00001 #include <ros/ros.h>
00002
00003 #include <boost/thread/mutex.hpp>
00004
00005 #include <tf/tf.h>
00006 #include <tf/transform_listener.h>
00007 #include <tf/tfMessage.h>
00008 #include <tf/transform_broadcaster.h>
00009
00010 #include <sensor_msgs/JointState.h>
00011 #include <geometry_msgs/PointStamped.h>
00012 #include <geometry_msgs/PoseStamped.h>
00013 #include <geometry_msgs/QuaternionStamped.h>
00014 #include <trajectory_msgs/JointTrajectory.h>
00015 #include <visualization_msgs/Marker.h>
00016 #include <visualization_msgs/MarkerArray.h>
00017
00018 #include <pr2_controllers_msgs/Pr2GripperCommandAction.h>
00019 #include <actionlib/client/simple_action_client.h>
00020
00021 #include <motld/TrackedObjects.h>
00022
00023 #include <mpc/types.h>
00024 #include <mpc/mpc.h>
00025 #include <mpc/finite_differences.h>
00026 #include <reactive_grasping_pr2/pr2_joint_space.h>
00027
00028 #ifndef SQR
00029 #define SQR(X) ((X)*(X))
00030 #endif
00031
00032 using namespace MPC;
00033
00034 typedef actionlib::SimpleActionClient<pr2_controllers_msgs::Pr2GripperCommandAction> GripperClient;
00035
00036 class ArmMPCJoints {
00037
00038 private:
00039 bool initialized_;
00040 VectorT center_;
00041 tf::TransformBroadcaster *tf_broadcaster_;
00042 tf::TransformListener *tf_transformer_;
00043
00044 ros::Subscriber obj_sub_;
00045
00046 ros::Publisher joint_pose_pub_;
00047 ros::Publisher marker_pub_;
00048 ros::Subscriber joint_state_sub_;
00049 boost::mutex state_mutex_;
00050 std::vector<std::string> joint_names_;
00051 std::vector<int> joint_mapping_;
00052 bool joint_mapping_initialized_;
00053 VectorT curr_joint_state_;
00054
00055 int curr_t_;
00056 float eps_;
00057 float max_mov_dist_;
00058
00059 bool on_way_to_goal_;
00060 bool last_goal_set_;
00061 bool last_pos_set_;
00062 VectorT last_goal_;
00063 VectorT last_pos_;
00064
00065 int n, m, N;
00066 bool restart_ddp_;
00067 DDPParams params;
00068 Trajectory curr_traj_;
00069 MatrixT u0_;
00070 trajectory_msgs::JointTrajectory joint_traj_;
00071 GripperClient *close_gripper_;
00072
00073 pr2_joint_space::RightArmModel *r_arm_model_;
00074
00075 ros::NodeHandle nh_;
00076 ros::NodeHandle private_nh_;
00077
00078
00079 std::string root_frame_;
00080 std::string gripper_frame_;
00081 std::string object_topic_;
00082 std::string gripper_action_topic_;
00083 std::string gripper_close_topic_;
00084 std::string marker_topic_;
00085
00086 public:
00087 ArmMPCJoints(ros::NodeHandle &nh) : nh_(nh), private_nh_(ros::NodeHandle("~")), initialized_(false) {
00088
00089 private_nh_.param("root_frame", root_frame_, std::string("/torso_lift_link"));
00090 private_nh_.param("gripper_frame", gripper_frame_, std::string("/r_wrist_roll_link"));
00091 private_nh_.param("object_topic", object_topic_, std::string("/objects"));
00092 private_nh_.param("gripper_action_topic", gripper_action_topic_, std::string("/r_arm_controller/command"));
00093 private_nh_.param("gripper_close_topic", gripper_close_topic_, std::string("/r_gripper_controller/gripper_action"));
00094 private_nh_.param("marker_topic", marker_topic_, std::string("MPC_markers"));
00095
00096
00097 tf_broadcaster_ = new tf::TransformBroadcaster();
00098 tf_transformer_ = new tf::TransformListener();
00099
00100
00101 obj_sub_ = nh_.subscribe("objects", 1, &ArmMPCJoints::mpc_callback, this);
00102
00103 joint_pose_pub_ = nh_.advertise<trajectory_msgs::JointTrajectory>(gripper_action_topic_, 1);
00104
00105 close_gripper_ = new GripperClient(gripper_close_topic_, true);
00106 while(!close_gripper_->waitForServer(ros::Duration(5.0))) {
00107 ROS_INFO("Waiting for the gripper action server to come up");
00108 }
00109
00110
00111 marker_pub_ = nh_.advertise<visualization_msgs::MarkerArray>(marker_topic_, 1);
00112
00113
00114 n = 14;
00115 m = 7;
00116 N = 20;
00117 u0_.setZero(m, N);
00118 VectorT internal_goal(VectorT::Zero(14));
00119
00120 eps_ = 0.02;
00121 max_mov_dist_ = 0.05;
00122
00123 restart_ddp_ = true;
00124
00125
00127 r_arm_model_ = new pr2_joint_space::RightArmModel(nh_, root_frame_);
00128 r_arm_model_->setInternalGoal(internal_goal);
00129
00130 finite_differences::setBaseFunctions(std::bind(&pr2_joint_space::RightArmModel::dynamics, r_arm_model_, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), std::bind(&pr2_joint_space::RightArmModel::cost, r_arm_model_, std::placeholders::_1, std::placeholders::_2));
00131 bool init_success = DDP_init(std::bind(&pr2_joint_space::RightArmModel::dynamics, r_arm_model_, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), std::bind(&pr2_joint_space::RightArmModel::dynamicsD, r_arm_model_, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3), &finite_differences::cost, &finite_differences::costD);
00132
00133
00134
00135
00136
00137
00138
00139
00140
00141
00142 center_.resize(3);
00143 last_goal_.resize(n);
00144 last_pos_.resize(n);
00145 on_way_to_goal_ = false;
00146 last_goal_set_ = false;
00147 last_pos_set_ = false;
00148 curr_t_ = 1;
00149
00150
00151 joint_names_.resize(7);
00152 joint_names_[0] = "r_shoulder_pan_joint";
00153 joint_names_[1] = "r_shoulder_lift_joint";
00154 joint_names_[2] = "r_upper_arm_roll_joint";
00155 joint_names_[3] = "r_elbow_flex_joint";
00156 joint_names_[4] = "r_forearm_roll_joint";
00157 joint_names_[5] = "r_wrist_flex_joint";
00158 joint_names_[6] = "r_wrist_roll_joint";
00159
00160
00161
00162
00163 joint_mapping_.resize(7);
00164 joint_mapping_initialized_ = false;
00165
00166 joint_traj_.joint_names.resize(joint_names_.size());
00167 for (int i = 0; i < joint_names_.size(); ++i) {
00168 joint_traj_.joint_names[i] = joint_names_[i];
00169 }
00170 joint_traj_.points.resize(1);
00171
00172 joint_traj_.points[0].positions.resize(7);
00173 joint_traj_.points[0].velocities.resize(7);
00174 for (size_t j = 0; j < 7; ++j) {
00175 joint_traj_.points[0].velocities[j] = 0.0;
00176 }
00177
00178 joint_state_sub_ = nh_.subscribe("/joint_states", 1, &ArmMPCJoints::joint_state_callback, this);
00179 curr_joint_state_.setZero(joint_names_.size()*2);
00180
00181
00182 initialized_ = init_success;
00183 ROS_INFO("Done initializing");
00184 }
00185
00186 ~ArmMPCJoints() {
00187 delete obj_sub_;
00188 delete close_gripper_;
00189 }
00190
00191 void open_gripper(bool wait=false) {
00192 pr2_controllers_msgs::Pr2GripperCommandGoal open;
00193 open.command.position = 0.08;
00194 open.command.max_effort = -1.0;
00195
00196 close_gripper_->sendGoal(open);
00197 if (wait) {
00198 close_gripper_->waitForResult();
00199 if(close_gripper_->getState() == actionlib::SimpleClientGoalState::SUCCEEDED)
00200 ROS_INFO("The gripper opened!");
00201 else
00202 ROS_INFO("The gripper failed to open.");
00203 }
00204 }
00205
00206 void close_gripper(bool wait=false) {
00207 pr2_controllers_msgs::Pr2GripperCommandGoal close;
00208 close.command.position = 0.0;
00209 close.command.max_effort = 20.0;
00210
00211 close_gripper_->sendGoal(close);
00212 if (wait) {
00213 close_gripper_->waitForResult();
00214 if(close_gripper_->getState() == actionlib::SimpleClientGoalState::SUCCEEDED)
00215 ROS_INFO("The gripper closed!");
00216 else
00217 ROS_INFO("The gripper failed to close.");
00218 }
00219 }
00220
00221 void joint_state_callback(const sensor_msgs::JointStateConstPtr &msg){
00222 boost::mutex::scoped_lock(state_mutex_);
00223 if (!joint_mapping_initialized_) {
00224 for (int i = 0; i < joint_names_.size(); ++i) {
00225 for (int j = 0; j < msg->name.size(); ++j) {
00226 if (joint_names_[i].compare(msg->name[j]) == 0) {
00227 joint_mapping_[i] = j;
00228 }
00229 }
00230 }
00231 joint_mapping_initialized_ = true;
00232 }
00233
00234 for (int i = 0; i < joint_mapping_.size(); ++i) {
00235 curr_joint_state_(i) = msg->position[joint_mapping_[i]];
00236 curr_joint_state_(i+joint_mapping_.size()) = msg->velocity[joint_mapping_[i]];
00237 }
00238 }
00239
00240 void get_cart_and_joint_pos(VectorT &x0) {
00241 if (!joint_mapping_initialized_) {
00242 x0.setZero();
00243 return;
00244 }
00245 {
00246 boost::mutex::scoped_lock(state_mutex_);
00247
00248 x0.tail(curr_joint_state_.size()) = curr_joint_state_;
00249 }
00250 }
00251
00252
00253 void mpc_callback(const motld::TrackedObjectsConstPtr &objects) {
00254 ROS_INFO("mpc_callback called!");
00255 if (not initialized_) {
00256 ROS_WARN("mpc not initialized");
00257 return;
00258 }
00259 if (objects->name.size() < 1) {
00260 ROS_WARN("no object poses received");
00261 return;
00262 }
00263 if (objects->name.size() != objects->pose.poses.size()) {
00264 ROS_WARN("object poses and object names array sizes do not match");
00265 return;
00266 }
00267
00268 if (on_way_to_goal_ && curr_t_ +1 >= curr_traj_.x.cols()) {
00269
00270
00271 }
00272
00273 ros::Time goal_stamp = objects->pose.header.stamp;
00274 ros::Time time_stamp_now = ros::Time::now();
00275
00276 VectorT goal(3);
00277 goal(0) = objects->pose.poses[0].position.x;
00278 goal(1) = objects->pose.poses[0].position.y;
00279 goal(2) = objects->pose.poses[0].position.z;
00280
00281 if (tf_transformer_->waitForTransform(objects->pose.header.frame_id, gripper_frame_, goal_stamp, ros::Duration(1.)) == false) {
00282 ROS_ERROR("Did not find transform between %s and %s", objects->pose.header.frame_id.c_str(), gripper_frame_.c_str());
00283 return;
00284 }
00285 ROS_INFO("Got transform!");
00286
00287 tf::StampedTransform trans;
00288 tf_transformer_->lookupTransform(objects->pose.header.frame_id, gripper_frame_, goal_stamp, trans);
00289
00290
00291 center_(0) = trans.getOrigin().x();
00292 center_(1) = trans.getOrigin().y();
00293 center_(2) = trans.getOrigin().z();
00294
00295
00296 r_arm_model_->setExternalGoal(goal);
00297
00298 VectorT x0(14);
00299 get_cart_and_joint_pos(x0);
00300 VectorT curr_joint_goal = x0;
00301 r_arm_model_->ik(center_, curr_joint_goal);
00302 curr_joint_goal.tail(7).setZero();
00303 r_arm_model_->setInternalGoal(curr_joint_goal);
00304
00305 center_ -= goal;
00306
00307
00308
00309
00310
00311
00312 float goal_diff = center_.norm();
00313
00314 if (goal_diff < eps_ + 0.02) {
00315 ROS_INFO("already at goal / goal reached");
00316 close_gripper();
00317 return;
00318 }
00319
00320 float goal_change = 200.;
00321 float pos_change = 100.;
00322 if (last_goal_set_) {
00323 goal_change = (last_goal_ - goal).norm();
00324 }
00325 if (last_pos_set_) {
00326 pos_change = (center_ - last_pos_).norm();
00327 }
00328 ROS_INFO("pchange: %f gchange: %f max_dist: %f", pos_change, goal_change, max_mov_dist_);
00329 if (pos_change > max_mov_dist_ || goal_change > max_mov_dist_) {
00330
00331
00332 on_way_to_goal_ = false;
00333 restart_ddp_ = true;
00334 }
00335
00336
00337
00338 int x_len = curr_traj_.x.cols();
00339
00340 if (false && on_way_to_goal_ && goal_change < eps_ && curr_t_ < x_len - 1) {
00341 curr_t_++;
00342 } else {
00343
00344 if (restart_ddp_) {
00345 ROS_INFO("restarting ddp");
00346 u0_.setZero();
00347 DDP(x0, u0_, params, 100, curr_traj_);
00348 restart_ddp_ = false;
00349 } else {
00350 u0_ = curr_traj_.u;
00351 DDP(x0, u0_, params, 10, curr_traj_);
00352 }
00353 curr_t_ = 1;
00354 }
00355
00356
00357 const VectorT &next = curr_traj_.x.col(curr_t_);
00358
00359 joint_traj_.header.stamp = time_stamp_now;
00360
00361
00362 for (int i = 0; i < joint_traj_.points[0].positions.size(); ++i) {
00363 joint_traj_.points[0].positions[i] = next(i);
00364 joint_traj_.points[0].velocities[i] = next(i + 7);
00365 }
00366 joint_traj_.points[0].time_from_start = ros::Duration(0.8f);
00367
00368
00369 last_goal_ = goal;
00370 last_goal_set_ = true;
00371
00372
00373 last_pos_ = center_;
00374 last_pos_set_ = true;
00375
00376
00377 if (curr_traj_.x.col(curr_traj_.x.cols()-1).norm() < eps_) {
00378 on_way_to_goal_ = true;
00379 } else {
00380 on_way_to_goal_ = false;
00381 }
00382
00383
00384 joint_pose_pub_.publish(joint_traj_);
00385 open_gripper();
00386
00387
00388 if (marker_pub_.getNumSubscribers() > 0) {
00389 ROS_INFO("mpc sending markers");
00390 visualization_msgs::MarkerArray markers;
00391 VectorT tmp_pos(3);
00392 int id = 0;
00393 for (int p = 0; p < curr_traj_.x.cols(); ++p) {
00394 r_arm_model_->fk(curr_traj_.x.col(p), tmp_pos);
00395 visualization_msgs::Marker marker;
00396 marker.header.frame_id = objects->pose.header.frame_id;
00397 marker.header.stamp = ros::Time::now();
00398 marker.type = marker.SPHERE;
00399 marker.action = marker.ADD;
00400 marker.id = id;
00401 marker.scale.x = 0.02;
00402 marker.scale.y = 0.02;
00403 marker.scale.z = 0.02;
00404 marker.color.a = 1.0;
00405 marker.color.r = 1.0;
00406 marker.color.g = 1.0;
00407 marker.color.b = 0.0;
00408 marker.pose.orientation.w = 1.0;
00409 marker.pose.position.x = tmp_pos(0) + goal(0);
00410 marker.pose.position.y = tmp_pos(1) + goal(1);
00411 marker.pose.position.z = tmp_pos(2) + goal(2);
00412 markers.markers.push_back(marker);
00413 ++id;
00414 }
00415 marker_pub_.publish(markers);
00416 }
00417 }
00418 };
00419
00420 int main(int argc, char *argv[]) {
00421 ros::init(argc, argv, "mpc_node");
00422 ros::NodeHandle n;
00423 ArmMPCJoints mpc_client(n);
00424 while (ros::ok()) {
00425 ros::spinOnce();
00426 }
00427 }