Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 from numpy import zeros, eye, dot, array
00036 from numpy.linalg import inv, norm
00037 from numpy.random import multivariate_normal
00038 class DiscreteKalmanFilter:
00039 quiet=True
00040 def __init__(self, n, P_0_0, xhat_0_0=None, A=None, B=None, C=None, R=None, Q=None, quiet=True):
00041 self.quiet = quiet
00042 self.P_k_k = P_0_0
00043 self.P_k_km1 = self.P_k_k
00044 self.n = n
00045 if xhat_0_0 is None:
00046 self.xhat_k_k = zeros(n)
00047 else:
00048 self.xhat_k_k = xhat_0_0
00049 self.xhat_k_km1 = self.xhat_k_k
00050 if A is None:
00051 self.A = eye(n)
00052 else:
00053 self.A = A
00054 if B is None:
00055 self.no_input = True
00056 self.ni = 0
00057 else:
00058 self.no_input = False
00059 self.B = B
00060 self.ni = B.shape[1]
00061 if C is None:
00062 self.C = eye(n)
00063 else:
00064 self.C = C
00065 self.no = self.C.shape[0]
00066 if R is None:
00067 self.R = zeros((self.no,self.no))
00068 else:
00069 self.R = R
00070 if Q is None:
00071 self.Q = zeros((self.n,self.n))
00072 else:
00073 self.Q = Q
00074 if not self.quiet:
00075 self.print_info()
00076
00077 def print_info(self):
00078 print "Kalman filter initialized"
00079 print "n = %d, ni = %d, no = %d" % (self.n, self.ni, self.no)
00080 print "A = "
00081 print self.A
00082 if self.no_input:
00083 print "no input"
00084 else:
00085 print "B = "
00086 print self.B
00087 print "C = "
00088 print self.C
00089 print "Q = "
00090 print self.Q
00091 print "R = "
00092 print self.R
00093 self.print_a_priori()
00094 self.print_a_posteriori()
00095
00096 def print_a_posteriori(self):
00097 print "xhat_k_k = "
00098 print self.xhat_k_k
00099 print "P_k_k = "
00100 print self.P_k_k
00101
00102 def print_a_priori(self):
00103 print "xhat_k_km1 = "
00104 print self.xhat_k_km1
00105 print "P_k_km1 = "
00106 print self.P_k_km1
00107
00108
00109 def predict(self, u=None):
00110 xhat_k_km1 = dot(self.A, self.xhat_k_k)
00111 if not self.no_input:
00112 xhat_k_km1 += dot(self.B, u)
00113 self.xhat_k_km1 = xhat_k_km1
00114 self.P_k_km1 = dot(dot(self.A,self.P_k_k),self.A.T) + self.Q
00115
00116 def correct(self, y_k):
00117
00118 innov = y_k - dot(self.C, self.xhat_k_km1)
00119
00120
00121 S_k = dot(self.C,dot(self.P_k_km1,self.C.T)) + self.R
00122
00123
00124
00125 K_k = dot(self.P_k_km1,dot(self.C.T,inv(S_k)))
00126
00127
00128
00129 self.xhat_k_k = self.xhat_k_km1 + dot(K_k,innov)
00130
00131 self.P_k_k = dot(eye(self.n) - dot(K_k,self.C), self.P_k_km1)
00132
00133 if __name__ == "__main__":
00134 self = DiscreteKalmanFilter(3, eye(3)*100.0, zeros(3), R=eye(3)*10.0, quiet=False)
00135 true_value = array([1,0,0])
00136 while(True):
00137 self.predict()
00138
00139 y_k = true_value + multivariate_normal(zeros(3), self.R)
00140 self.correct(y_k)
00141 print y_k, self.xhat_k_k, norm(self.xhat_k_k - true_value)
00142
00143
00144