learn_prior.py
Go to the documentation of this file.
00001 #!/usr/bin/env python
00002 
00003 """
00004  usage: %(progname)s <track_file.txt> 
00005 """
00006 import roslib; roslib.load_manifest('articulation_models')
00007 from articulation_models.track_utils import *
00008 from copy import deepcopy
00009 
00010 
00011 
00012 def main():
00013   try:
00014     rospy.init_node('learn_prior')
00015     model_pub = rospy.Publisher('model', ModelMsg)
00016     
00017     tracks = []
00018     for filename in rospy.myargv()[1:]:
00019       tracks.append( readtrack( filename ) )
00020       
00021     for track in tracks:
00022       track.pose = zero_start( track.pose )
00023 
00024     model_select = rospy.ServiceProxy('model_select', TrackModelSrv)
00025 
00026     model_eval = rospy.ServiceProxy('model_select_eval', TrackModelSrv)
00027 
00028     prior_models = []
00029     for id,track in enumerate(tracks):
00030       request = TrackModelSrvRequest()
00031       request.model.id = id
00032       request.model.track = track
00033       response = model_select(request)
00034       set_param(response.model,"complexity",0,ParamMsg.PRIOR)
00035       prior_models.append( response.model )
00036       
00037     print "All models learned"
00038     
00039     for id,track in enumerate(tracks):
00040       print "working on track %d" % id
00041       for i in range(1,len(track.pose)):
00042         print "working on track %d, considering first %d observations " % (id,i)
00043         subtrack = sub_track(track, 0, i)
00044         
00045         request = TrackModelSrvRequest()
00046         request.model.id = -1 # current
00047         request.model.track = subtrack
00048         response = model_select(request)
00049 
00050         # construct full list
00051         models = []
00052         for id2,model2 in enumerate(deepcopy(prior_models)):
00053           if id!=id2:
00054             models.append(model2)
00055         models.append(response.model)
00056         
00057         # now update all models to new subtrack
00058         
00059         response_list = []
00060         for model in models:
00061           request = TrackModelSrvRequest()
00062 
00063           request.model = model
00064           request.model.track = subtrack
00065           response = model_eval(request)
00066           #print response.model
00067           print response.model.name,get_param(response.model,"bic"),get_param(response.model,"complexity")
00068           response_list.append( response.model )
00069         
00070         def compare_bic(model_a, model_b):
00071           return cmp( get_param(model_a,"bic"),get_param(model_b,"bic") ) 
00072 
00073         response_list.sort(compare_bic)
00074         print "best:",response_list[0].name,get_param(response_list[0],"bic"),get_param(response_list[0],"complexity")
00075         
00076         model_pub.publish(response_list[0])
00077   except rospy.ROSInterruptException: pass
00078     
00079 if __name__ == '__main__':
00080   main()
00081 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Properties Friends Defines


articulation_models
Author(s): Juergen Sturm
autogenerated on Wed Dec 26 2012 15:35:18