00001 #include <numeric>
00002 #include <math.h>
00003 #include <boost/foreach.hpp>
00004 #include <boost/math/distributions/students_t.hpp>
00005 #include <boost/bind.hpp>
00006
00007 #include <ros/ros.h>
00008 #include <pr2_controllers_msgs/JointTrajectoryControllerState.h>
00009 #include <std_msgs/Bool.h>
00010 #include <std_msgs/Float64.h>
00011 #include <ros/package.h>
00012 #include <std_srvs/Empty.h>
00013 #include <rosbag/bag.h>
00014 #include <rosbag/view.h>
00015
00016 #include <pr2_collision_monitor/JointErrorData.h>
00017 #include <pr2_collision_monitor/JointDetectionStart.h>
00018
00019 using namespace std;
00020
00021 namespace pr2_collision_monitor {
00022
00023 class JointCollDetect {
00024 public:
00025 JointCollDetect();
00026 void onInit();
00027 void writeErrorData();
00028 bool isTraining() { return training_mode; }
00029 ~JointCollDetect();
00030
00031 protected:
00032 ros::NodeHandle nh;
00033 ros::NodeHandle nh_priv;
00034 std::string arm, behavior_name, data_filename;
00035
00036 bool monitoring_collisions, training_mode, significance_mode, collision_detected;
00037 double start_time, end_time;
00038 vector<double> min_errors, max_errors;
00039 vector<float> cur_min_data, cur_max_data;
00040 JointErrorData error_data;
00041 vector<std::string> behavior_name_list;
00042 vector<vector<vector<float> > > total_min_data, total_max_data;
00043
00044 bool startDetection(std::string&, float specificity);
00045 void stopDetection();
00046 bool triggerCollision();
00047 void errorCallback(pr2_controllers_msgs::JointTrajectoryControllerState::ConstPtr message);
00048 bool srvStartDetection(JointDetectionStart::Request&,
00049 JointDetectionStart::Response&);
00050 bool srvStopDetection(std_srvs::Empty::Request&, std_srvs::Empty::Response&);
00051 bool srvTriggerCollision(std_srvs::Empty::Request&, std_srvs::Empty::Response&);
00052 void loadErrorBag(const std::string& load_filename,
00053 JointErrorData::Ptr& err_data_ptr);
00054 void loadAllErrorData(vector<std::string>& filename_list);
00055
00056 ros::Publisher detect_pub;
00057 ros::Subscriber error_sub;
00058 ros::ServiceServer start_srv, stop_srv, trig_srv;
00059 };
00060
00061 JointCollDetect::JointCollDetect() : nh_priv("~"),
00062 min_errors(7),
00063 max_errors(7),
00064 cur_min_data(7),
00065 cur_max_data(7) {
00066 onInit();
00067 }
00068
00069 void JointCollDetect::onInit() {
00070 nh_priv.param<std::string>("arm", arm, std::string("r"));
00071
00072
00073
00074
00075 nh_priv.param<bool>("training_mode", training_mode, false);
00076
00077
00078
00079 nh_priv.param<bool>("significance_mode", significance_mode, false);
00080 if(training_mode) {
00081
00082 nh_priv.param<std::string>("behavior_name", behavior_name,
00083 std::string("collision_behavior"));
00084 nh_priv.param<std::string>("data_filename", data_filename,
00085 std::string("joint_error_data"));
00086 error_data.arm = arm;
00087 error_data.behavior = behavior_name;
00088 } else {
00089
00090 if(significance_mode) {
00091
00092
00093 std::string filename_prefix;
00094 nh_priv.getParam("filename_prefix", filename_prefix);
00095 vector<std::string> filename_list;
00096 XmlRpc::XmlRpcValue xml_filenames;
00097 if(nh_priv.hasParam("filename_list")) {
00098 nh_priv.getParam("filename_list", xml_filenames);
00099 for(int i=0;i<xml_filenames.size();i++) {
00100 filename_list.push_back(filename_prefix +
00101 static_cast<std::string>(xml_filenames[i]));
00102 }
00103 } else {
00104 ROS_ERROR("[joint_coll_detect] MUST PROVIDE FILENAMES IN significance_mode (filename_list)");
00105 ros::shutdown();
00106 return;
00107 }
00108 loadAllErrorData(filename_list);
00109 } else {
00110
00111 XmlRpc::XmlRpcValue xml_min_errors, xml_max_errors;
00112 if(nh_priv.hasParam("min_errors")) {
00113 nh_priv.getParam("min_errors", xml_min_errors);
00114 for(int i=0;i<xml_min_errors.size();i++)
00115 min_errors[i] = static_cast<double>(xml_min_errors[i]);
00116 } else {
00117 ROS_ERROR("[joint_coll_detect] MUST PROVIDE THRESHOLDS (min_errors)");
00118 ros::shutdown();
00119 return;
00120 }
00121 if(nh_priv.hasParam("max_errors")) {
00122 nh_priv.getParam("max_errors", xml_max_errors);
00123 for(int i=0;i<xml_max_errors.size();i++)
00124 max_errors[i] = static_cast<double>(xml_max_errors[i]);
00125 } else {
00126 ROS_ERROR("[joint_coll_detect] MUST PROVIDE THRESHOLDS (max_errors)");
00127 ros::shutdown();
00128 return;
00129 }
00130 }
00131 }
00132
00133 monitoring_collisions = false;
00134
00135 detect_pub = nh_priv.advertise<std_msgs::Bool>("arm_collision_detected", 1);
00136 ROS_INFO("[joint_coll_detect] Publishing on arm_collision_detected");
00137 start_srv = nh_priv.advertiseService("start_detection",
00138 &JointCollDetect::srvStartDetection, this);
00139 ROS_INFO("[joint_coll_detect] Service advertised at start_detection");
00140 stop_srv = nh_priv.advertiseService("stop_detection",
00141 &JointCollDetect::srvStopDetection, this);
00142 ROS_INFO("[joint_coll_detect] Service advertised at stop_detection");
00143 trig_srv = nh_priv.advertiseService("trigger_collision",
00144 &JointCollDetect::srvTriggerCollision, this);
00145 ROS_INFO("[joint_coll_detect] Service advertised at trigger_collision");
00146
00147 error_sub = nh.subscribe(arm + "_arm_controller/state", 2,
00148 &JointCollDetect::errorCallback, this);
00149 ROS_INFO("[joint_coll_detect] JointCollDetect loaded.");
00150 }
00151
00152 float minus_squared(float a, float b, float c) { return a + (b-c)*(b-c); }
00153
00154 bool JointCollDetect::startDetection(std::string& behavior, float sig_level) {
00155 if(!monitoring_collisions) {
00156 collision_detected = false;
00157
00158 if(significance_mode) {
00159
00160 uint32_t behavior_ind = std::find(behavior_name_list.begin(),
00161 behavior_name_list.end(),
00162 behavior) - behavior_name_list.begin();
00163 if(behavior_ind == behavior_name_list.size() ||
00164 sig_level > 1 || sig_level < 0) {
00165 ROS_WARN("[joint_coll_detect] Behavior %s not loaded (bad parameters)!",
00166 behavior.c_str());
00167 return false;
00168 }
00169
00170
00171
00172 for(int i=0;i<7;i++) {
00173
00174
00175 int Sn = total_min_data[behavior_ind][i].size();
00176 float Sm_min = std::accumulate(total_min_data[behavior_ind][i].begin(),
00177 total_min_data[behavior_ind][i].end(), 0.0) / Sn;
00178 float Sm_max = std::accumulate(total_max_data[behavior_ind][i].begin(),
00179 total_max_data[behavior_ind][i].end(), 0.0) / Sn;
00180 boost::function<float(float, float)> minus_squared_bind;
00181 minus_squared_bind = boost::bind(&minus_squared, _1, _2, Sm_min);
00182 float Sd_min = std::sqrt(std::accumulate(
00183 total_min_data[behavior_ind][i].begin(),
00184 total_min_data[behavior_ind][i].end(),
00185 0.0, minus_squared_bind) / Sn);
00186 minus_squared_bind = boost::bind(&minus_squared, _1, _2, Sm_max);
00187 float Sd_max = std::sqrt(std::accumulate(
00188 total_max_data[behavior_ind][i].begin(),
00189 total_max_data[behavior_ind][i].end(),
00190 0.0, minus_squared_bind) / Sn);
00191 boost::math::students_t st_dist(Sn-1);
00192 float T = boost::math::quantile(st_dist, sig_level);
00193 float thresh_min = Sm_min - T * Sd_min * std::sqrt(1 + 1.0/Sn);
00194 float thresh_max = Sm_max + T * Sd_max * std::sqrt(1 + 1.0/Sn);
00195 min_errors[i] = thresh_min;
00196 max_errors[i] = thresh_max;
00197
00198
00199 }
00200 }
00201 printf("Min thresh: [");
00202 for(int i=0;i<7;i++)
00203 printf("%1.3f, ", min_errors[i]);
00204 printf("]\nMax thresh: [");
00205 for(int i=0;i<7;i++)
00206 printf("%1.3f, ", max_errors[i]);
00207 printf("]\n");
00208 start_time = ros::Time::now().toSec();
00209 monitoring_collisions = true;
00210 ROS_INFO("[joint_coll_detect] Monitoring for collisions.");
00211 if(training_mode) {
00212 std::fill(cur_min_data.begin(), cur_min_data.end(), 10000);
00213 std::fill(cur_max_data.begin(), cur_max_data.end(), -10000);
00214 }
00215 return true;
00216 }
00217 return false;
00218 }
00219
00220 void JointCollDetect::stopDetection() {
00221 if(monitoring_collisions) {
00222 end_time = ros::Time::now().toSec();
00223 monitoring_collisions = false;
00224 ROS_INFO("[joint_coll_detect] Stopping monitoring (time passed: %2.1f).", end_time-start_time);
00225 if(training_mode) {
00226 error_data.min_errors.insert(error_data.min_errors.end(),
00227 cur_min_data.begin(), cur_min_data.end());
00228 error_data.max_errors.insert(error_data.max_errors.end(),
00229 cur_max_data.begin(), cur_max_data.end());
00230 printf("Min data: [");
00231 for(int i=0;i<7;i++)
00232 printf("%1.3f, ", cur_min_data[i]);
00233 printf("]\nMax data: [");
00234 for(int i=0;i<7;i++)
00235 printf("%1.3f, ", cur_max_data[i]);
00236 printf("]\n");
00237 }
00238 }
00239 }
00240
00241 bool JointCollDetect::srvStartDetection(JointDetectionStart::Request& req,
00242 JointDetectionStart::Response& resp) {
00243 return startDetection(req.behavior, req.sig_level);
00244 }
00245
00246 bool JointCollDetect::srvStopDetection(std_srvs::Empty::Request&, std_srvs::Empty::Response&) {
00247 stopDetection();
00248 return true;
00249 }
00250
00251 bool JointCollDetect::triggerCollision() {
00252 if(!training_mode) {
00253 if(monitoring_collisions)
00254 stopDetection();
00255 std_msgs::Bool bool_true;
00256 bool_true.data = true;
00257 detect_pub.publish(bool_true);
00258 collision_detected = true;
00259 return true;
00260 }
00261 return false;
00262 }
00263
00264 bool JointCollDetect::srvTriggerCollision(std_srvs::Empty::Request& req, std_srvs::Empty::Response& res) {
00265 triggerCollision();
00266 return true;
00267 }
00268
00269 void JointCollDetect::errorCallback(pr2_controllers_msgs::JointTrajectoryControllerState::ConstPtr message) {
00270 if(!monitoring_collisions || message->error.positions.size() < 7)
00271 return;
00272
00273 for(int i=0;i<7;i++) {
00274 if(!training_mode) {
00275 if(message->error.positions[i] < min_errors[i] ||
00276 message->error.positions[i] > max_errors[i]) {
00277 if(triggerCollision())
00278 ROS_INFO("[joint_coll_detect] Collision detected on joint %d. Min: %1.3f, Max: %1.3f, Cur: %1.3f", i, min_errors[i], max_errors[i], message->error.positions[i]);
00279 }
00280 } else {
00281 if(message->error.positions[i] < cur_min_data[i])
00282 cur_min_data[i] = message->error.positions[i];
00283 if(message->error.positions[i] > cur_max_data[i])
00284 cur_max_data[i] = message->error.positions[i];
00285 }
00286 }
00287 if(!collision_detected) {
00288 std_msgs::Bool bool_false;
00289 bool_false.data = false;
00290 detect_pub.publish(bool_false);
00291 }
00292 }
00293
00294 void JointCollDetect::writeErrorData() {
00295 ROS_INFO("[joint_coll_detect] Writing error data to file.");
00296 rosbag::Bag data_bag;
00297 data_bag.open(data_filename, rosbag::bagmode::Write);
00298 data_bag.write("/error_data", ros::Time::now(), error_data);
00299 data_bag.close();
00300 ROS_INFO("[joint_coll_detect] Bag file written.");
00301 }
00302
00303 void JointCollDetect::loadErrorBag(const std::string& load_filename,
00304 JointErrorData::Ptr& err_data_ptr) {
00305 rosbag::Bag data_bag;
00306 data_bag.open(load_filename, rosbag::bagmode::Read);
00307 rosbag::View view(data_bag, rosbag::TopicQuery("/error_data"));
00308 if(view.size() == 0 || view.size() > 1) {
00309 ROS_ERROR("[joint_coll_detect] Badly formed error_data file (%s)", load_filename.c_str());
00310 ros::shutdown();
00311 return;
00312 }
00313 BOOST_FOREACH(rosbag::MessageInstance const m, view) {
00314 err_data_ptr = m.instantiate<JointErrorData>();
00315 }
00316 }
00317
00318 void JointCollDetect::loadAllErrorData(vector<std::string>& filename_list) {
00319 total_min_data.resize(filename_list.size()); total_max_data.resize(filename_list.size());
00320 behavior_name_list.resize(filename_list.size());
00321 int beh_ind = 0;
00322 BOOST_FOREACH(std::string const filename, filename_list) {
00323 JointErrorData::Ptr err_data_ptr;
00324 loadErrorBag(filename, err_data_ptr);
00325 total_min_data[beh_ind].resize(7); total_max_data[beh_ind].resize(7);
00326 uint32_t data_pt_ind = 0;
00327 while(data_pt_ind < err_data_ptr->min_errors.size()) {
00328 for(int i=0;i<7;i++) {
00329 total_min_data[beh_ind][i].push_back(err_data_ptr->min_errors[data_pt_ind]);
00330 total_max_data[beh_ind][i].push_back(err_data_ptr->max_errors[data_pt_ind]);
00331 data_pt_ind++;
00332 }
00333 }
00334 behavior_name_list[beh_ind] = err_data_ptr->behavior;
00335 beh_ind++;
00336 }
00337 }
00338
00339 JointCollDetect::~JointCollDetect() {
00340 }
00341
00342 };
00343
00344 int main(int argc, char **argv)
00345 {
00346 ros::init(argc, argv, "joint_coll_detect", ros::init_options::AnonymousName);
00347 pr2_collision_monitor::JointCollDetect cm;
00348 ros::spin();
00349 if(cm.isTraining())
00350 cm.writeErrorData();
00351 return 0;
00352 }
00353