joint_coll_detect.cpp
Go to the documentation of this file.
00001 #include <numeric>
00002 #include <math.h>
00003 #include <boost/foreach.hpp>
00004 #include <boost/math/distributions/students_t.hpp>
00005 #include <boost/bind.hpp>
00006 
00007 #include <ros/ros.h>
00008 #include <pr2_controllers_msgs/JointTrajectoryControllerState.h>
00009 #include <std_msgs/Bool.h>
00010 #include <std_msgs/Float64.h>
00011 #include <ros/package.h>
00012 #include <std_srvs/Empty.h>
00013 #include <rosbag/bag.h>
00014 #include <rosbag/view.h>
00015 
00016 #include <pr2_collision_monitor/JointErrorData.h>
00017 #include <pr2_collision_monitor/JointDetectionStart.h>
00018 
00019 using namespace std;
00020 
00021 namespace pr2_collision_monitor {
00022 
00023     class JointCollDetect {
00024         public:
00025             JointCollDetect();
00026             void onInit();
00027             void writeErrorData();
00028             bool isTraining() { return training_mode; }
00029             ~JointCollDetect(); 
00030 
00031         protected:
00032             ros::NodeHandle nh;
00033             ros::NodeHandle nh_priv;
00034             std::string arm, behavior_name, data_filename;
00035 
00036             bool monitoring_collisions, training_mode, significance_mode, collision_detected;
00037             double start_time, end_time;
00038             vector<double> min_errors, max_errors;
00039             vector<float> cur_min_data, cur_max_data;
00040             JointErrorData error_data;
00041             vector<std::string> behavior_name_list;
00042             vector<vector<vector<float> > > total_min_data, total_max_data;
00043 
00044             bool startDetection(std::string&, float specificity);
00045             void stopDetection();
00046             bool triggerCollision();
00047             void errorCallback(pr2_controllers_msgs::JointTrajectoryControllerState::ConstPtr message);
00048             bool srvStartDetection(JointDetectionStart::Request&, 
00049                                    JointDetectionStart::Response&);
00050             bool srvStopDetection(std_srvs::Empty::Request&, std_srvs::Empty::Response&);
00051             bool srvTriggerCollision(std_srvs::Empty::Request&, std_srvs::Empty::Response&);
00052             void loadErrorBag(const std::string& load_filename, 
00053                               JointErrorData::Ptr& err_data_ptr);
00054             void loadAllErrorData(vector<std::string>& filename_list);
00055 
00056             ros::Publisher detect_pub;
00057             ros::Subscriber error_sub;
00058             ros::ServiceServer start_srv, stop_srv, trig_srv;
00059     };
00060 
00061     JointCollDetect::JointCollDetect() : nh_priv("~"),
00062                                            min_errors(7),
00063                                            max_errors(7),
00064                                            cur_min_data(7),
00065                                            cur_max_data(7) {
00066         onInit();
00067     }
00068 
00069     void JointCollDetect::onInit() {
00070         nh_priv.param<std::string>("arm", arm, std::string("r"));
00071 
00072         // training_mode does not detect collisions but instead monitors
00073         // normal operation and collects statistics on joint errors for each behavior
00074         // execution
00075         nh_priv.param<bool>("training_mode", training_mode, false);
00076         // significance_mode employs the statistics used in training
00077         // if this mode is false, manual threshold must be provided through
00078         // the parameter server
00079         nh_priv.param<bool>("significance_mode", significance_mode, false);
00080         if(training_mode) {
00081             // training to find suitable thresholds
00082             nh_priv.param<std::string>("behavior_name", behavior_name, 
00083                                                         std::string("collision_behavior"));
00084             nh_priv.param<std::string>("data_filename", data_filename, 
00085                                                         std::string("joint_error_data"));
00086             error_data.arm = arm;
00087             error_data.behavior = behavior_name;
00088         } else {
00089             // loading thresholds
00090             if(significance_mode) {
00091                 // load training data and can use specificity values
00092                 // to determine thresholds
00093                 std::string filename_prefix;
00094                 nh_priv.getParam("filename_prefix", filename_prefix);
00095                 vector<std::string> filename_list;
00096                 XmlRpc::XmlRpcValue xml_filenames;
00097                 if(nh_priv.hasParam("filename_list")) {
00098                     nh_priv.getParam("filename_list", xml_filenames);
00099                     for(int i=0;i<xml_filenames.size();i++) {
00100                         filename_list.push_back(filename_prefix + 
00101                                                 static_cast<std::string>(xml_filenames[i]));
00102                     }
00103                 } else {
00104                     ROS_ERROR("[joint_coll_detect] MUST PROVIDE FILENAMES IN significance_mode (filename_list)");
00105                     ros::shutdown();
00106                     return;
00107                 }
00108                 loadAllErrorData(filename_list);
00109             } else {
00110                 // load from manually set thresholds
00111                 XmlRpc::XmlRpcValue xml_min_errors, xml_max_errors;
00112                 if(nh_priv.hasParam("min_errors")) {
00113                     nh_priv.getParam("min_errors", xml_min_errors);
00114                     for(int i=0;i<xml_min_errors.size();i++)
00115                         min_errors[i] = static_cast<double>(xml_min_errors[i]);
00116                 } else {
00117                     ROS_ERROR("[joint_coll_detect] MUST PROVIDE THRESHOLDS (min_errors)");
00118                     ros::shutdown();
00119                     return;
00120                 }
00121                 if(nh_priv.hasParam("max_errors")) {
00122                     nh_priv.getParam("max_errors", xml_max_errors);
00123                     for(int i=0;i<xml_max_errors.size();i++)
00124                         max_errors[i] = static_cast<double>(xml_max_errors[i]);
00125                 } else {
00126                     ROS_ERROR("[joint_coll_detect] MUST PROVIDE THRESHOLDS (max_errors)");
00127                     ros::shutdown();
00128                     return;
00129                 }
00130             }
00131         }
00132 
00133         monitoring_collisions = false;
00134 
00135         detect_pub = nh_priv.advertise<std_msgs::Bool>("arm_collision_detected", 1);
00136         ROS_INFO("[joint_coll_detect] Publishing on arm_collision_detected");
00137         start_srv = nh_priv.advertiseService("start_detection", 
00138                                              &JointCollDetect::srvStartDetection, this);
00139         ROS_INFO("[joint_coll_detect] Service advertised at start_detection");
00140         stop_srv = nh_priv.advertiseService("stop_detection", 
00141                                             &JointCollDetect::srvStopDetection, this);
00142         ROS_INFO("[joint_coll_detect] Service advertised at stop_detection");
00143         trig_srv = nh_priv.advertiseService("trigger_collision", 
00144                                             &JointCollDetect::srvTriggerCollision, this);
00145         ROS_INFO("[joint_coll_detect] Service advertised at trigger_collision");
00146 
00147         error_sub = nh.subscribe(arm + "_arm_controller/state", 2, 
00148                 &JointCollDetect::errorCallback, this);
00149         ROS_INFO("[joint_coll_detect] JointCollDetect loaded.");
00150     }
00151 
00152     float minus_squared(float a, float b, float c) { return a + (b-c)*(b-c); }
00153 
00154     bool JointCollDetect::startDetection(std::string& behavior, float sig_level) {
00155         if(!monitoring_collisions) {
00156             collision_detected = false;
00157 
00158             if(significance_mode) {
00159                 // load the behavior requested
00160                 uint32_t behavior_ind = std::find(behavior_name_list.begin(), 
00161                                                   behavior_name_list.end(),
00162                                                   behavior) - behavior_name_list.begin();
00163                 if(behavior_ind == behavior_name_list.size() || 
00164                                              sig_level > 1 || sig_level < 0) {
00165                     ROS_WARN("[joint_coll_detect] Behavior %s not loaded (bad parameters)!", 
00166                                                            behavior.c_str());
00167                     return false;
00168                 }
00169 
00170                 // set the thresholds using the statistical properties of the 
00171                 // training data
00172                 for(int i=0;i<7;i++) {
00173                     // Perform a prediction interval on the training data
00174                     // using a Student's t-test
00175                     int Sn = total_min_data[behavior_ind][i].size();
00176                     float Sm_min = std::accumulate(total_min_data[behavior_ind][i].begin(),
00177                                                    total_min_data[behavior_ind][i].end(), 0.0) / Sn;
00178                     float Sm_max = std::accumulate(total_max_data[behavior_ind][i].begin(),
00179                                                    total_max_data[behavior_ind][i].end(), 0.0) / Sn;
00180                     boost::function<float(float, float)> minus_squared_bind;
00181                     minus_squared_bind = boost::bind(&minus_squared, _1, _2, Sm_min);
00182                     float Sd_min = std::sqrt(std::accumulate(
00183                                                    total_min_data[behavior_ind][i].begin(),
00184                                                    total_min_data[behavior_ind][i].end(),
00185                                                    0.0, minus_squared_bind) / Sn);
00186                     minus_squared_bind = boost::bind(&minus_squared, _1, _2, Sm_max);
00187                     float Sd_max = std::sqrt(std::accumulate(
00188                                                    total_max_data[behavior_ind][i].begin(),
00189                                                    total_max_data[behavior_ind][i].end(),
00190                                                    0.0, minus_squared_bind) / Sn);
00191                     boost::math::students_t st_dist(Sn-1);
00192                     float T = boost::math::quantile(st_dist, sig_level);
00193                     float thresh_min = Sm_min - T * Sd_min * std::sqrt(1 + 1.0/Sn);
00194                     float thresh_max = Sm_max + T * Sd_max * std::sqrt(1 + 1.0/Sn);
00195                     min_errors[i] = thresh_min;
00196                     max_errors[i] = thresh_max;
00197                     //ROS_INFO("Sm_min: %f, Sm_max: %f, Sn: %d", Sm_min, Sm_max, Sn);
00198                     //ROS_INFO("Sd_min: %f, Sd_max: %f, Sn: %d", Sd_min, Sd_max, Sn);
00199                 }
00200             }
00201             printf("Min thresh: [");
00202             for(int i=0;i<7;i++)
00203                 printf("%1.3f, ", min_errors[i]);
00204             printf("]\nMax thresh: [");
00205             for(int i=0;i<7;i++)
00206                 printf("%1.3f, ", max_errors[i]);
00207             printf("]\n");
00208             start_time = ros::Time::now().toSec();
00209             monitoring_collisions = true;
00210             ROS_INFO("[joint_coll_detect] Monitoring for collisions.");
00211             if(training_mode) {
00212                 std::fill(cur_min_data.begin(), cur_min_data.end(), 10000);
00213                 std::fill(cur_max_data.begin(), cur_max_data.end(), -10000);
00214             }
00215             return true;
00216         }
00217         return false;
00218     }
00219 
00220     void JointCollDetect::stopDetection() {
00221         if(monitoring_collisions) {
00222             end_time = ros::Time::now().toSec();
00223             monitoring_collisions = false;
00224             ROS_INFO("[joint_coll_detect] Stopping monitoring (time passed: %2.1f).", end_time-start_time);
00225             if(training_mode) {
00226                 error_data.min_errors.insert(error_data.min_errors.end(), 
00227                                              cur_min_data.begin(), cur_min_data.end());
00228                 error_data.max_errors.insert(error_data.max_errors.end(), 
00229                                              cur_max_data.begin(), cur_max_data.end());
00230                 printf("Min data: [");
00231                 for(int i=0;i<7;i++)
00232                     printf("%1.3f, ", cur_min_data[i]);
00233                 printf("]\nMax data: [");
00234                 for(int i=0;i<7;i++)
00235                     printf("%1.3f, ", cur_max_data[i]);
00236                 printf("]\n");
00237             }
00238         }
00239     }
00240 
00241     bool JointCollDetect::srvStartDetection(JointDetectionStart::Request& req, 
00242                                              JointDetectionStart::Response& resp) {
00243         return startDetection(req.behavior, req.sig_level);
00244     }
00245 
00246     bool JointCollDetect::srvStopDetection(std_srvs::Empty::Request&, std_srvs::Empty::Response&) {
00247         stopDetection();
00248         return true;
00249     }
00250 
00251     bool JointCollDetect::triggerCollision() {
00252         if(!training_mode) {
00253             if(monitoring_collisions)
00254                 stopDetection();
00255             std_msgs::Bool bool_true;
00256             bool_true.data = true;
00257             detect_pub.publish(bool_true);
00258             collision_detected = true;
00259             return true;
00260         }
00261         return false;
00262     }
00263 
00264     bool JointCollDetect::srvTriggerCollision(std_srvs::Empty::Request& req, std_srvs::Empty::Response& res) {
00265         triggerCollision();
00266         return true;
00267     }
00268 
00269     void JointCollDetect::errorCallback(pr2_controllers_msgs::JointTrajectoryControllerState::ConstPtr message) {
00270         if(!monitoring_collisions || message->error.positions.size() < 7)
00271             return;
00272 
00273         for(int i=0;i<7;i++) {
00274             if(!training_mode) {
00275                 if(message->error.positions[i] < min_errors[i] ||
00276                    message->error.positions[i] > max_errors[i]) {
00277                     if(triggerCollision())
00278                         ROS_INFO("[joint_coll_detect] Collision detected on joint %d. Min: %1.3f, Max: %1.3f, Cur: %1.3f", i, min_errors[i], max_errors[i], message->error.positions[i]);
00279                 }
00280             } else {
00281                 if(message->error.positions[i] < cur_min_data[i])
00282                     cur_min_data[i] = message->error.positions[i];
00283                 if(message->error.positions[i] > cur_max_data[i])
00284                     cur_max_data[i] = message->error.positions[i];
00285             }
00286         }
00287         if(!collision_detected) {
00288             std_msgs::Bool bool_false;
00289             bool_false.data = false;
00290             detect_pub.publish(bool_false);
00291         }
00292     }
00293 
00294     void JointCollDetect::writeErrorData() {
00295         ROS_INFO("[joint_coll_detect] Writing error data to file.");
00296         rosbag::Bag data_bag;
00297         data_bag.open(data_filename, rosbag::bagmode::Write);
00298         data_bag.write("/error_data", ros::Time::now(), error_data);
00299         data_bag.close();
00300         ROS_INFO("[joint_coll_detect] Bag file written.");
00301     }
00302 
00303     void JointCollDetect::loadErrorBag(const std::string& load_filename, 
00304                                          JointErrorData::Ptr& err_data_ptr) {
00305         rosbag::Bag data_bag;
00306         data_bag.open(load_filename, rosbag::bagmode::Read);
00307         rosbag::View view(data_bag, rosbag::TopicQuery("/error_data"));
00308         if(view.size() == 0 || view.size() > 1) {
00309             ROS_ERROR("[joint_coll_detect] Badly formed error_data file (%s)", load_filename.c_str());
00310             ros::shutdown();
00311             return;
00312         }
00313         BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00314             err_data_ptr = m.instantiate<JointErrorData>();
00315         }
00316     }
00317 
00318     void JointCollDetect::loadAllErrorData(vector<std::string>& filename_list) {
00319         total_min_data.resize(filename_list.size()); total_max_data.resize(filename_list.size());
00320         behavior_name_list.resize(filename_list.size());
00321         int beh_ind = 0;
00322         BOOST_FOREACH(std::string const filename, filename_list) {
00323             JointErrorData::Ptr err_data_ptr;
00324             loadErrorBag(filename, err_data_ptr);
00325             total_min_data[beh_ind].resize(7); total_max_data[beh_ind].resize(7);
00326             uint32_t data_pt_ind = 0;
00327             while(data_pt_ind < err_data_ptr->min_errors.size()) {
00328                 for(int i=0;i<7;i++) {
00329                     total_min_data[beh_ind][i].push_back(err_data_ptr->min_errors[data_pt_ind]);
00330                     total_max_data[beh_ind][i].push_back(err_data_ptr->max_errors[data_pt_ind]);
00331                     data_pt_ind++;
00332                 }
00333             }
00334             behavior_name_list[beh_ind] = err_data_ptr->behavior;
00335             beh_ind++;
00336         }
00337     }
00338 
00339     JointCollDetect::~JointCollDetect() {
00340     }
00341 
00342 };
00343 
00344 int main(int argc, char **argv)
00345 {
00346     ros::init(argc, argv, "joint_coll_detect", ros::init_options::AnonymousName);
00347     pr2_collision_monitor::JointCollDetect cm;
00348     ros::spin();
00349     if(cm.isTraining()) 
00350         cm.writeErrorData();
00351     return 0;
00352 }
00353 


pr2_collision_monitor
Author(s): Kelsey Hawkins, Advisor: Prof. Charlie Kemp (Healthcare Robotics Lab at Georgia Tech)
autogenerated on Wed Nov 27 2013 11:40:10