MetricTrainer.cpp
Go to the documentation of this file.
00001 
00013 // RAIL Recognition
00014 #include "rail_recognition/MetricTrainer.h"
00015 #include "rail_recognition/PointCloudMetrics.h"
00016 
00017 // ROS
00018 #include <pcl_ros/point_cloud.h>
00019 
00020 using namespace std;
00021 using namespace rail::pick_and_place;
00022 
00023 MetricTrainer::MetricTrainer()
00024     : private_node_("~"), get_yes_and_no_feedback_ac_(private_node_, "get_yes_no_feedback", true),
00025       as_(private_node_, "train_metrics", boost::bind(&MetricTrainer::trainMetricsCallback,
00026                                                       this, _1), false)
00027 {
00028   // set defaults
00029   int port = graspdb::Client::DEFAULT_PORT;
00030   string host("127.0.0.1");
00031   string user("ros");
00032   string password("");
00033   string db("graspdb");
00034 
00035   // grab any parameters we need
00036   node_.getParam("/graspdb/host", host);
00037   node_.getParam("/graspdb/port", port);
00038   node_.getParam("/graspdb/user", user);
00039   node_.getParam("/graspdb/password", password);
00040   node_.getParam("/graspdb/db", db);
00041 
00042   // connect to the grasp database
00043   graspdb_ = new graspdb::Client(host, port, user, password, db);
00044   okay_ = graspdb_->connect();
00045 
00046   // setup the point cloud publishers
00047   base_pc_pub_ = private_node_.advertise<pcl::PointCloud<pcl::PointXYZRGB> >("base_pc", 1, true);
00048   aligned_pc_pub_ = private_node_.advertise<pcl::PointCloud<pcl::PointXYZRGB> >("aligned_pc", 1, true);
00049 
00050   if (okay_)
00051   {
00052     ROS_INFO("Metric Trainer Successfully Initialized");
00053     as_.start();
00054   }
00055 }
00056 
00057 MetricTrainer::~MetricTrainer()
00058 {
00059   // cleanup
00060   as_.shutdown();
00061   graspdb_->disconnect();
00062   delete graspdb_;
00063 }
00064 
00065 bool MetricTrainer::okay() const
00066 {
00067   return okay_;
00068 }
00069 
00070 void MetricTrainer::trainMetricsCallback(const rail_pick_and_place_msgs::TrainMetricsGoalConstPtr &goal)
00071 {
00072   ROS_INFO("Gathering metrics for %s. Check RViz to see the matches.", goal->object_name.c_str());
00073 
00074   // default to false
00075   rail_pick_and_place_msgs::TrainMetricsFeedback feedback;
00076   rail_pick_and_place_msgs::TrainMetricsResult result;
00077   result.success = false;
00078 
00079   // get all of the grasp demonstrations for the given object name
00080   feedback.message = "Loading grasp demonstrations...";
00081   as_.publishFeedback(feedback);
00082   vector<graspdb::GraspDemonstration> demonstrations;
00083   graspdb_->loadGraspDemonstrationsByObjectName(goal->object_name, demonstrations);
00084 
00085   // try merging every combination of grasps and gather metrics for each
00086   if (demonstrations.size() >= 2)
00087   {
00088     // convert to PCL point clouds and filter them
00089     feedback.message = "Converting to PCL point clouds...";
00090     as_.publishFeedback(feedback);
00091     vector<pcl::PointCloud<pcl::PointXYZRGB>::Ptr> point_clouds;
00092     for (size_t i = 0; i < demonstrations.size(); i++)
00093     {
00094       // create the PCL point cloud
00095       point_clouds.push_back(pcl::PointCloud<pcl::PointXYZRGB>::Ptr(new pcl::PointCloud<pcl::PointXYZRGB>));
00096       // convert from a ROS message
00097       point_cloud_metrics::rosPointCloud2ToPCLPointCloud(demonstrations[i].getPointCloud(), point_clouds[i]);
00098       // filter and move to the origin
00099       point_cloud_metrics::filterPointCloudOutliers(point_clouds[i]);
00100       point_cloud_metrics::transformToOrigin(point_clouds[i]);
00101     }
00102 
00103     // create the output file
00104     ofstream output_file;
00105     output_file.open("registration_metrics.txt", ios::out | ios::app);
00106 
00107     // check all pairs
00108     for (size_t i = 0; i < point_clouds.size() - 1; i++)
00109     {
00110       for (size_t j = i + 1; j < point_clouds.size(); j++)
00111       {
00112         stringstream ss;
00113         ss << i << " and " << j;
00114         string i_j_str = ss.str();
00115         feedback.message = "Merging point clouds " + i_j_str + "...";
00116         as_.publishFeedback(feedback);
00117 
00118         pcl::PointCloud<pcl::PointXYZRGB>::Ptr base_pc;
00119         pcl::PointCloud<pcl::PointXYZRGB>::Ptr target_pc;
00120 
00121         // set the larger point cloud as the base point cloud
00122         if (point_clouds[i]->size() > point_clouds[j]->size())
00123         {
00124           base_pc = point_clouds[i];
00125           target_pc = point_clouds[j];
00126         }
00127         else
00128         {
00129           base_pc = point_clouds[j];
00130           target_pc = point_clouds[i];
00131         }
00132 
00133         // perform ICP on the point clouds
00134         pcl::PointCloud<pcl::PointXYZRGB>::Ptr aligned_pc(new pcl::PointCloud<pcl::PointXYZRGB>);
00135         point_cloud_metrics::performICP(base_pc, target_pc, aligned_pc);
00136 
00137         // publish result for human verification
00138         base_pc_pub_.publish(base_pc);
00139         aligned_pc_pub_.publish(aligned_pc);
00140 
00141         // wait for input denoting positive or negative registration
00142         feedback.message = "Getting metrics for point clouds " + i_j_str + "...";
00143         as_.publishFeedback(feedback);
00144         // send the request
00145         rail_pick_and_place_msgs::GetYesNoFeedbackGoal goal;
00146         get_yes_and_no_feedback_ac_.sendGoal(goal);
00147 
00148         // calculate all metrics
00149         double m_o, m_c_err;
00150         point_cloud_metrics::calculateRegistrationMetricOverlap(base_pc, aligned_pc, m_o, m_c_err);
00151         double m_d_err = point_cloud_metrics::calculateRegistrationMetricDistanceError(base_pc, aligned_pc);
00152 
00153         // wait for a response
00154         feedback.message = "Waiting for feedback on point clouds " + i_j_str + "...";
00155         as_.publishFeedback(feedback);
00156         get_yes_and_no_feedback_ac_.waitForResult();
00157         string input = (get_yes_and_no_feedback_ac_.getResult()->yes) ? "y" : "n";
00158 
00159         // write the data to the file
00160         output_file << m_o << "," << m_d_err << "," << m_c_err << "," << input << endl;
00161       }
00162     }
00163 
00164     // save the file and finish
00165     output_file.close();
00166     result.success = true;
00167     as_.setSucceeded(result);
00168   } else
00169   {
00170     // not enough models
00171     string message = "Less than 2 models found for " + goal->object_name + ", ignoring request.";
00172     ROS_WARN("%s", message.c_str());
00173     as_.setSucceeded(result, message);
00174   }
00175 }


rail_recognition
Author(s): David Kent , Russell Toris , bhetherman
autogenerated on Sun Mar 6 2016 11:39:13