00001
00002
00003
00004
00005
00006
00007
00008 #include <ros/ros.h>
00009
00010 #include "articulation_msgs/ModelMsg.h"
00011 #include "articulation_msgs/TrackMsg.h"
00012 #include "articulation_msgs/ParamMsg.h"
00013 #include "articulation_msgs/TrackModelSrv.h"
00014 #include "articulation_msgs/AlignModelSrv.h"
00015 #include "articulation_msgs/GetModelPriorSrv.h"
00016 #include "articulation_msgs/SetModelPriorSrv.h"
00017
00018 #include "articulation_models/models/factory.h"
00019 #include "articulation_models/utils.h"
00020
00021 #include "icp/icp_utils.h"
00022
00023 using namespace std;
00024 using namespace articulation_models;
00025 using namespace articulation_msgs;
00026
00027 ros::Publisher model_pub;
00028
00029 MultiModelFactory factory;
00030
00031 std::map<int, GenericModelPtr> model_database;
00032
00033 double sigma_position = 0.01;
00034 double sigma_orientation = 360 * M_PI / 180.0;
00035
00036 bool single_model = false;
00037 bool do_align = false;
00038
00039 double sigma_align_position = 0.2;
00040 double sigma_align_orientation = 20 * M_PI / 180.0;
00041
00042 #define SQR(a) ((a)*(a))
00043
00044 int next_id = 0;
00045
00046 bool model_select(articulation_msgs::TrackModelSrv::Request &request,
00047 articulation_msgs::TrackModelSrv::Response &response, bool use_prior) {
00048 ROS_INFO("selecting model, id=%d, poses=%d", (int)request.model.track.id,(int)request.model.track.pose.size());
00049
00050 map<double, GenericModelPtr> evaluated_models;
00051
00052
00053 request.model.id = -1;
00054 setParamIfNotDefined(request.model.params, "sigma_position",
00055 sigma_position, ParamMsg::PRIOR);
00056 setParamIfNotDefined(request.model.params, "sigma_orientation",
00057 sigma_orientation, ParamMsg::PRIOR);
00058
00059 double total_loglikelihood=0;
00060 double total_n=0;
00061 double total_k=0;
00062 for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00063 != model_database.end(); it++) {
00064 total_loglikelihood += it->second->getParam("loglikelihood");
00065 total_n += it->second->model.track.pose.size();
00066 total_k += it->second->getParam("complexity");
00067 }
00068
00069
00070
00071
00072
00073 GenericModelVector candidate_models = factory.createModels(request.model);
00074
00075 for (size_t i = 0; i < candidate_models.size(); i++) {
00076 if (!candidate_models[i]->fitModel())
00077 continue;
00078 candidate_models[i]->projectPoseToConfiguration();
00079 if (!candidate_models[i]->fitMinMaxConfigurations())
00080 continue;
00081 if (!candidate_models[i]->evaluateModel())
00082 continue;
00083 if (isnan(candidate_models[i]->getBIC()))
00084 continue;
00085
00086
00087
00088
00089 double bic =
00090 -2*(total_loglikelihood + candidate_models[i]->getParam("loglikelihood"))
00091 +( total_k + candidate_models[i]->getParam("complexity") )
00092 * log(total_n + request.model.track.pose.size() );
00093 evaluated_models[ bic ] = candidate_models[i];
00094
00095 }
00096
00097 if(single_model && use_prior) {
00098 if(model_database.size()>0) {
00099 cout << "single_model mode active!! clearing new candidates!!"<<endl;
00100 evaluated_models.clear();
00101 }
00102 }
00103
00104 if(use_prior) {
00105
00106 for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00107 != model_database.end(); it++) {
00108 GenericModelPtr stored_model = it->second;
00109 articulation_msgs::ModelMsg stored_model_msg = stored_model->getModel();
00110
00111 if(do_align) {
00112 double k_align = 0;
00113 double lh_align = 0;
00114 if(request.model.track.pose.size()>5 && stored_model_msg.track.pose.size()>5) {
00115 ros::ServiceClient client = ros::NodeHandle().serviceClient<articulation_msgs::AlignModelSrv> ("icp_align");
00116 articulation_msgs::AlignModelSrv srv;
00117 srv.request.model = stored_model_msg;
00118 srv.request.data = request.model;
00119 if (client.call(srv)) {
00120 stored_model_msg = srv.response.model_aligned;
00121 k_align = 3;
00122 lh_align =
00123 ( SQR(srv.response.dist_trans) / SQR(sigma_align_position) ) +
00124 ( SQR(srv.response.dist_rot) / SQR(sigma_align_orientation) );
00125 }
00126
00127 }
00128 }
00129
00130
00131
00132
00133
00134
00135
00136 articulation_msgs::ModelMsg merged_model_msg = stored_model_msg;
00137 merged_model_msg.track.id = request.model.track.id;
00138 merged_model_msg.track.pose.insert(merged_model_msg.track.pose.end(),
00139 request.model.track.pose.begin(),
00140 request.model.track.pose.end());
00141 for (size_t i = 0; i < merged_model_msg.track.pose_flags.size(); i++) {
00142 merged_model_msg.track.pose_flags[i] &= ~TrackMsg::POSE_VISIBLE;
00143 }
00144 merged_model_msg.track.pose_flags.back()
00145 |= TrackMsg::POSE_END_OF_SEGMENT;
00146 merged_model_msg.track.pose_flags.insert(
00147 merged_model_msg.track.pose_flags.end(),
00148 request.model.track.pose_flags.begin(),
00149 request.model.track.pose_flags.end());
00150
00151
00152 GenericModelPtr merged_model = factory.restoreModel(merged_model_msg);
00153 if (!merged_model->fitModel())
00154 continue;
00155 merged_model->projectPoseToConfiguration();
00156 if (!merged_model->fitMinMaxConfigurations())
00157 continue;
00158 if (!merged_model->evaluateModel())
00159 continue;
00160 if (isnan(merged_model->getBIC()))
00161 continue;
00162
00163
00164
00165
00166
00167
00168 double bic =
00169 -2*(total_loglikelihood - stored_model->getParam("loglikelihood") + merged_model->getParam("loglikelihood"))
00170 + ( total_k - stored_model->getParam("complexity") + merged_model->getParam("complexity") )
00171 * log(total_n + request.model.track.pose.size() );
00172 evaluated_models[ bic ] = merged_model;
00173
00174 }
00175 }
00176
00177 if (evaluated_models.size() == 0) {
00178 cout << "no valid models found" << endl;
00179 return false;
00180 }
00181
00182
00183 GenericModelPtr selected_model = evaluated_models.begin()->second;
00184 selected_model->bic = evaluated_models.begin()->first;
00185
00186
00187 cout << selected_model ->getModelName() << selected_model->getId()<<" "<< " pos_err="
00188 << selected_model ->getPositionError() << " rot_err="
00189 << selected_model ->getOrientationError() << " bic="
00190 << selected_model ->getBIC() << " k=" << selected_model ->getParam(
00191 "complexity") << " n=" << selected_model ->getTrack().pose.size()
00192 << endl << flush;
00193
00194 selected_model->sampleConfigurationSpace(0.01);
00195 response.model = selected_model->getModel();
00196 return (true);
00197 }
00198
00199 bool model_select(articulation_msgs::TrackModelSrv::Request &request,
00200 articulation_msgs::TrackModelSrv::Response &response) {
00201 return model_select(request,response,true);
00202 }
00203
00204 bool model_store(articulation_msgs::TrackModelSrv::Request &request,
00205 articulation_msgs::TrackModelSrv::Response &response) {
00206 ROS_INFO("storing model, id=%d, poses=%d, name=%s", (int)request.model.track.id,(int)request.model.track.pose.size(),request.model.name.c_str());
00207 GenericModelPtr model = factory.restoreModel(request.model);
00208
00209 if (model->getId() == -1) {
00210 model->setId(next_id++);
00211 }
00212
00213 model_database[model->getId()] = model;
00214 response.model = model->getModel();
00215
00216 return (true);
00217 }
00218
00219 bool model_get_prior(articulation_msgs::GetModelPriorSrv::Request &request,
00220 articulation_msgs::GetModelPriorSrv::Response &response) {
00221 ROS_INFO("model_get_prior, returning n=%d models", (int)model_database.size());
00222 for (std::map<int, GenericModelPtr>::iterator it = model_database.begin(); it
00223 != model_database.end(); it++) {
00224 GenericModelPtr stored_model = it->second;
00225 response.model.push_back(stored_model->getModel());
00226
00227 }
00228 return (true);
00229 }
00230
00231 bool model_set_prior(articulation_msgs::SetModelPriorSrv::Request &request,
00232 articulation_msgs::GetModelPriorSrv::Response &response) {
00233 ROS_INFO("model_set_prior, restoring n=%d models", (int)request.model.size());
00234
00235 string filter_models("rigid rotational prismatic");
00236 ros::NodeHandle("~").getParam("filter_models", filter_models);
00237 factory.setFilter(filter_models);
00238 factory.listModelFactories();
00239 model_database.clear();
00240 next_id = 0;
00241 for (size_t i = 0; i < request.model.size(); i++) {
00242 GenericModelPtr model = factory.restoreModel(request.model[i]);
00243 model_database[model->getId()] = model;
00244 if(request.model[i].id>=next_id)
00245 next_id = request.model[i].id+1;
00246 }
00247 return (true);
00248 }
00249
00250
00251 bool model_select_eval(articulation_msgs::TrackModelSrv::Request &request,
00252 articulation_msgs::TrackModelSrv::Response &response) {
00253 ROS_INFO("evaluation model selection, id=%d, poses=%d", (int)request.model.track.id,(int)request.model.track.pose.size());
00254
00255 model_select(request,response,false);
00256 GenericModelPtr model = factory.restoreModel(response.model);
00257 if(!model) {
00258 ROS_INFO("sorry, no model");
00259 return true;
00260 }
00261 int ch = model->openChannel("avg_error_position_cutoff");
00262 int ch_time = model->openChannel("timing");
00263
00264 articulation_msgs::TrackModelSrv::Request partial_request;
00265 articulation_msgs::TrackModelSrv::Response partial_response;
00266 for(size_t n=0;n<request.model.track.pose.size();n++) {
00267 partial_request = request;
00268 partial_request.model.track.pose.erase(
00269 partial_request.model.track.pose.begin() + n,
00270 partial_request.model.track.pose.end() );
00271
00272 ros::Time t_start = ros::Time::now();
00273 model_select(partial_request,partial_response);
00274 model->model.track.channels[ ch_time ].values[n] = (ros::Time::now() - t_start ).toSec();
00275
00276 GenericModelPtr partial_model = factory.restoreModel(partial_response.model);
00277 if(!partial_model)
00278 continue;
00279 partial_model->fitModel();
00280 partial_model->model.track = request.model.track;
00281
00282
00283 partial_model->evaluateModel();
00284 double avg_error_position = partial_model->getParam("avg_error_position");
00285 model->model.track.channels[ ch ].values[n] = avg_error_position;
00286 cout << n <<"/"<< request.model.track.pose.size()<<": avg_error_position="<<avg_error_position<<endl;
00287 }
00288 response.model = model->getModel();
00289 return true;
00290 }
00291
00292 int main(int argc, char** argv) {
00293
00294 ros::init(argc, argv, "model_learner_prior");
00295 ros::NodeHandle n;
00296
00297
00298 std::string filter_models("rigid rotational prismatic");
00299 ros::NodeHandle("~").getParam("filter_models", filter_models);
00300 factory.setFilter(filter_models);
00301 factory.listModelFactories();
00302
00303 ros::NodeHandle("~").getParam("sigma_position", sigma_position);
00304 ros::NodeHandle("~").getParam("sigma_orientation", sigma_orientation);
00305
00306 ros::NodeHandle("~").getParam("single_model", single_model);
00307 ros::NodeHandle("~").getParam("do_align", do_align);
00308
00309
00310 ros::ServiceServer model_select_service = n.advertiseService(
00311 "model_select", model_select);
00312 ros::ServiceServer model_store_service = n.advertiseService("model_store",
00313 model_store);
00314
00315 ros::ServiceServer model_get_prior_service = n.advertiseService(
00316 "model_prior_get", model_get_prior);
00317 ros::ServiceServer model_set_prior_service = n.advertiseService(
00318 "model_prior_set", model_set_prior);
00319
00320 ros::ServiceServer model_select_eval_service = n.advertiseService(
00321 "model_select_eval", model_select_eval);
00322
00323 ROS_INFO("model_learner_prior, service ready");
00324 ros::spin();
00325 }