00001 #include "pomdp.h"
00002
00003 Pomdp::Pomdp(){
00004
00005 int numstates, numactions, numobservations;
00006
00007 this->actioncount = 0;
00008
00009 std::string pomdpfile("data/pomdpfile.txt");
00010 if(node_handle_.getParam("pomdpfile", pomdpfile)) {
00011 ROS_INFO("pomdpfile global path found, using %s", pomdpfile.c_str());
00012 }else{
00013 ROS_WARN("pomdpfile global path not found, using default %s", pomdpfile.c_str());
00014 }
00015
00016 srand( time(NULL) );
00017 Policy::getDimensions(pomdpfile,&numstates,&numactions,&numobservations, 1);
00018 myPolicy = new Policy(pomdpfile,numactions,numstates,numobservations);
00019
00020
00021 this->execution_flag = false;
00022
00023
00024
00025
00026 this->focused_obj_label_publisher = this->node_handle_.advertise<std_msgs::Int32>("/planner/focused_obj_label", 5);
00027 this->belief_summary_publisher = this->node_handle_.advertise<iri_wam_common_msgs::belief_summary>("/planner/belief_summary", 5);
00028 this->belief_publisher = this->node_handle_.advertise<iri_wam_common_msgs::belief>("/planner/belief", 5);
00029
00030
00031
00032
00033 this->execution_flow_control_server = this->node_handle_.advertiseService("/planner/execution_flow_control", &Pomdp::execution_flow_controlCallback, this);
00034
00035
00036 obs_client = this->node_handle_.serviceClient<iri_wam_common_msgs::obs_request>("/planner/obs_client");
00037 wam_action_client = this->node_handle_.serviceClient<iri_wam_common_msgs::wamaction>("/planner/wam_action");
00038
00039
00040
00041
00042
00043 }
00044
00045 void Pomdp::mainLoop(void){
00046
00047 int action, wam_action, obs, state;
00048 int zone, hand;
00049 if(execution_flag){
00050 state = myPolicy->mostProbableState();
00051
00052 ROS_INFO("Most probable state %s with probability %f", myPolicy->domain.state2string(state).c_str(), myPolicy->getStateProbability(state)*2 );
00053
00054 publish_belief();
00055
00056 if(!myPolicy->domain.isFinal(state) || myPolicy->getStateProbability(state) < 0.45){
00057 myPolicy->logState(LOGFILE);
00058
00059 action = myPolicy->getBestAction();
00060 ROS_INFO("Best action selected %s", myPolicy->domain.action2string(action).c_str());
00061 pomdp_action2wam_action(action, wam_action, zone, hand);
00062
00063 wam_actions_client_srv.request.action = wam_action;
00064 wam_actions_client_srv.request.zone = zone;
00065 wam_actions_client_srv.request.hand = hand;
00066 if(wam_action_client.call(wam_actions_client_srv))
00067 {
00068 ROS_INFO("Service result: %d", wam_actions_client_srv.response.success);
00069 } else {
00070 ROS_ERROR("Failed to call service wam actions service");
00071 return;
00072 }
00073
00074
00075 if(obs_client.call(obs_request_srv))
00076 {
00077 ROS_INFO("Service result: %d objects, %d on A, %d on B", obs_request_srv.response.num_objects,obs_request_srv.response.num_objectsA,obs_request_srv.response.num_objectsB);
00078 } else {
00079 ROS_ERROR("Failed to call service observation client");
00080 return;
00081 }
00082 obs = myPolicy->domain.get_observation(obs_request_srv.response.num_objectsA, obs_request_srv.response.num_objectsB, MAXNUMB);
00083 ROS_INFO("Pomdp equivalent Observation is %s (%d)", myPolicy->domain.obs2string(obs).c_str(),obs);
00084
00085
00086
00087 myPolicy->transformState(action,obs);
00088 ROS_INFO("State transformed");
00089
00090 this->ObjLabel_msg.data = 1;
00091
00092 focused_obj_label_publisher.publish(this->ObjLabel_msg);
00093 }else{
00094
00095 ROS_INFO("Final state reached with probability %f", myPolicy->getStateProbability(myPolicy->mostProbableState())*2);
00096 myPolicy->logState(LOGFILE);
00097
00098 sleep(5);
00099 }
00100 }else{
00101
00102 ROS_INFO("Awaiting trigger");
00103 sleep(5);
00104 }
00105 }
00106
00107 bool Pomdp::execution_flow_controlCallback(std_srvs::Empty::Request &req, std_srvs::Empty::Response &res){
00108
00109
00110 this->execution_flag = !this->execution_flag;
00111 return true;
00112 }
00113
00114
00115 void Pomdp::pomdp_action2wam_action(int pomdp_action, int& wam_action, int& zone, int& hand){
00116 wam_action = pomdp_action;
00117 if(wam_action < 5)
00118 zone = AZONEMSG;
00119 else
00120 zone = BZONEMSG;
00121 wam_action = pomdp_action%5;
00122
00123 switch(wam_action){
00124 case 0:
00125 wam_action = TAKEHIGH;
00126 hand = STRAIGHT;
00127 break;
00128 case 1:
00129 wam_action = TAKELOW;
00130 hand = STRAIGHT;
00131 break;
00132 case 2:
00133 wam_action = TAKEHIGH;
00134 hand = ISOMETRIC;
00135 break;
00136 case 3:
00137 wam_action = TAKELOW;
00138 hand = ISOMETRIC;
00139 break;
00140 case 4:
00141 wam_action = MOVEAWAY;
00142 hand = STRAIGHT;
00143 break;
00144 default:
00145 ROS_ERROR("Unrecognized pomdp_action2wam_action equivalent");
00146 break;
00147 }
00148
00149 }
00150
00151 void Pomdp::publish_belief(){
00152 std::list<std::pair<double,int> > state_probability_list;
00153 std::vector<int> indexlist;
00154 std::list<std::pair<double,int> >::iterator it;
00155 std::vector<double> objInA, objInB;
00156 double averageObjInA=0, averageObjInB=0;
00157 double uncertaintyA=0, uncertaintyB=0;
00158 int numA, numB, wrinkle, firstNumA, firstNumB;
00159
00160 ROS_INFO("Belief (top four):");
00161 myPolicy->probableStatesList(state_probability_list, indexlist, 0.05);
00162 for(int state=0; state < 4;state++){
00163 if(myPolicy->getStateProbability(indexlist[state])*2 < 0.05)
00164 break;
00165 if(indexlist[state] > indexlist.size()/2)
00166 continue;
00167 ROS_INFO("%s %f", myPolicy->domain.state2string(indexlist[state]).c_str(), myPolicy->getStateProbability(indexlist[state])*2 );
00168 }
00169
00170 myPolicy->getBeliefAverageAtPiles(objInA,objInB);
00171 for ( int i=0; i < objInA.size(); i++ )
00172 averageObjInA += i*objInA[i];
00173 for ( int i=0; i < objInB.size(); i++ )
00174 averageObjInB += i*objInB[i];
00175
00176
00177
00178 myPolicy->domain.get_num_obj(state_probability_list.front().second, firstNumA, firstNumB, MAXNUMB, wrinkle);
00179
00180 for ( it=state_probability_list.begin(); it != state_probability_list.end(); it++ ){
00181 std::pair<double,int> aux = *it;
00182 if(aux.first == 0)
00183 break;
00184 myPolicy->domain.get_num_obj(aux.second, numA, numB, MAXNUMB, wrinkle);
00185 if( numA == firstNumA)
00186 uncertaintyA += aux.first;
00187 if( numB == firstNumB)
00188 uncertaintyB += aux.first;
00189 }
00190
00191 this->belief_summary_msg.num_objects_A = averageObjInA;
00192 this->belief_summary_msg.uncertainty_A = 1 - uncertaintyA;
00193 this->belief_summary_msg.num_objects_B = averageObjInB;
00194 this->belief_summary_msg.uncertainty_B = 1 - uncertaintyB;
00195 this->belief_summary_msg.uncertainty_total = 1 - myPolicy->getStateProbability(indexlist[0])*2;
00196 ROS_INFO("Publishing belief and overall uncertainty %f", belief_summary_msg.uncertainty_total);
00197 belief_summary_publisher.publish(this->belief_summary_msg);
00198
00199
00200 this->belief_msg.states_probabilities.clear();
00201 for ( it=state_probability_list.begin(); it != state_probability_list.end(); it++ ){
00202 std::pair<double,int> aux = *it;
00203 this->belief_msg.states_probabilities.push_back(aux.first);
00204 }
00205
00206 belief_publisher.publish(this->belief_msg);
00207
00208 }
00209
00210
00211 int main(int argc,char *argv[])
00212 {
00213 ros::init(argc, argv, "pomdp");
00214 Pomdp pomdp;
00215 ros::Rate loop_rate(10);
00216 while(ros::ok()){
00217 pomdp.mainLoop();
00218 ros::spinOnce();
00219 loop_rate.sleep();
00220 }
00221 }