model_learner_msg.cpp
Go to the documentation of this file.
00001 /*
00002  * model_learner.cpp
00003  *
00004  *  Created on: Oct 21, 2009
00005  *      Author: sturm
00006  */
00007 
00008 #include <ros/ros.h>
00009 
00010 #include "articulation_msgs/ModelMsg.h"
00011 #include "articulation_msgs/TrackMsg.h"
00012 #include "articulation_msgs/ParamMsg.h"
00013 
00014 #include "articulation_models/models/factory.h"
00015 #include "articulation_models/utils.h"
00016 #include <boost/foreach.hpp>
00017 
00018 #define DEBUG false
00019 
00020 using namespace std;
00021 using namespace articulation_models;
00022 using namespace articulation_msgs;
00023 
00024 ros::Publisher model_pub;
00025 
00026 MultiModelFactory factory;
00027 
00028 GenericModelVector models_valid;
00029 
00030 map<string,     ros::Time> startingTime;
00031 map<string, vector<double> > measurements;
00032 
00033 ros::NodeHandle *nh;
00034 double sigma_position = 0.01;
00035 double sigma_orientation = 4*M_PI;
00036 
00037 
00038 void TIC(string name){
00039         startingTime[name] = ros::Time::now();
00040 }
00041 
00042 void TOC(string name) {
00043         measurements[name].push_back( (ros::Time::now() - startingTime[name]).toSec() );
00044 }
00045 
00046 void ADD_DATA(string name,double data) {
00047         measurements[name].push_back( data );
00048 }
00049 
00050 
00051 #define SQR(a) ((a)*(a))
00052 void EVAL() {
00053         map<string, vector<double> >::iterator it;
00054         for(it = measurements.begin(); it!=measurements.end(); it++) {
00055                 size_t n = it->second.size();
00056                 double sum = 0;
00057                 for(size_t i=0;i<n;i++) {
00058                         sum += it->second[i];
00059                 }
00060                 double mean = sum /n;
00061                 double vsum = 0;
00062                 for(size_t i=0;i<n;i++) {
00063                         vsum += SQR(it->second[i] - mean);
00064                 }
00065                 double var = vsum / n;
00066                 cout << it->first << " " << mean << " "<<sqrt(var)<< " ("<<n<<" obs)"<< endl;
00067         }
00068 }
00069 
00070 void trackCallback(const TrackMsgConstPtr& track)
00071 {
00072   ROS_INFO("Received track id [%d]", track->id);
00073 
00074   articulation_msgs::ModelMsg model_track;
00075   model_track.track = *track;
00076   setParamIfNotDefined(model_track.params, "sigma_position",
00077                         sigma_position, ParamMsg::PRIOR);
00078   setParamIfNotDefined(model_track.params, "sigma_orientation",
00079                         sigma_orientation, ParamMsg::PRIOR);
00080 
00081   TIC("createModels");
00082 
00083   GenericModelVector models_new = factory.createModels( model_track );
00084   TOC("createModels");
00085 
00086   GenericModelVector models_old = models_valid;
00087 
00088   models_valid.clear();
00089   models_old.clear();
00090 
00091   // update old models, then add valid
00092   for(size_t i=0;i<models_old.size();i++) {
00093           models_old[i]->setTrack(*track);
00094           models_old[i]->projectPoseToConfiguration();
00095           if( !models_old[i]->fitMinMaxConfigurations() ) continue;
00096           if( !models_old[i]->evaluateModel() ) continue;
00097 
00098           models_valid.push_back( models_old[i] );
00099   }
00100 
00101   // fit new models, then add valid
00102   TIC("per_track");
00103   for(size_t i=0;i<models_new.size();i++) {
00104           TIC("fitModel" + models_new[i]->getModelName());
00105           if( !models_new[i]->fitModel() ) {
00106                   if(DEBUG) cout <<"fitting of "<<models_new[i]->getModelName()<<" failed"<<endl;
00107                   continue;
00108           }
00109           TOC("fitModel" + models_new[i]->getModelName());
00110           TIC("projectPoseToConfiguration" + models_new[i]->getModelName());
00111           models_new[i]->projectPoseToConfiguration();
00112           TOC("projectPoseToConfiguration" + models_new[i]->getModelName());
00113           TIC("fitMinMaxConfigurations" + models_new[i]->getModelName());
00114           if( !models_new[i]->fitMinMaxConfigurations() ) {
00115                   if(DEBUG) cout <<"fitting of min/max conf of "<<models_new[i]->getModelName()<<" failed"<<endl;
00116                   continue;
00117           }
00118           TOC("fitMinMaxConfigurations" + models_new[i]->getModelName());
00119 
00120           TIC("evaluateModel" + models_new[i]->getModelName());
00121           if( !models_new[i]->evaluateModel() ) {
00122                   if(DEBUG) cout <<"evaluation of "<<models_new[i]->getModelName()<<" failed"<<endl;
00123                   continue;
00124           }
00125           TOC("evaluateModel" + models_new[i]->getModelName());
00126 
00127           models_valid.push_back( models_new[i] );
00128 
00129   }
00130   TOC("per_track");
00131 
00132 
00133   map<double,GenericModelPtr> models_sorted;
00134   for(size_t i=0;i<models_valid.size();i++) {
00135           if(isnan( models_valid[i]->getBIC() )) {
00136                   if(DEBUG) cout <<"BIC eval of "<<models_new[i]->getModelName()<<" is nan, skipping"<<endl;
00137                   continue;
00138           }
00139           models_sorted[models_valid[i]->getBIC()] = models_valid[i];
00140   }
00141 
00142   if(models_sorted.size()==0) {
00143           cout << "no valid models found"<<endl;
00144           return;
00145   }
00146 
00147   for(map<double,GenericModelPtr>::iterator it=models_sorted.begin();it!=models_sorted.end();it++) {
00148   cout << it->second->getModelName()<<
00149           " pos_err=" << it->second->getPositionError()<<
00150           " rot_err=" << it->second->getOrientationError()<<
00151           " bic=" << it->second->getBIC()<<
00152           " k=" << it->second->getParam("complexity") <<
00153           " n=" << it->second->getTrack().pose.size() <<
00154           endl;
00155   }
00156 //  }
00157   map<double,GenericModelPtr>::iterator it = models_sorted.begin();
00158   models_valid.clear();
00159   models_valid.push_back(it->second);
00160 
00161 //  it->second->projectPoseToConfiguration();
00162 //  it->second->fitMinMaxConfigurations();
00163   if(it->second->getModelName()=="rotational") {
00164           it->second->sampleConfigurationSpace( 0.05 );
00165   } else {
00166           it->second->sampleConfigurationSpace( 0.01 );
00167   }
00168 
00169   ModelMsg msg = it->second->getModel();
00170   msg.id = msg.track.id;
00171   model_pub.publish( msg );
00172 
00173 //  EVAL();
00174 }
00175 
00176 int main(int argc, char** argv)
00177 {
00178   ros::init(argc, argv, "model_learner");
00179   ros::NodeHandle n;
00180 
00181   std::string filter_models("rigid rotational prismatic");
00182   ros::NodeHandle("~").getParam("filter_models", filter_models);
00183   ros::NodeHandle("~").getParam("sigma_position",sigma_position);
00184   ros::NodeHandle("~").getParam("sigma_orientation",sigma_orientation);
00185   factory.setFilter(filter_models);
00186 
00187   cout <<"(param) sigma_position=" << sigma_position << endl;
00188   cout <<"(param) sigma_orientation=" << sigma_orientation << endl;
00189 
00190   model_pub = n.advertise<ModelMsg>("model", 1);
00191 
00192   ros::Subscriber track_sub = n.subscribe("track", 1, trackCallback);
00193   ros::spin();
00194 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Properties Friends Defines


articulation_models
Author(s): Juergen Sturm
autogenerated on Wed Dec 26 2012 15:35:18