00001
00002
00003
00004
00005
00006
00007
00008 #include <ros/ros.h>
00009
00010 #include "articulation_msgs/ModelMsg.h"
00011 #include "articulation_msgs/TrackMsg.h"
00012 #include "articulation_msgs/ParamMsg.h"
00013 #include "articulation_msgs/TrackModelSrv.h"
00014 #include "articulation_msgs/AlignModelSrv.h"
00015 #include "articulation_msgs/GetModelPriorSrv.h"
00016 #include "articulation_msgs/SetModelPriorSrv.h"
00017
00018 #include "articulation_models/models/factory.h"
00019 #include "articulation_models/utils.h"
00020
00021 #include "icp/icp_utils.h"
00022
00023 using namespace std;
00024 using namespace articulation_models;
00025 using namespace articulation_msgs;
00026
00027 ros::Publisher model_pub;
00028
00029 MultiModelFactory factory;
00030
00031 std::map<int, GenericModelPtr> model_database;
00032
00033 double sigma_position = 0.01;
00034 double sigma_orientation = 360 * M_PI / 180.0;
00035
00036 bool single_model = false;
00037 bool do_align = false;
00038
00039 double sigma_align_position = 0.2;
00040 double sigma_align_orientation = 20 * M_PI / 180.0;
00041
00042 #define SQR(a) ((a)*(a))
00043
00044 int next_id = 0;
00045
00046 double getSimpleLikelihood(GenericModelPtr model) {
00047 return model->getParam("loglikelihood") +
00048 (model->getParam("samples"))*model->getParam("dofs")*log(model->getParam("samples"));
00049 }
00050
00051 bool model_select(articulation_msgs::TrackModelSrv::Request &request,
00052 articulation_msgs::TrackModelSrv::Response &response, bool use_prior) {
00053 ROS_INFO("selecting model, id=%d, poses=%d", (int)request.model.track.id,(int)request.model.track.pose.size());
00054
00055 map<double, GenericModelPtr> evaluated_models;
00056
00057
00058 request.model.id = -1;
00059 setParamIfNotDefined(request.model.params, "sigma_position",
00060 sigma_position, ParamMsg::PRIOR);
00061 setParamIfNotDefined(request.model.params, "sigma_orientation",
00062 sigma_orientation, ParamMsg::PRIOR);
00063
00064 double total_loglikelihood=0;
00065 double total_n=0;
00066 double total_k=0;
00067 for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00068 != model_database.end(); it++) {
00069 total_loglikelihood += getSimpleLikelihood(it->second) ;
00070 total_n += it->second->model.track.pose.size();
00071 total_k += it->second->getParam("complexity");
00072 }
00073
00074
00075
00076
00077
00078 GenericModelVector candidate_models = factory.createModels(request.model);
00079
00080 for (size_t i = 0; i < candidate_models.size(); i++) {
00081 if (!candidate_models[i]->fitModel())
00082 continue;
00083 candidate_models[i]->projectPoseToConfiguration();
00084 if (!candidate_models[i]->fitMinMaxConfigurations())
00085 continue;
00086 if (!candidate_models[i]->evaluateModel())
00087 continue;
00088 if (isnan(candidate_models[i]->getBIC()))
00089 continue;
00090
00091
00092
00093
00094 double bic =
00095 -2*(total_loglikelihood +
00096 getSimpleLikelihood(candidate_models[i]))
00097 +( total_k + candidate_models[i]->getParam("complexity") )
00098 * log(total_n + request.model.track.pose.size() );
00099 evaluated_models[ bic ] = candidate_models[i];
00100
00101 }
00102
00103 if(single_model && use_prior) {
00104 if(model_database.size()>0) {
00105 cout << "single_model mode active!! clearing new candidates!!"<<endl;
00106 evaluated_models.clear();
00107 }
00108 }
00109
00110 if(use_prior) {
00111
00112 for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00113 != model_database.end(); it++) {
00114 GenericModelPtr stored_model = it->second;
00115 articulation_msgs::ModelMsg stored_model_msg = stored_model->getModel();
00116
00117 if(do_align) {
00118 double k_align = 0;
00119 double lh_align = 0;
00120 if(request.model.track.pose.size()>5 && stored_model_msg.track.pose.size()>5) {
00121 ros::ServiceClient client = ros::NodeHandle().serviceClient<articulation_msgs::AlignModelSrv> ("icp_align");
00122 articulation_msgs::AlignModelSrv srv;
00123 srv.request.model = stored_model_msg;
00124 srv.request.data = request.model;
00125 if (client.call(srv)) {
00126 stored_model_msg = srv.response.model_aligned;
00127 k_align = 3;
00128 lh_align =
00129 ( SQR(srv.response.dist_trans) / SQR(sigma_align_position) ) +
00130 ( SQR(srv.response.dist_rot) / SQR(sigma_align_orientation) );
00131 }
00132
00133 }
00134 }
00135
00136
00137
00138
00139
00140
00141
00142 articulation_msgs::ModelMsg merged_model_msg = stored_model_msg;
00143 merged_model_msg.track.id = request.model.track.id;
00144 merged_model_msg.track.pose.insert(merged_model_msg.track.pose.end(),
00145 request.model.track.pose.begin(),
00146 request.model.track.pose.end());
00147 for (size_t i = 0; i < merged_model_msg.track.pose_flags.size(); i++) {
00148 merged_model_msg.track.pose_flags[i] &= ~TrackMsg::POSE_VISIBLE;
00149 }
00150 merged_model_msg.track.pose_flags.back()
00151 |= TrackMsg::POSE_END_OF_SEGMENT;
00152 merged_model_msg.track.pose_flags.insert(
00153 merged_model_msg.track.pose_flags.end(),
00154 request.model.track.pose_flags.begin(),
00155 request.model.track.pose_flags.end());
00156
00157
00158 GenericModelPtr merged_model = factory.restoreModel(merged_model_msg);
00159 if (!merged_model->fitModel())
00160 continue;
00161 merged_model->projectPoseToConfiguration();
00162 if (!merged_model->fitMinMaxConfigurations())
00163 continue;
00164 if (!merged_model->evaluateModel())
00165 continue;
00166 if (isnan(merged_model->getBIC()))
00167 continue;
00168
00169
00170
00171
00172
00173
00174 double bic =
00175 -2*(total_loglikelihood
00176 - getSimpleLikelihood(stored_model)
00177 + getSimpleLikelihood(merged_model)
00178 )
00179 + ( total_k - stored_model->getParam("complexity") + merged_model->getParam("complexity") )
00180 * log(total_n + request.model.track.pose.size() );
00181 evaluated_models[ bic ] = merged_model;
00182
00183 }
00184 }
00185
00186 if (evaluated_models.size() == 0) {
00187 cout << "no valid models found" << endl;
00188 return false;
00189 }
00190
00191
00192 GenericModelPtr selected_model = evaluated_models.begin()->second;
00193 selected_model->bic = evaluated_models.begin()->first;
00194
00195
00196 cout << selected_model ->getModelName() << selected_model->getId()<<" "<< " pos_err="
00197 << selected_model ->getPositionError() << " rot_err="
00198 << selected_model ->getOrientationError() << " bic="
00199 << selected_model ->getBIC() << " k=" << selected_model ->getParam(
00200 "complexity") << " n=" << selected_model ->getTrack().pose.size()
00201 << endl << flush;
00202
00203 selected_model->sampleConfigurationSpace(0.01);
00204 response.model = selected_model->getModel();
00205 return (true);
00206 }
00207
00208 bool model_select(articulation_msgs::TrackModelSrv::Request &request,
00209 articulation_msgs::TrackModelSrv::Response &response) {
00210 return model_select(request,response,true);
00211 }
00212
00213 bool model_store(articulation_msgs::TrackModelSrv::Request &request,
00214 articulation_msgs::TrackModelSrv::Response &response) {
00215 ROS_INFO("storing model, id=%d, poses=%d, name=%s", (int)request.model.track.id,(int)request.model.track.pose.size(),request.model.name.c_str());
00216 GenericModelPtr model = factory.restoreModel(request.model);
00217
00218 if (model->getId() == -1) {
00219 model->setId(next_id++);
00220 }
00221
00222 model_database[model->getId()] = model;
00223 response.model = model->getModel();
00224
00225 return (true);
00226 }
00227
00228 bool model_get_prior(articulation_msgs::GetModelPriorSrv::Request &request,
00229 articulation_msgs::GetModelPriorSrv::Response &response) {
00230 ROS_INFO("model_get_prior, returning n=%d models", (int)model_database.size());
00231 for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00232 != model_database.end(); it++) {
00233 GenericModelPtr stored_model = it->second;
00234 response.model.push_back(stored_model->getModel());
00235
00236 }
00237 return (true);
00238 }
00239
00240 bool model_set_prior(articulation_msgs::SetModelPriorSrv::Request &request,
00241 articulation_msgs::GetModelPriorSrv::Response &response) {
00242 ROS_INFO("model_set_prior, restoring n=%d models", (int)request.model.size());
00243
00244 string filter_models("rigid rotational prismatic");
00245 ros::NodeHandle("~").getParam("filter_models", filter_models);
00246 factory.setFilter(filter_models);
00247 factory.listModelFactories();
00248 model_database.clear();
00249 next_id = 0;
00250 for (size_t i = 0; i < request.model.size(); i++) {
00251 GenericModelPtr model = factory.restoreModel(request.model[i]);
00252 model_database[model->getId()] = model;
00253 if(request.model[i].id>=next_id)
00254 next_id = request.model[i].id+1;
00255 }
00256 return (true);
00257 }
00258
00259
00260 bool model_select_eval(articulation_msgs::TrackModelSrv::Request &request,
00261 articulation_msgs::TrackModelSrv::Response &response) {
00262 ROS_INFO("evaluation model selection, id=%d, poses=%d", (int)request.model.track.id,(int)request.model.track.pose.size());
00263
00264 model_select(request,response,false);
00265 GenericModelPtr model = factory.restoreModel(response.model);
00266 if(!model) {
00267 ROS_INFO("sorry, no model");
00268 return true;
00269 }
00270 int ch = model->openChannel("avg_error_position_cutoff");
00271 int ch_time = model->openChannel("timing");
00272
00273 articulation_msgs::TrackModelSrv::Request partial_request;
00274 articulation_msgs::TrackModelSrv::Response partial_response;
00275 for(size_t n=0;n<request.model.track.pose.size();n++) {
00276 partial_request = request;
00277 partial_request.model.track.pose.erase(
00278 partial_request.model.track.pose.begin() + n,
00279 partial_request.model.track.pose.end() );
00280
00281 ros::Time t_start = ros::Time::now();
00282 model_select(partial_request,partial_response);
00283 model->model.track.channels[ ch_time ].values[n] = (ros::Time::now() - t_start ).toSec();
00284
00285 GenericModelPtr partial_model = factory.restoreModel(partial_response.model);
00286 if(!partial_model)
00287 continue;
00288 partial_model->fitModel();
00289 partial_model->model.track = request.model.track;
00290
00291
00292 partial_model->evaluateModel();
00293 double avg_error_position = partial_model->getParam("avg_error_position");
00294 model->model.track.channels[ ch ].values[n] = avg_error_position;
00295 cout << n <<"/"<< request.model.track.pose.size()<<": avg_error_position="<<avg_error_position<<endl;
00296 }
00297 response.model = model->getModel();
00298 return true;
00299 }
00300
00301 int main(int argc, char** argv) {
00302
00303 ros::init(argc, argv, "model_learner_prior");
00304 ros::NodeHandle n;
00305
00306
00307 std::string filter_models("rigid rotational prismatic");
00308 ros::NodeHandle("~").getParam("filter_models", filter_models);
00309 factory.setFilter(filter_models);
00310 factory.listModelFactories();
00311
00312 ros::NodeHandle("~").getParam("sigma_position", sigma_position);
00313 ros::NodeHandle("~").getParam("sigma_orientation", sigma_orientation);
00314
00315 ros::NodeHandle("~").getParam("single_model", single_model);
00316 ros::NodeHandle("~").getParam("do_align", do_align);
00317
00318
00319 ros::ServiceServer model_select_service = n.advertiseService(
00320 "model_select", model_select);
00321 ros::ServiceServer model_store_service = n.advertiseService("model_store",
00322 model_store);
00323
00324 ros::ServiceServer model_get_prior_service = n.advertiseService(
00325 "model_prior_get", model_get_prior);
00326 ros::ServiceServer model_set_prior_service = n.advertiseService(
00327 "model_prior_set", model_set_prior);
00328
00329 ros::ServiceServer model_select_eval_service = n.advertiseService(
00330 "model_select_eval", model_select_eval);
00331
00332 ROS_INFO("model_learner_prior, service ready");
00333 ros::spin();
00334 }