model_learner_prior.cpp
Go to the documentation of this file.
00001 /*
00002  * model_learner.cpp
00003  *
00004  *  Created on: Oct 21, 2009
00005  *      Author: sturm
00006  */
00007 
00008 #include <ros/ros.h>
00009 
00010 #include "articulation_msgs/ModelMsg.h"
00011 #include "articulation_msgs/TrackMsg.h"
00012 #include "articulation_msgs/ParamMsg.h"
00013 #include "articulation_msgs/TrackModelSrv.h"
00014 #include "articulation_msgs/AlignModelSrv.h"
00015 #include "articulation_msgs/GetModelPriorSrv.h"
00016 #include "articulation_msgs/SetModelPriorSrv.h"
00017 
00018 #include "articulation_models/models/factory.h"
00019 #include "articulation_models/utils.h"
00020 
00021 #include "icp/icp_utils.h"
00022 
00023 using namespace std;
00024 using namespace articulation_models;
00025 using namespace articulation_msgs;
00026 
00027 ros::Publisher model_pub;
00028 
00029 MultiModelFactory factory;
00030 
00031 std::map<int, GenericModelPtr> model_database;
00032 
00033 double sigma_position = 0.01;
00034 double sigma_orientation = 360 * M_PI / 180.0;
00035 
00036 bool single_model = false;
00037 bool do_align = false;
00038 
00039 double sigma_align_position = 0.2;
00040 double sigma_align_orientation = 20 * M_PI / 180.0;
00041 
00042 #define SQR(a) ((a)*(a))
00043 
00044 int next_id = 0;
00045 
00046 double getSimpleLikelihood(GenericModelPtr model) {
00047         return model->getParam("loglikelihood") +
00048                         (model->getParam("samples"))*model->getParam("dofs")*log(model->getParam("samples"));
00049 }
00050 
00051 bool model_select(articulation_msgs::TrackModelSrv::Request &request,
00052                 articulation_msgs::TrackModelSrv::Response &response, bool use_prior) {
00053         ROS_INFO("selecting model, id=%d, poses=%d", (int)request.model.track.id,(int)request.model.track.pose.size());
00054 
00055         map<double, GenericModelPtr> evaluated_models;
00056 
00057         // set some parameters
00058         request.model.id = -1; // cluster assigment not known yet
00059         setParamIfNotDefined(request.model.params, "sigma_position",
00060                         sigma_position, ParamMsg::PRIOR);
00061         setParamIfNotDefined(request.model.params, "sigma_orientation",
00062                         sigma_orientation, ParamMsg::PRIOR);
00063 
00064         double total_loglikelihood=0;
00065         double total_n=0;
00066         double total_k=0;
00067         for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00068                         != model_database.end(); it++) {
00069                 total_loglikelihood += getSimpleLikelihood(it->second) ;
00070                 total_n += it->second->model.track.pose.size();
00071                 total_k += it->second->getParam("complexity");
00072         }
00073 //      cout << "total_loglikelihood ="<<total_loglikelihood <<endl;
00074 //      cout << "total_k ="<<total_k <<endl;
00075 //      cout << "total_n ="<<total_n <<endl;
00076 
00077         // possibility 1: fit a completely new model
00078         GenericModelVector candidate_models = factory.createModels(request.model);
00079         // fit candidates and sort
00080         for (size_t i = 0; i < candidate_models.size(); i++) {
00081                 if (!candidate_models[i]->fitModel())
00082                         continue;
00083                 candidate_models[i]->projectPoseToConfiguration();
00084                 if (!candidate_models[i]->fitMinMaxConfigurations())
00085                         continue;
00086                 if (!candidate_models[i]->evaluateModel())
00087                         continue;
00088                 if (isnan(candidate_models[i]->getBIC()))
00089                         continue;
00090 
00091 //              cout << "candidate_models->getParam(loglikelihood) ="<<candidate_models[i]->getParam("loglikelihood") <<endl;
00092 //              cout << "candidate_models->getParam(complexity) ="<<candidate_models[i]->getParam("complexity") <<endl;
00093 //              cout << "request.model.track.pose.size() ="<<request.model.track.pose.size() <<endl;
00094                 double bic =
00095                                 -2*(total_loglikelihood +
00096                                                 getSimpleLikelihood(candidate_models[i]))
00097                                 +( total_k + candidate_models[i]->getParam("complexity") )
00098                                 * log(total_n + request.model.track.pose.size() );
00099                 evaluated_models[ bic ] = candidate_models[i];
00100 //              cout << candidate_models[i]->getModelName()<< " candidate bic="<< bic <<endl;
00101         }
00102 
00103         if(single_model && use_prior) {
00104                 if(model_database.size()>0) {
00105                         cout << "single_model mode active!! clearing new candidates!!"<<endl;
00106                         evaluated_models.clear();
00107                 }
00108         }
00109 
00110         if(use_prior) {
00111                 // possibility 2: combine trajectory with a stored model, then re-fit
00112                 for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00113                                 != model_database.end(); it++) {
00114                         GenericModelPtr stored_model = it->second;
00115                         articulation_msgs::ModelMsg stored_model_msg = stored_model->getModel();
00116 
00117                         if(do_align) {
00118                                 double k_align = 0;
00119                                 double lh_align = 0;
00120                                 if(request.model.track.pose.size()>5 && stored_model_msg.track.pose.size()>5) {
00121                                         ros::ServiceClient client = ros::NodeHandle().serviceClient<articulation_msgs::AlignModelSrv> ("icp_align");
00122                                         articulation_msgs::AlignModelSrv srv;
00123                                         srv.request.model = stored_model_msg;
00124                                         srv.request.data = request.model;
00125                                         if (client.call(srv)) {
00126                                                 stored_model_msg = srv.response.model_aligned;
00127                                                 k_align = 3;
00128                                                 lh_align =
00129                                                                                 ( SQR(srv.response.dist_trans) / SQR(sigma_align_position) ) +
00130                                                                                 ( SQR(srv.response.dist_rot) / SQR(sigma_align_orientation) );
00131                                         }
00132 
00133                                 }
00134                         }
00135 
00136                                 //        if(request.model.track.pose.size()>10 && stored_model_msg.track.pose.size()>10) {
00137                                 //                icp::IcpAlign alignment(stored_model_msg.track, request.model.track);
00138                                 //                alignment.TransformModel(stored_model_msg.track);
00139                                 //        }
00140 
00141                                 // now join tracks
00142                                 articulation_msgs::ModelMsg merged_model_msg = stored_model_msg;
00143                         merged_model_msg.track.id = request.model.track.id;
00144                         merged_model_msg.track.pose.insert(merged_model_msg.track.pose.end(),
00145                                         request.model.track.pose.begin(),
00146                                         request.model.track.pose.end());
00147                         for (size_t i = 0; i < merged_model_msg.track.pose_flags.size(); i++) {
00148                                 merged_model_msg.track.pose_flags[i] &= ~TrackMsg::POSE_VISIBLE;
00149                         }
00150                         merged_model_msg.track.pose_flags.back()
00151                                         |= TrackMsg::POSE_END_OF_SEGMENT;
00152                         merged_model_msg.track.pose_flags.insert(
00153                                         merged_model_msg.track.pose_flags.end(),
00154                                         request.model.track.pose_flags.begin(),
00155                                         request.model.track.pose_flags.end());
00156 
00157                         //        cout << " merged model, n="<<merged_model_msg.track.pose.size()<<endl;
00158                         GenericModelPtr merged_model = factory.restoreModel(merged_model_msg);
00159                         if (!merged_model->fitModel())
00160                                 continue;
00161                         merged_model->projectPoseToConfiguration();
00162                         if (!merged_model->fitMinMaxConfigurations())
00163                                 continue;
00164                         if (!merged_model->evaluateModel())
00165                                 continue;
00166                         if (isnan(merged_model->getBIC()))
00167                                 continue;
00168 
00169         //              cout << "stored_model->getParam(loglikelihood) ="<<stored_model->getParam("loglikelihood") <<endl;
00170         //              cout << "stored_model->getParam(complexity) ="<<stored_model->getParam("complexity") <<endl;
00171         //              cout << "request.model.track.pose.size() ="<<request.model.track.pose.size() <<endl;
00172         //              cout << "merged_model->getParam(loglikelihood) ="<<merged_model->getParam("loglikelihood") <<endl;
00173         //              cout << "merged_model->getParam(complexity) ="<<merged_model->getParam("complexity") <<endl;
00174                         double bic =
00175                                         -2*(total_loglikelihood
00176                                                         - getSimpleLikelihood(stored_model)
00177                                                         + getSimpleLikelihood(merged_model)
00178                                                 )
00179                                         + ( total_k - stored_model->getParam("complexity") + merged_model->getParam("complexity") )
00180                                         * log(total_n + request.model.track.pose.size() );
00181                         evaluated_models[ bic ] = merged_model;
00182         //              cout << " relative bic="<< (bic) <<endl;
00183                 }
00184         }
00185 
00186         if (evaluated_models.size() == 0) {
00187                 cout << "no valid models found" << endl;
00188                 return false;
00189         }
00190 
00191         // best model
00192         GenericModelPtr selected_model = evaluated_models.begin()->second;
00193         selected_model->bic = evaluated_models.begin()->first;
00194 
00195         // print some information on selected model
00196         cout << selected_model ->getModelName() << selected_model->getId()<<" "<< " pos_err="
00197                         << selected_model ->getPositionError() << " rot_err="
00198                         << selected_model ->getOrientationError() << " bic="
00199                         << selected_model ->getBIC() << " k=" << selected_model ->getParam(
00200                         "complexity") << " n=" << selected_model ->getTrack().pose.size()
00201                         << endl << flush;
00202 
00203         selected_model->sampleConfigurationSpace(0.01);
00204         response.model = selected_model->getModel();
00205         return (true);
00206 }
00207 
00208 bool model_select(articulation_msgs::TrackModelSrv::Request &request,
00209                 articulation_msgs::TrackModelSrv::Response &response) {
00210         return model_select(request,response,true);
00211 }
00212 
00213 bool model_store(articulation_msgs::TrackModelSrv::Request &request,
00214                 articulation_msgs::TrackModelSrv::Response &response) {
00215         ROS_INFO("storing model, id=%d, poses=%d, name=%s", (int)request.model.track.id,(int)request.model.track.pose.size(),request.model.name.c_str());
00216         GenericModelPtr model = factory.restoreModel(request.model);
00217 
00218         if (model->getId() == -1) {
00219                 model->setId(next_id++);
00220         }
00221 
00222         model_database[model->getId()] = model;
00223         response.model = model->getModel();
00224 
00225         return (true);
00226 }
00227 
00228 bool model_get_prior(articulation_msgs::GetModelPriorSrv::Request &request,
00229                 articulation_msgs::GetModelPriorSrv::Response &response) {
00230         ROS_INFO("model_get_prior, returning n=%d models", (int)model_database.size());
00231         for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00232                         != model_database.end(); it++) {
00233                 GenericModelPtr stored_model = it->second;
00234                 response.model.push_back(stored_model->getModel());
00235 
00236         }
00237         return (true);
00238 }
00239 
00240 bool model_set_prior(articulation_msgs::SetModelPriorSrv::Request &request,
00241                 articulation_msgs::GetModelPriorSrv::Response &response) {
00242         ROS_INFO("model_set_prior, restoring n=%d models", (int)request.model.size());
00243 
00244         string filter_models("rigid rotational prismatic");
00245         ros::NodeHandle("~").getParam("filter_models", filter_models);
00246         factory.setFilter(filter_models);
00247         factory.listModelFactories();
00248         model_database.clear();
00249         next_id = 0;
00250         for (size_t i = 0; i < request.model.size(); i++) {
00251                 GenericModelPtr model = factory.restoreModel(request.model[i]);
00252                 model_database[model->getId()] = model;
00253                 if(request.model[i].id>=next_id)
00254                         next_id = request.model[i].id+1;
00255         }
00256         return (true);
00257 }
00258 
00259 
00260 bool model_select_eval(articulation_msgs::TrackModelSrv::Request &request,
00261                 articulation_msgs::TrackModelSrv::Response &response) {
00262         ROS_INFO("evaluation model selection, id=%d, poses=%d", (int)request.model.track.id,(int)request.model.track.pose.size());
00263 
00264         model_select(request,response,false);   // first, find the "right" model, no prior
00265         GenericModelPtr model = factory.restoreModel(response.model);
00266         if(!model) {
00267                 ROS_INFO("sorry, no model");
00268                 return true;
00269         }
00270         int ch = model->openChannel("avg_error_position_cutoff");
00271         int ch_time = model->openChannel("timing");
00272 
00273         articulation_msgs::TrackModelSrv::Request partial_request;
00274         articulation_msgs::TrackModelSrv::Response partial_response;
00275         for(size_t n=0;n<request.model.track.pose.size();n++) {
00276                 partial_request = request;
00277                 partial_request.model.track.pose.erase(
00278                                 partial_request.model.track.pose.begin() + n,
00279                                 partial_request.model.track.pose.end() );
00280 
00281                 ros::Time t_start = ros::Time::now();
00282                 model_select(partial_request,partial_response);
00283                 model->model.track.channels[ ch_time ].values[n] = (ros::Time::now() - t_start ).toSec();
00284 
00285                 GenericModelPtr partial_model = factory.restoreModel(partial_response.model);
00286                 if(!partial_model)
00287                         continue;
00288                 partial_model->fitModel();      // gp need to be constructed first (=re-fitting on partial data).. :(
00289                 partial_model->model.track = request.model.track;
00290 
00291 
00292                 partial_model->evaluateModel();
00293                 double avg_error_position = partial_model->getParam("avg_error_position");
00294                 model->model.track.channels[ ch ].values[n] = avg_error_position;
00295                 cout << n <<"/"<< request.model.track.pose.size()<<": avg_error_position="<<avg_error_position<<endl;
00296         }
00297         response.model = model->getModel();
00298         return true;
00299 }
00300 
00301 int main(int argc, char** argv) {
00302         // init ROS node
00303         ros::init(argc, argv, "model_learner_prior");
00304         ros::NodeHandle n;
00305 
00306         // read ROS params
00307         std::string filter_models("rigid rotational prismatic");
00308         ros::NodeHandle("~").getParam("filter_models", filter_models);
00309         factory.setFilter(filter_models);
00310         factory.listModelFactories();
00311 
00312         ros::NodeHandle("~").getParam("sigma_position", sigma_position);
00313         ros::NodeHandle("~").getParam("sigma_orientation", sigma_orientation);
00314 
00315         ros::NodeHandle("~").getParam("single_model", single_model);
00316         ros::NodeHandle("~").getParam("do_align", do_align);
00317 
00318         // advertise ROS services
00319         ros::ServiceServer model_select_service = n.advertiseService(
00320                         "model_select", model_select);
00321         ros::ServiceServer model_store_service = n.advertiseService("model_store",
00322                         model_store);
00323 
00324         ros::ServiceServer model_get_prior_service = n.advertiseService(
00325                         "model_prior_get", model_get_prior);
00326         ros::ServiceServer model_set_prior_service = n.advertiseService(
00327                         "model_prior_set", model_set_prior);
00328 
00329         ros::ServiceServer model_select_eval_service = n.advertiseService(
00330                         "model_select_eval", model_select_eval);
00331 
00332         ROS_INFO("model_learner_prior, service ready");
00333         ros::spin();
00334 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Properties Friends Defines


articulation_models
Author(s): Juergen Sturm
autogenerated on Wed Dec 26 2012 15:35:18