00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 import roslib; roslib.load_manifest('trf_learn')
00029 import rospy
00030 
00031 import scipy.spatial as sp
00032 import threading
00033 import hrl_lib.util as ut
00034 import shutil
00035 import visualization_msgs.msg as vm
00036 import os.path as pt
00037 import numpy as np
00038 import time
00039 import pdb
00040 import os
00041 import trf_learn.recognize_3d as r3d
00042 import cv
00043 
00044 
00045 class LocationDisplay(threading.Thread):
00046 
00047     def __init__(self, loc_man): 
00048         threading.Thread.__init__(self)
00049         try:
00050             rospy.init_node('location_display')
00051         except Exception,e:
00052             print e
00053 
00054         self.location_label_pub  = rospy.Publisher('location_label', vm.Marker)
00055         self.location_marker_pub = rospy.Publisher('location_marker', vm.Marker)
00056         self.loc_man = loc_man
00057 
00058     def run(self):
00059         text_scale = .1
00060         text_color = np.matrix([1,0,0,1.]).T
00061 
00062         circle_radii = .9
00063         circle_scale = .2
00064         circle_z = .03
00065         circle_color = np.matrix([1,0,0,1.]).T
00066 
00067         circle_msgs = []
00068         data = self.loc_man.data
00069         pdb.set_trace()
00070         for task_id in data.keys():
00071             point_map = data[task_id]['center']
00072             circle_msgs.append(viz.circle_marker(point_map, circle_radii, circle_scale, circle_color, 'map', circle_z))
00073             
00074         text_msgs = []
00075         for task_id in data.keys():
00076             point_map = data[task_id]['center']
00077             text_msgs.append(viz.text_marker(task_id, point_map, text_color, text_scale, 'map'))
00078 
00079         r = rospy.Rate(2)
00080         while not rospy.is_shutdown():
00081             for c in circle_msgs:
00082                 self.location_marker_pub.publish(c)
00083             for txt in text_msgs:
00084                 self.location_label_pub.publish(txt)
00085             r.sleep()
00086 
00087 class LocationsManager:
00088 
00089     def __init__(self, name, rec_params, train=True):
00090         self.RELIABILITY_RECORD_LIM = 20
00091         self.RELIABILITY_THRES = .9
00092 
00093         self.rec_params = rec_params
00094         self.saved_locations_fname = name
00095         self.LOCATION_ADD_RADIUS = .5
00096         self.tree = None   
00097         self.centers = None 
00098 
00099         self.ids = [] 
00100         self.data = {} 
00101 
00102         self.learners = {}
00103         self._load_database()
00104         self.image_pubs = {}
00105 
00106         if train:
00107             for k in self.data.keys():
00108                 print '=========================================='
00109                 print 'Training for', k
00110                 print '=========================================='
00111                 self.train(k)
00112                 self.image_pubs[k] = r3d.ImagePublisher(k.replace(':', '_'))
00113 
00114         self.task_types = ['light_switch_down', 'light_switch_up', 
00115                             'light_rocker_down', 'light_rocker_up', 
00116                             'push_drawer', 'pull_drawer']
00117 
00118         self.task_pairs = [['light_switch_down', 'light_switch_up'], 
00119                            ['light_rocker_down', 'light_rocker_up'],
00120                            ['pull_drawer', 'push_drawer']]
00121 
00122 
00123     def revert(self):
00124         self.centers = self.centers[:, 0:4]
00125         self.ids = self.ids[0:4]
00126         self.data.pop('office_push_drawer')
00127         self.data.pop('office_pull_drawer')
00128         
00129 
00130     def _load_database(self):
00131         
00132         
00133         if not os.path.isfile(self.saved_locations_fname):
00134             return
00135         d = ut.load_pickle(self.saved_locations_fname)
00136         self.ids = d['ids']
00137         self.centers = d['centers']
00138         self.data = d['data']
00139         self.tree = sp.KDTree(np.array(self.centers).T)
00140 
00141     def save_database(self):
00142         print 'Saving pickle. DONOT INTERRUPPT!!!'
00143         d = {'centers': self.centers,
00144             'ids': self.ids,
00145             'data': self.data}
00146         try:
00147             shutil.copyfile(self.saved_locations_fname, 
00148                     time.strftime('%m_%d_%Y_%I_%M%p') + '_locations.pkl')
00149         except Exception, e:
00150             print e
00151 
00152         ut.save_pickle(d, self.saved_locations_fname)
00153         print 'SAFE!!!'
00154         
00155 
00156     def get_complementary_task(self, tasktype):
00157         for ta, tb in self.task_pairs:
00158             if ta == tasktype:
00159                 return tb
00160             if tb == tasktype:
00161                 return ta
00162         return None
00163 
00164     def update_base_pose(self, taskid, base_pose):
00165         print 'updating base pose for task', taskid
00166         self.data[taskid]['base_pose'] = base_pose
00167 
00168     def create_new_location(self, task_type, point_map, base_pose, gather_data=True, name=None):
00169         if name == None:
00170             taskid = time.strftime('%A_%m_%d_%Y_%I:%M%p') + ('_%s' % task_type)
00171         else:
00172             taskid = name + ('_%s' % task_type)
00173 
00174         try:
00175             os.mkdir(taskid)
00176         except OSError, e:
00177             print e
00178 
00179         if self.centers == None:
00180             self.centers = point_map
00181         else:
00182             self.centers = np.column_stack((self.centers, point_map))
00183         self.tree = sp.KDTree(np.array(self.centers.T))
00184 
00185         self.ids.append(taskid)
00186         self.data[taskid] = {'task': task_type,
00187                              'center': point_map,
00188                              'base_pose': base_pose,
00189                              'points': point_map,
00190                              'dataset': None,
00191                              'dataset_raw': None,
00192                              'gather_data': gather_data,
00193                              'complementary_task_id': None,
00194                              'pca': None,
00195                              'execution_record': []}
00196         self.image_pubs[taskid] = r3d.ImagePublisher(taskid.replace(':', '_'))
00197         self.save_database()
00198         return taskid
00199 
00200     def record_time(self, task_id, record_name, value):
00201         if not self.data[task_id].has_key('times'):
00202             self.data[task_id]['times'] = {}
00203 
00204         if not self.data[task_id]['times'].has_key(record_name):
00205             self.data[task_id]['times'][record_name] = []
00206 
00207         self.data[task_id]['times'][record_name].append(value)
00208 
00209     def update_execution_record(self, taskid, value):
00210         self.data[taskid]['execution_record'].append(value)
00211 
00212     def is_reliable(self, taskid):
00213         record = self.data[taskid]['execution_record']
00214         if len(record) < self.RELIABILITY_RECORD_LIM:
00215             return False
00216 
00217         if np.sum(record) < (self.RELIABILITY_RECORD_LIM * self.RELIABILITY_THRES):
00218             return False
00219 
00220         return True
00221 
00222     def _id_to_center_idx(self, task_id):
00223         for i, tid in enumerate(self.ids):
00224             if tid == task_id:
00225                 return i
00226         return None
00227 
00228     def add_perceptual_data(self, task_id, fea_dict):
00229         rospy.loginfo('LocationsManager: add_perceptual_data - %s adding %d instance(s)' \
00230                 % (task_id, fea_dict['labels'].shape[1]))
00231         current_raw_dataset = self.data[task_id]['dataset_raw']
00232         current_dataset = self.data[task_id]['dataset']
00233 
00234         self.data[task_id]['dataset_raw'] = \
00235                 r3d.InterestPointDataset.add_to_dataset(
00236                         current_raw_dataset, fea_dict['instances'], 
00237                         fea_dict['labels'], fea_dict['points2d'], 
00238                         fea_dict['points3d'], None, None, 
00239                         sizes=fea_dict['sizes'])
00240 
00241         self.data[task_id]['dataset'] = \
00242                 r3d.InterestPointDataset.add_to_dataset(
00243                         current_dataset, fea_dict['instances'], 
00244                         fea_dict['labels'], fea_dict['points2d'], 
00245                         fea_dict['points3d'], None, None, 
00246                         sizes=fea_dict['sizes'])
00247 
00248     def get_perceptual_data(self, task_id):
00249         return self.data[task_id]['dataset']
00250 
00251     def remove_perceptual_data(self, task_id, instance_idx):
00252         self.data[task_id]['dataset'].remove(instance_idx)
00253         self.data[task_id]['dataset_raw'].remove(instance_idx)
00254 
00255     def active_learn_add_data(self, task_id, fea_dict):
00256         
00257         self.add_perceptual_data(task_id, fea_dict)
00258 
00259     def update(self, task_id, point_map):
00260         
00261         ldata = self.data[task_id]
00262         ldata['points'] = np.column_stack((point_map, ldata['points']))
00263         ldata['center'] = ldata['points'].mean(1)
00264 
00265         center_idx = self._id_to_center_idx(task_id)
00266         self.centers[:, center_idx] = ldata['center']
00267         self.tree = sp.KDTree(np.array(self.centers).T)
00268 
00269     def set_center(self, task_id, point_map):
00270         ldata = self.data[task_id]
00271         ldata['points'] = point_map
00272         ldata['center'] = point_map
00273         center_idx = self._id_to_center_idx(task_id)
00274         self.centers[:, center_idx] = ldata['center']
00275         self.tree = sp.KDTree(np.array(self.centers).T)
00276 
00277     def publish_image(self, task_id, image, postfix=''):
00278         self.image_pubs[task_id].publish(image)
00279         ffull = pt.join(task_id, time.strftime('%A_%m_%d_%Y_%I_%M_%S%p') + postfix + '.jpg')
00280         cv.SaveImage(ffull, image)
00281 
00282     def list_all(self):
00283         rlist = []
00284         for k in self.data.keys():
00285             rlist.append([k, self.data[k]['task']])
00286         return rlist
00287 
00288     def list_close_by(self, point_map, task=None):
00289         if self.tree != None:
00290             indices = self.tree.query_ball_point(np.array(point_map.T), self.LOCATION_ADD_RADIUS)[0]
00291             print 'list_close_by: indices close by', indices
00292             
00293             ids_selected = []
00294             for i in indices:
00295                 sid = self.ids[i]
00296                 stask = self.data[sid]['task']
00297                 if task == None:
00298                     ids_selected.append([sid, stask])
00299                 else:
00300                     if task == stask:
00301                         ids_selected.append([sid, stask])
00302             return ids_selected
00303         else:
00304             return []
00305 
00306     def train_all_classifiers(self):
00307         for k in self.data.keys():
00308             self.train(k)
00309 
00310     def train(self, task_id, dset_for_pca=None, save_pca_images=True):
00311         dataset = self.data[task_id]['dataset']
00312         rec_params = self.rec_params
00313         
00314         if dataset == None:
00315             return
00316 
00317         
00318         nneg = np.sum(dataset.outputs == r3d.NEGATIVE) 
00319         npos = np.sum(dataset.outputs == r3d.POSITIVE)
00320         print '================= Training ================='
00321         print 'NEG examples', nneg
00322         print 'POS examples', npos
00323         print 'TOTAL', dataset.outputs.shape[1]
00324         neg_to_pos_ratio = float(nneg)/float(npos)
00325         weight_balance = ' -w0 1 -w1 %.2f' % neg_to_pos_ratio
00326         print 'training'
00327         
00328         previous_learner = None
00329         if self.learners.has_key(task_id):
00330             previous_learner = self.learners[task_id]
00331         
00332         learner = r3d.SVMPCA_ActiveLearner(use_pca=True, 
00333                         reconstruction_std_lim=self.rec_params.reconstruction_std_lim, 
00334                         reconstruction_err_toler=self.rec_params.reconstruction_err_toler,
00335                         old_learner=previous_learner, pca=self.data[task_id]['pca'])
00336 
00337         
00338         if dset_for_pca != None:
00339             inputs_for_pca = dset_for_pca['instances']
00340         else:
00341             
00342             inputs_for_pca = dataset.inputs
00343 
00344         learner.train(dataset, 
00345                       inputs_for_pca,
00346                       rec_params.svm_params + weight_balance,
00347                       rec_params.variance_keep)
00348 
00349         self.data[task_id]['pca'] = learner.pca
00350         self.learners[task_id] = learner
00351         if save_pca_images:
00352             
00353             basis = learner.pca.projection_basis
00354             cv.SaveImage('%s_pca.png' % task_id, r3d.instances_to_image(self.rec_params.win_size, basis, np.min(basis), np.max(basis)))
00355