00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 import roslib; roslib.load_manifest('trf_learn')
00029 import rospy
00030
00031 import scipy.spatial as sp
00032 import threading
00033 import hrl_lib.util as ut
00034 import shutil
00035 import visualization_msgs.msg as vm
00036 import os.path as pt
00037 import numpy as np
00038 import time
00039 import pdb
00040 import os
00041 import trf_learn.recognize_3d as r3d
00042 import cv
00043
00044
00045 class LocationDisplay(threading.Thread):
00046
00047 def __init__(self, loc_man):
00048 threading.Thread.__init__(self)
00049 try:
00050 rospy.init_node('location_display')
00051 except Exception,e:
00052 print e
00053
00054 self.location_label_pub = rospy.Publisher('location_label', vm.Marker)
00055 self.location_marker_pub = rospy.Publisher('location_marker', vm.Marker)
00056 self.loc_man = loc_man
00057
00058 def run(self):
00059 text_scale = .1
00060 text_color = np.matrix([1,0,0,1.]).T
00061
00062 circle_radii = .9
00063 circle_scale = .2
00064 circle_z = .03
00065 circle_color = np.matrix([1,0,0,1.]).T
00066
00067 circle_msgs = []
00068 data = self.loc_man.data
00069 pdb.set_trace()
00070 for task_id in data.keys():
00071 point_map = data[task_id]['center']
00072 circle_msgs.append(viz.circle_marker(point_map, circle_radii, circle_scale, circle_color, 'map', circle_z))
00073
00074 text_msgs = []
00075 for task_id in data.keys():
00076 point_map = data[task_id]['center']
00077 text_msgs.append(viz.text_marker(task_id, point_map, text_color, text_scale, 'map'))
00078
00079 r = rospy.Rate(2)
00080 while not rospy.is_shutdown():
00081 for c in circle_msgs:
00082 self.location_marker_pub.publish(c)
00083 for txt in text_msgs:
00084 self.location_label_pub.publish(txt)
00085 r.sleep()
00086
00087 class LocationsManager:
00088
00089 def __init__(self, name, rec_params, train=True):
00090 self.RELIABILITY_RECORD_LIM = 20
00091 self.RELIABILITY_THRES = .9
00092
00093 self.rec_params = rec_params
00094 self.saved_locations_fname = name
00095 self.LOCATION_ADD_RADIUS = .5
00096 self.tree = None
00097 self.centers = None
00098
00099 self.ids = []
00100 self.data = {}
00101
00102 self.learners = {}
00103 self._load_database()
00104 self.image_pubs = {}
00105
00106 if train:
00107 for k in self.data.keys():
00108 print '=========================================='
00109 print 'Training for', k
00110 print '=========================================='
00111 self.train(k)
00112 self.image_pubs[k] = r3d.ImagePublisher(k.replace(':', '_'))
00113
00114 self.task_types = ['light_switch_down', 'light_switch_up',
00115 'light_rocker_down', 'light_rocker_up',
00116 'push_drawer', 'pull_drawer']
00117
00118 self.task_pairs = [['light_switch_down', 'light_switch_up'],
00119 ['light_rocker_down', 'light_rocker_up'],
00120 ['pull_drawer', 'push_drawer']]
00121
00122
00123 def revert(self):
00124 self.centers = self.centers[:, 0:4]
00125 self.ids = self.ids[0:4]
00126 self.data.pop('office_push_drawer')
00127 self.data.pop('office_pull_drawer')
00128
00129
00130 def _load_database(self):
00131
00132
00133 if not os.path.isfile(self.saved_locations_fname):
00134 return
00135 d = ut.load_pickle(self.saved_locations_fname)
00136 self.ids = d['ids']
00137 self.centers = d['centers']
00138 self.data = d['data']
00139 self.tree = sp.KDTree(np.array(self.centers).T)
00140
00141 def save_database(self):
00142 print 'Saving pickle. DONOT INTERRUPPT!!!'
00143 d = {'centers': self.centers,
00144 'ids': self.ids,
00145 'data': self.data}
00146 try:
00147 shutil.copyfile(self.saved_locations_fname,
00148 time.strftime('%m_%d_%Y_%I_%M%p') + '_locations.pkl')
00149 except Exception, e:
00150 print e
00151
00152 ut.save_pickle(d, self.saved_locations_fname)
00153 print 'SAFE!!!'
00154
00155
00156 def get_complementary_task(self, tasktype):
00157 for ta, tb in self.task_pairs:
00158 if ta == tasktype:
00159 return tb
00160 if tb == tasktype:
00161 return ta
00162 return None
00163
00164 def update_base_pose(self, taskid, base_pose):
00165 print 'updating base pose for task', taskid
00166 self.data[taskid]['base_pose'] = base_pose
00167
00168 def create_new_location(self, task_type, point_map, base_pose, gather_data=True, name=None):
00169 if name == None:
00170 taskid = time.strftime('%A_%m_%d_%Y_%I:%M%p') + ('_%s' % task_type)
00171 else:
00172 taskid = name + ('_%s' % task_type)
00173
00174 try:
00175 os.mkdir(taskid)
00176 except OSError, e:
00177 print e
00178
00179 if self.centers == None:
00180 self.centers = point_map
00181 else:
00182 self.centers = np.column_stack((self.centers, point_map))
00183 self.tree = sp.KDTree(np.array(self.centers.T))
00184
00185 self.ids.append(taskid)
00186 self.data[taskid] = {'task': task_type,
00187 'center': point_map,
00188 'base_pose': base_pose,
00189 'points': point_map,
00190 'dataset': None,
00191 'dataset_raw': None,
00192 'gather_data': gather_data,
00193 'complementary_task_id': None,
00194 'pca': None,
00195 'execution_record': []}
00196 self.image_pubs[taskid] = r3d.ImagePublisher(taskid.replace(':', '_'))
00197 self.save_database()
00198 return taskid
00199
00200 def record_time(self, task_id, record_name, value):
00201 if not self.data[task_id].has_key('times'):
00202 self.data[task_id]['times'] = {}
00203
00204 if not self.data[task_id]['times'].has_key(record_name):
00205 self.data[task_id]['times'][record_name] = []
00206
00207 self.data[task_id]['times'][record_name].append(value)
00208
00209 def update_execution_record(self, taskid, value):
00210 self.data[taskid]['execution_record'].append(value)
00211
00212 def is_reliable(self, taskid):
00213 record = self.data[taskid]['execution_record']
00214 if len(record) < self.RELIABILITY_RECORD_LIM:
00215 return False
00216
00217 if np.sum(record) < (self.RELIABILITY_RECORD_LIM * self.RELIABILITY_THRES):
00218 return False
00219
00220 return True
00221
00222 def _id_to_center_idx(self, task_id):
00223 for i, tid in enumerate(self.ids):
00224 if tid == task_id:
00225 return i
00226 return None
00227
00228 def add_perceptual_data(self, task_id, fea_dict):
00229 rospy.loginfo('LocationsManager: add_perceptual_data - %s adding %d instance(s)' \
00230 % (task_id, fea_dict['labels'].shape[1]))
00231 current_raw_dataset = self.data[task_id]['dataset_raw']
00232 current_dataset = self.data[task_id]['dataset']
00233
00234 self.data[task_id]['dataset_raw'] = \
00235 r3d.InterestPointDataset.add_to_dataset(
00236 current_raw_dataset, fea_dict['instances'],
00237 fea_dict['labels'], fea_dict['points2d'],
00238 fea_dict['points3d'], None, None,
00239 sizes=fea_dict['sizes'])
00240
00241 self.data[task_id]['dataset'] = \
00242 r3d.InterestPointDataset.add_to_dataset(
00243 current_dataset, fea_dict['instances'],
00244 fea_dict['labels'], fea_dict['points2d'],
00245 fea_dict['points3d'], None, None,
00246 sizes=fea_dict['sizes'])
00247
00248 def get_perceptual_data(self, task_id):
00249 return self.data[task_id]['dataset']
00250
00251 def remove_perceptual_data(self, task_id, instance_idx):
00252 self.data[task_id]['dataset'].remove(instance_idx)
00253 self.data[task_id]['dataset_raw'].remove(instance_idx)
00254
00255 def active_learn_add_data(self, task_id, fea_dict):
00256
00257 self.add_perceptual_data(task_id, fea_dict)
00258
00259 def update(self, task_id, point_map):
00260
00261 ldata = self.data[task_id]
00262 ldata['points'] = np.column_stack((point_map, ldata['points']))
00263 ldata['center'] = ldata['points'].mean(1)
00264
00265 center_idx = self._id_to_center_idx(task_id)
00266 self.centers[:, center_idx] = ldata['center']
00267 self.tree = sp.KDTree(np.array(self.centers).T)
00268
00269 def set_center(self, task_id, point_map):
00270 ldata = self.data[task_id]
00271 ldata['points'] = point_map
00272 ldata['center'] = point_map
00273 center_idx = self._id_to_center_idx(task_id)
00274 self.centers[:, center_idx] = ldata['center']
00275 self.tree = sp.KDTree(np.array(self.centers).T)
00276
00277 def publish_image(self, task_id, image, postfix=''):
00278 self.image_pubs[task_id].publish(image)
00279 ffull = pt.join(task_id, time.strftime('%A_%m_%d_%Y_%I_%M_%S%p') + postfix + '.jpg')
00280 cv.SaveImage(ffull, image)
00281
00282 def list_all(self):
00283 rlist = []
00284 for k in self.data.keys():
00285 rlist.append([k, self.data[k]['task']])
00286 return rlist
00287
00288 def list_close_by(self, point_map, task=None):
00289 if self.tree != None:
00290 indices = self.tree.query_ball_point(np.array(point_map.T), self.LOCATION_ADD_RADIUS)[0]
00291 print 'list_close_by: indices close by', indices
00292
00293 ids_selected = []
00294 for i in indices:
00295 sid = self.ids[i]
00296 stask = self.data[sid]['task']
00297 if task == None:
00298 ids_selected.append([sid, stask])
00299 else:
00300 if task == stask:
00301 ids_selected.append([sid, stask])
00302 return ids_selected
00303 else:
00304 return []
00305
00306 def train_all_classifiers(self):
00307 for k in self.data.keys():
00308 self.train(k)
00309
00310 def train(self, task_id, dset_for_pca=None, save_pca_images=True):
00311 dataset = self.data[task_id]['dataset']
00312 rec_params = self.rec_params
00313
00314 if dataset == None:
00315 return
00316
00317
00318 nneg = np.sum(dataset.outputs == r3d.NEGATIVE)
00319 npos = np.sum(dataset.outputs == r3d.POSITIVE)
00320 print '================= Training ================='
00321 print 'NEG examples', nneg
00322 print 'POS examples', npos
00323 print 'TOTAL', dataset.outputs.shape[1]
00324 neg_to_pos_ratio = float(nneg)/float(npos)
00325 weight_balance = ' -w0 1 -w1 %.2f' % neg_to_pos_ratio
00326 print 'training'
00327
00328 previous_learner = None
00329 if self.learners.has_key(task_id):
00330 previous_learner = self.learners[task_id]
00331
00332 learner = r3d.SVMPCA_ActiveLearner(use_pca=True,
00333 reconstruction_std_lim=self.rec_params.reconstruction_std_lim,
00334 reconstruction_err_toler=self.rec_params.reconstruction_err_toler,
00335 old_learner=previous_learner, pca=self.data[task_id]['pca'])
00336
00337
00338 if dset_for_pca != None:
00339 inputs_for_pca = dset_for_pca['instances']
00340 else:
00341
00342 inputs_for_pca = dataset.inputs
00343
00344 learner.train(dataset,
00345 inputs_for_pca,
00346 rec_params.svm_params + weight_balance,
00347 rec_params.variance_keep)
00348
00349 self.data[task_id]['pca'] = learner.pca
00350 self.learners[task_id] = learner
00351 if save_pca_images:
00352
00353 basis = learner.pca.projection_basis
00354 cv.SaveImage('%s_pca.png' % task_id, r3d.instances_to_image(self.rec_params.win_size, basis, np.min(basis), np.max(basis)))
00355