00001
00002
00003 """
00004 usage: %(progname)s <track_file.txt>
00005 """
00006 import roslib; roslib.load_manifest('articulation_models')
00007 from articulation_models.track_utils import *
00008 from copy import deepcopy
00009
00010
00011
00012 def main():
00013 try:
00014 rospy.init_node('learn_prior')
00015 model_pub = rospy.Publisher('model', ModelMsg)
00016
00017 tracks = []
00018 for filename in rospy.myargv()[1:]:
00019 tracks.append( readtrack( filename ) )
00020
00021 for track in tracks:
00022 track.pose = zero_start( track.pose )
00023
00024 model_select = rospy.ServiceProxy('model_select', TrackModelSrv)
00025
00026 model_eval = rospy.ServiceProxy('model_select_eval', TrackModelSrv)
00027
00028 prior_models = []
00029 for id,track in enumerate(tracks):
00030 request = TrackModelSrvRequest()
00031 request.model.id = id
00032 request.model.track = track
00033 response = model_select(request)
00034 set_param(response.model,"complexity",0,ParamMsg.PRIOR)
00035 prior_models.append( response.model )
00036
00037 print "All models learned"
00038
00039 for id,track in enumerate(tracks):
00040 print "working on track %d" % id
00041 for i in range(1,len(track.pose)):
00042 print "working on track %d, considering first %d observations " % (id,i)
00043 subtrack = sub_track(track, 0, i)
00044
00045 request = TrackModelSrvRequest()
00046 request.model.id = -1
00047 request.model.track = subtrack
00048 response = model_select(request)
00049
00050
00051 models = []
00052 for id2,model2 in enumerate(deepcopy(prior_models)):
00053 if id!=id2:
00054 models.append(model2)
00055 models.append(response.model)
00056
00057
00058
00059 response_list = []
00060 for model in models:
00061 request = TrackModelSrvRequest()
00062
00063 request.model = model
00064 request.model.track = subtrack
00065 response = model_eval(request)
00066
00067 print response.model.name,get_param(response.model,"bic"),get_param(response.model,"complexity")
00068 response_list.append( response.model )
00069
00070 def compare_bic(model_a, model_b):
00071 return cmp( get_param(model_a,"bic"),get_param(model_b,"bic") )
00072
00073 response_list.sort(compare_bic)
00074 print "best:",response_list[0].name,get_param(response_list[0],"bic"),get_param(response_list[0],"complexity")
00075
00076 model_pub.publish(response_list[0])
00077 except rospy.ROSInterruptException: pass
00078
00079 if __name__ == '__main__':
00080 main()
00081