Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008 #include <ros/ros.h>
00009
00010 #include "articulation_msgs/ModelMsg.h"
00011 #include "articulation_msgs/TrackMsg.h"
00012 #include "articulation_msgs/ParamMsg.h"
00013
00014 #include "articulation_models/models/factory.h"
00015 #include "articulation_models/utils.h"
00016 #include <boost/foreach.hpp>
00017
00018 #define DEBUG false
00019
00020 using namespace std;
00021 using namespace articulation_models;
00022 using namespace articulation_msgs;
00023
00024 ros::Publisher model_pub;
00025
00026 MultiModelFactory factory;
00027
00028 GenericModelVector models_valid;
00029
00030 map<string, ros::Time> startingTime;
00031 map<string, vector<double> > measurements;
00032
00033 ros::NodeHandle *nh;
00034 double sigma_position = 0.01;
00035 double sigma_orientation = 4*M_PI;
00036
00037
00038 void TIC(string name){
00039 startingTime[name] = ros::Time::now();
00040 }
00041
00042 void TOC(string name) {
00043 measurements[name].push_back( (ros::Time::now() - startingTime[name]).toSec() );
00044 }
00045
00046 void ADD_DATA(string name,double data) {
00047 measurements[name].push_back( data );
00048 }
00049
00050
00051 #define SQR(a) ((a)*(a))
00052 void EVAL() {
00053 map<string, vector<double> >::iterator it;
00054 for(it = measurements.begin(); it!=measurements.end(); it++) {
00055 size_t n = it->second.size();
00056 double sum = 0;
00057 for(size_t i=0;i<n;i++) {
00058 sum += it->second[i];
00059 }
00060 double mean = sum /n;
00061 double vsum = 0;
00062 for(size_t i=0;i<n;i++) {
00063 vsum += SQR(it->second[i] - mean);
00064 }
00065 double var = vsum / n;
00066 cout << it->first << " " << mean << " "<<sqrt(var)<< " ("<<n<<" obs)"<< endl;
00067 }
00068 }
00069
00070 void trackCallback(const TrackMsgConstPtr& track)
00071 {
00072 ROS_INFO("Received track id [%d]", track->id);
00073
00074 articulation_msgs::ModelMsg model_track;
00075 model_track.track = *track;
00076 setParamIfNotDefined(model_track.params, "sigma_position",
00077 sigma_position, ParamMsg::PRIOR);
00078 setParamIfNotDefined(model_track.params, "sigma_orientation",
00079 sigma_orientation, ParamMsg::PRIOR);
00080
00081 TIC("createModels");
00082
00083 GenericModelVector models_new = factory.createModels( model_track );
00084 TOC("createModels");
00085
00086 GenericModelVector models_old = models_valid;
00087
00088 models_valid.clear();
00089 models_old.clear();
00090
00091
00092 for(size_t i=0;i<models_old.size();i++) {
00093 models_old[i]->setTrack(*track);
00094 models_old[i]->projectPoseToConfiguration();
00095 if( !models_old[i]->fitMinMaxConfigurations() ) continue;
00096 if( !models_old[i]->evaluateModel() ) continue;
00097
00098 models_valid.push_back( models_old[i] );
00099 }
00100
00101
00102 TIC("per_track");
00103 for(size_t i=0;i<models_new.size();i++) {
00104 TIC("fitModel" + models_new[i]->getModelName());
00105 if( !models_new[i]->fitModel() ) {
00106 if(DEBUG) cout <<"fitting of "<<models_new[i]->getModelName()<<" failed"<<endl;
00107 continue;
00108 }
00109 TOC("fitModel" + models_new[i]->getModelName());
00110 TIC("projectPoseToConfiguration" + models_new[i]->getModelName());
00111 models_new[i]->projectPoseToConfiguration();
00112 TOC("projectPoseToConfiguration" + models_new[i]->getModelName());
00113 TIC("fitMinMaxConfigurations" + models_new[i]->getModelName());
00114 if( !models_new[i]->fitMinMaxConfigurations() ) {
00115 if(DEBUG) cout <<"fitting of min/max conf of "<<models_new[i]->getModelName()<<" failed"<<endl;
00116 continue;
00117 }
00118 TOC("fitMinMaxConfigurations" + models_new[i]->getModelName());
00119
00120 TIC("evaluateModel" + models_new[i]->getModelName());
00121 if( !models_new[i]->evaluateModel() ) {
00122 if(DEBUG) cout <<"evaluation of "<<models_new[i]->getModelName()<<" failed"<<endl;
00123 continue;
00124 }
00125 TOC("evaluateModel" + models_new[i]->getModelName());
00126
00127 models_valid.push_back( models_new[i] );
00128
00129 }
00130 TOC("per_track");
00131
00132
00133 map<double,GenericModelPtr> models_sorted;
00134 for(size_t i=0;i<models_valid.size();i++) {
00135 if(isnan( models_valid[i]->getBIC() )) {
00136 if(DEBUG) cout <<"BIC eval of "<<models_new[i]->getModelName()<<" is nan, skipping"<<endl;
00137 continue;
00138 }
00139 models_sorted[models_valid[i]->getBIC()] = models_valid[i];
00140 }
00141
00142 if(models_sorted.size()==0) {
00143 cout << "no valid models found"<<endl;
00144 return;
00145 }
00146
00147 for(map<double,GenericModelPtr>::iterator it=models_sorted.begin();it!=models_sorted.end();it++) {
00148 cout << it->second->getModelName()<<
00149 " pos_err=" << it->second->getPositionError()<<
00150 " rot_err=" << it->second->getOrientationError()<<
00151 " bic=" << it->second->getBIC()<<
00152 " k=" << it->second->getParam("complexity") <<
00153 " n=" << it->second->getTrack().pose.size() <<
00154 endl;
00155 }
00156
00157 map<double,GenericModelPtr>::iterator it = models_sorted.begin();
00158 models_valid.clear();
00159 models_valid.push_back(it->second);
00160
00161
00162
00163 if(it->second->getModelName()=="rotational") {
00164 it->second->sampleConfigurationSpace( 0.05 );
00165 } else {
00166 it->second->sampleConfigurationSpace( 0.01 );
00167 }
00168
00169 ModelMsg msg = it->second->getModel();
00170 msg.id = msg.track.id;
00171 model_pub.publish( msg );
00172
00173
00174 }
00175
00176 int main(int argc, char** argv)
00177 {
00178 ros::init(argc, argv, "model_learner");
00179 ros::NodeHandle n;
00180
00181 std::string filter_models("rigid rotational prismatic");
00182 ros::NodeHandle("~").getParam("filter_models", filter_models);
00183 ros::NodeHandle("~").getParam("sigma_position",sigma_position);
00184 ros::NodeHandle("~").getParam("sigma_orientation",sigma_orientation);
00185 factory.setFilter(filter_models);
00186
00187 cout <<"(param) sigma_position=" << sigma_position << endl;
00188 cout <<"(param) sigma_orientation=" << sigma_orientation << endl;
00189
00190 model_pub = n.advertise<ModelMsg>("model", 1);
00191
00192 ros::Subscriber track_sub = n.subscribe("track", 1, trackCallback);
00193 ros::spin();
00194 }