pr2_viz_servo.py
Go to the documentation of this file.
00001 #! /usr/bin/python
00002 
00003 import numpy as np
00004 import sys
00005 from collections import deque
00006 
00007 import roslib
00008 roslib.load_manifest('hrl_pr2_arms')
00009 roslib.load_manifest('ar_pose')
00010 roslib.load_manifest('visualization_msgs')
00011 roslib.load_manifest('pykdl_utils')
00012 
00013 import rospy
00014 import tf.transformations as tf_trans
00015 from geometry_msgs.msg import Twist
00016 from std_msgs.msg import Float32MultiArray
00017 from visualization_msgs.msg import Marker
00018 from std_msgs.msg import ColorRGBA
00019 
00020 from hrl_generic_arms.pose_converter import PoseConverter
00021 from hrl_generic_arms.controllers import PIDController
00022 from pykdl_utils.pr2_kin import kin_from_param
00023 
00024 from ar_pose.msg import ARMarker
00025 
00026 class ServoKalmanFilter(object):
00027     # TODO tune these parameters properly
00028     def __init__(self, delta_t, sigma_z=0.02*0.001, P_init=[0.0, 0.0], sigma_a=0.02):
00029         self.P_init = P_init
00030         self.delta_t = delta_t
00031         self.F = np.mat([[1, delta_t], [0, 1]])
00032         self.Q = np.mat([[delta_t**4/4, delta_t**3/2], [delta_t**3/2, delta_t**2]]) * sigma_a**2
00033         self.H = np.mat([1, 0])
00034         self.R = np.mat([sigma_z])
00035         self.x_cur = None
00036 
00037         self.resid_q_len = int(1./delta_t)
00038         self.resid_sigma_reject = 3.
00039         self.min_reject = 0.1
00040         self.resid_queue = deque()
00041         
00042         self.unreli_q_len = 1 * int(1./delta_t)
00043         self.unreli_queue = deque()
00044         self.unreli_weights = np.linspace(2, 0, self.unreli_q_len)
00045 
00046     def update(self, z_obs, new_obs=True):
00047         is_unreli = False
00048         if new_obs:
00049             if self.x_cur is None:
00050                 self.x_cur = np.mat([z_obs, 0]).T
00051                 self.P_cur = np.mat(np.diag(self.P_init))
00052             # predict
00053             x_pred = self.F * self.x_cur # predicted state
00054             P_pred = self.F * self.P_cur * self.F.T + self.Q # predicted covariance
00055 
00056             # update
00057             y_resi = z_obs - self.H * x_pred # measurement residual
00058             S_resi = self.H * P_pred * self.H.T + self.R # residual covariance
00059             K_gain = P_pred * self.H.T * S_resi**-1 # Kalman gain
00060 
00061             # check residual to be consistent with recent residuals
00062             if (len(self.resid_queue) == self.resid_q_len and
00063                 np.fabs(y_resi) > max(self.min_reject, 
00064                                       self.resid_sigma_reject * np.std(self.resid_queue))):
00065                 # we have determined this observation to be unreliable
00066                 print "INCONSISTENT", self.resid_queue
00067                 is_unreli = True
00068 
00069             else:
00070                 self.x_cur = x_pred + K_gain * y_resi # update state estimate
00071                 self.P_cur = (np.mat(np.eye(2)) - K_gain * self.H) * P_pred
00072 
00073             # record residual
00074             if len(self.resid_queue) == self.resid_q_len:
00075                 self.resid_queue.popleft()
00076             self.resid_queue.append(y_resi)
00077         else:
00078             print "NOT NEW"
00079             is_unreli = True
00080 
00081         # record is_unreli
00082         if len(self.unreli_queue) == self.unreli_q_len:
00083             self.unreli_queue.popleft()
00084         self.unreli_queue.append(is_unreli)
00085 
00086         # find the unreli level
00087         # this value [0, 1] is a record of the values which have been determined to be unreli
00088         # in the pase few seconds filtered with linear weights
00089         # a value of 0 means there are no unreliable estimates, 
00090         # 1 means there is no reliable state estimate
00091         if len(self.unreli_queue) == self.unreli_q_len:
00092             unreli_level = np.sum(self.unreli_weights * self.unreli_queue) / self.unreli_q_len
00093         else:
00094             unreli_level = 0.
00095 
00096         return self.x_cur, self.P_cur, unreli_level
00097 
00098 def homo_mat_from_2d(x, y, rot):
00099     mat2d = np.mat(tf_trans.euler_matrix(0, 0, rot))
00100     mat2d[0,3] = x
00101     mat2d[1,3] = y
00102     return mat2d
00103 
00104 def homo_mat_to_2d(mat):
00105     rot = tf_trans.euler_from_matrix(mat)[2]
00106     return mat[0,3], mat[1,3], rot
00107 
00108 def create_base_marker(pose, id, color):
00109     marker = Marker()
00110     marker.header.frame_id = "base_link"
00111     marker.header.stamp = rospy.Time.now()
00112     marker.ns = "ar_servo"
00113     marker.id = id
00114     marker.pose = PoseConverter.to_pose_msg(pose)
00115     marker.color = ColorRGBA(*(color + (1.0,)))
00116     marker.scale.x = 0.7; marker.scale.y = 0.7; marker.scale.z = 0.2
00117     return marker
00118 
00119 class PR2VisualServoAR(object):
00120     def __init__(self, ar_topic):
00121         self.ar_sub = rospy.Subscriber(ar_topic, ARMarker, self.ar_sub)
00122         self.mkr_pub = rospy.Publisher("visualization_marker", Marker)
00123 
00124         self.cur_ar_pose = None
00125         self.kin_arm = None
00126         self.ar_pose_updated = False
00127         self.base_pub = rospy.Publisher("/base_controller/command", Twist)
00128         self.preempt_requested = False
00129 
00130     def ar_sub(self, msg):
00131         if self.kin_arm == None:
00132             self.kin_arm = kin_from_param(base_link="base_link", 
00133                                           end_link=msg.header.frame_id)
00134         base_B_camera = self.kin_arm.forward_filled()
00135         camera_B_tag = PoseConverter.to_homo_mat(msg.pose.pose)
00136         cur_ar_pose = base_B_camera * camera_B_tag
00137         # check to see if the tag is in front of the robot
00138         if cur_ar_pose[0,3] < 0.:
00139             #rospy.logwarn("Strange AR toolkit bug!")
00140             return
00141         self.cur_ar_pose = cur_ar_pose
00142         self.ar_pose_updated = True
00143 
00144     def request_preempt(self):
00145         self.preempt_requested = True
00146 
00147     def save_ar_goal(self):
00148         r = rospy.Rate(10)
00149         while not rospy.is_shutdown():
00150             if self.cur_ar_pose is not None:
00151                 ar_goal = homo_mat_to_2d(self.cur_ar_pose)
00152                 print ar_goal
00153             r.sleep()
00154 
00155     def test_move(self):
00156         rospy.sleep(0)
00157         base_twist = Twist()
00158         base_twist.linear.x = 0.0 #x_ctrl
00159         base_twist.linear.y = 0.0 #y_ctrl
00160         base_twist.angular.z = 0.4 #r_ctrl
00161         r = rospy.Rate(20.)
00162         while not rospy.is_shutdown():
00163             self.base_pub.publish(base_twist)
00164             r.sleep()
00165         self.base_pub.publish(Twist())
00166 
00167     def find_ar_tag(self, timeout=None):
00168         rate = 20.
00169         ar_2d_q_len = 10
00170         sigma_thresh = [0.005, 0.001, 0.01]
00171         no_mean_thresh = 0.5
00172         r = rospy.Rate(rate)
00173         ar_2d_queue = deque()
00174         new_obs_queue = deque()
00175         start_time = rospy.get_time()
00176         while True:
00177             if timeout is not None and rospy.get_time() - start_time > timeout:
00178                 rospy.logwarn("[pr2_viz_servo] find_ar_tag timed out, current ar_sigma: " + 
00179                               str(np.std(ar_2d_queue, 0)) +
00180                               " sigma_thresh: " +
00181                               str(sigma_thresh))
00182                 return None, 'timeout'
00183             if self.preempt_requested:
00184                 self.preempt_requested = False
00185                 return None, 'preempted'
00186             if rospy.is_shutdown():
00187                 return None, 'aborted'
00188 
00189             if self.cur_ar_pose is not None:
00190                 # make sure we have a new observation
00191                 new_obs = self.ar_pose_updated
00192                 self.ar_pose_updated = False
00193 
00194                 if new_obs:
00195                     if len(ar_2d_queue) == ar_2d_q_len:
00196                         ar_2d_queue.popleft()
00197                     ar_2d = homo_mat_to_2d(self.cur_ar_pose)
00198                     ar_2d_queue.append(ar_2d)
00199 
00200                 if len(new_obs_queue) == ar_2d_q_len:
00201                     new_obs_queue.popleft()
00202                 new_obs_queue.append(new_obs)
00203 
00204                 # see if we have a low variance tag
00205                 if len(ar_2d_queue) == ar_2d_q_len:
00206                     ar_sigma = np.std(ar_2d_queue, 0)
00207                     no_mean = np.mean(new_obs_queue, 0)
00208                     print ar_sigma, no_mean
00209                     if np.all(ar_sigma < sigma_thresh) and no_mean >= no_mean_thresh:
00210                         return np.mean(ar_2d_queue, 0), 'found_tag'
00211             r.sleep()
00212 
00213     def servo_to_tag(self, pose_goal, goal_error=[0.03, 0.03, 0.1], initial_ar_pose=None):
00214         lost_tag_thresh = 0.6 #0.4
00215 
00216         # TODO REMOVE
00217         err_pub = rospy.Publisher("servo_err", Float32MultiArray)
00218         if False:
00219             self.test_move()
00220             return "aborted"
00221         #######################
00222 
00223         goal_ar_pose = homo_mat_from_2d(*pose_goal)
00224         rate = 20.
00225         kf_x = ServoKalmanFilter(delta_t=1./rate)
00226         kf_y = ServoKalmanFilter(delta_t=1./rate)
00227         kf_r = ServoKalmanFilter(delta_t=1./rate)
00228         if initial_ar_pose is not None:
00229             ar_err = homo_mat_to_2d(homo_mat_from_2d(*initial_ar_pose) * goal_ar_pose**-1)
00230             kf_x.update(ar_err[0])
00231             kf_y.update(ar_err[1])
00232             kf_r.update(ar_err[2])
00233             print "initial_ar_pose", initial_ar_pose
00234             
00235         pid_x = PIDController(k_p=0.5, rate=rate, saturation=0.05)
00236         pid_y = PIDController(k_p=0.5, rate=rate, saturation=0.05)
00237         pid_r = PIDController(k_p=0.5, rate=rate, saturation=0.08)
00238         r = rospy.Rate(rate)
00239         while True:
00240             if rospy.is_shutdown():
00241                 self.base_pub.publish(Twist())
00242                 return 'aborted'
00243             if self.preempt_requested:
00244                 self.preempt_requested = False
00245                 self.base_pub.publish(Twist())
00246                 return 'preempted'
00247             goal_mkr = create_base_marker(goal_ar_pose, 0, (0., 1., 0.))
00248             self.mkr_pub.publish(goal_mkr)
00249             if self.cur_ar_pose is not None:
00250                 # make sure we have a new observation
00251                 new_obs = self.ar_pose_updated
00252                 self.ar_pose_updated = False
00253 
00254                 # find the error between the AR tag and goal pose
00255                 print "self.cur_ar_pose", self.cur_ar_pose
00256                 cur_ar_pose_2d = homo_mat_from_2d(*homo_mat_to_2d(self.cur_ar_pose))
00257                 print "cur_ar_pose_2d", cur_ar_pose_2d
00258                 ar_mkr = create_base_marker(cur_ar_pose_2d, 1, (1., 0., 0.))
00259                 self.mkr_pub.publish(ar_mkr)
00260                 ar_err = homo_mat_to_2d(cur_ar_pose_2d * goal_ar_pose**-1)
00261                 print "ar_err", ar_err
00262                 print "goal_ar_pose", goal_ar_pose
00263 
00264                 # filter this error using a Kalman filter
00265                 x_filt_err, x_filt_cov, x_unreli = kf_x.update(ar_err[0], new_obs=new_obs)
00266                 y_filt_err, y_filt_cov, y_unreli = kf_y.update(ar_err[1], new_obs=new_obs)
00267                 r_filt_err, r_filt_cov, r_unreli = kf_r.update(ar_err[2], new_obs=new_obs)
00268 
00269                 if np.any(np.array([x_unreli, y_unreli, r_unreli]) > [lost_tag_thresh]*3):
00270                     self.base_pub.publish(Twist())
00271                     return 'lost_tag'
00272 
00273                 print "Noise:", x_unreli, y_unreli, r_unreli
00274                 # TODO REMOVE
00275                 ma = Float32MultiArray()
00276                 ma.data = [x_filt_err[0,0], x_filt_err[1,0], ar_err[0], 
00277                            x_unreli, y_unreli, r_unreli]
00278                 err_pub.publish(ma)
00279 
00280                 print "xerr"
00281                 print x_filt_err
00282                 print x_filt_cov
00283                 print "Cov", x_filt_cov[0,0], y_filt_cov[0,0], r_filt_cov[0,0]
00284                 x_ctrl = pid_x.update_state(x_filt_err[0,0])
00285                 y_ctrl = pid_y.update_state(y_filt_err[0,0])
00286                 r_ctrl = pid_r.update_state(r_filt_err[0,0])
00287                 base_twist = Twist()
00288                 base_twist.linear.x = x_ctrl
00289                 base_twist.linear.y = y_ctrl
00290                 base_twist.angular.z = r_ctrl
00291                 cur_filt_err = np.array([x_filt_err[0,0], y_filt_err[0,0], r_filt_err[0,0]])
00292                 print "err", ar_err
00293                 print "Err filt", cur_filt_err 
00294                 print "Twist:", base_twist
00295                 if np.all(np.fabs(cur_filt_err) < goal_error):
00296                     self.base_pub.publish(Twist())
00297                     return 'succeeded'
00298 
00299                 self.base_pub.publish(base_twist)
00300 
00301             r.sleep()
00302 
00303 def main():
00304     rospy.init_node("pr2_viz_servo")
00305     assert(sys.argv[1] in ['r', 'l'])
00306     if sys.argv[1] == 'r':
00307         viz_servo = PR2VisualServoAR("/r_pr2_ar_pose_marker")
00308     else:
00309         viz_servo = PR2VisualServoAR("/l_pr2_ar_pose_marker")
00310     if False:
00311         viz_servo.save_ar_goal()
00312     elif False:
00313         viz_servo.servo_to_tag((0.55761498778404717, -0.28816809195738824, 1.5722787397126308))
00314     else:
00315         print viz_servo.find_ar_tag(5)
00316 
00317 if __name__ == "__main__":
00318     main()


kelsey_sandbox
Author(s): kelsey
autogenerated on Wed Nov 27 2013 11:52:04