Go to the documentation of this file.00001
00025 #include <float.h>
00026
00027 #include <madp/NullPlanner.h>
00028
00029 #include <madp/MADPComponentFactoredStates.h>
00030
00031 #include <mdm_library/controller_pomdp.h>
00032 #include <mdm_library/ActionSymbol.h>
00033 #include <std_msgs/Float32.h>
00034
00035
00036
00037 using namespace std;
00038 using namespace mdm_library;
00039
00040
00041
00042 ControllerPOMDP::
00043 ControllerPOMDP ( const string& problem_file,
00044 const CONTROLLER_STATUS initial_status ) :
00045 ControlLayerBase ( initial_status ),
00046 loader_ ( new DecPOMDPLoader ( problem_file ) ),
00047 belief_ (),
00048 prev_action_ ( 0 ),
00049 observation_sub_ ( nh_.subscribe ( "observation", 10, &ControllerPOMDP::observationCallback, this ) ),
00050 ext_belief_estimate_sub_ ( nh_.subscribe ( "ext_belief_estimate", 10, &ControllerPOMDP::extBeliefCallback, this ) ),
00051 isd_sub_ ( nh_.subscribe ( "initial_state_distribution", 10, &ControllerPOMDP::isdCallback, this ) ),
00052 current_belief_pub_ ( nh_.advertise<BeliefStateInfo> ( "current_belief", 1, false ) ),
00053 action_pub_ ( nh_.advertise<ActionSymbol> ( "action", 0, true ) ),
00054 exp_reward_pub_ ( nh_.advertise<std_msgs::Float32> ( "reward", 0, true ) )
00055 {}
00056
00057
00058
00059 void
00060 ControllerPOMDP::
00061 act ( const uint32_t observation )
00062 {
00063 if ( getStatus() == STOPPED )
00064 {
00065 return;
00066 }
00067
00068 double eta = 0;
00069
00070 if ( getDecisionEpisode() > 0 )
00071 {
00072 eta = belief_->Update ( * ( loader_->GetDecPOMDP() ), prev_action_, observation );
00073 }
00074
00075 uint32_t action = INT_MAX;
00076 double q, v = -DBL_MAX;
00077 for ( size_t a = 0; a < getNumberOfActions(); a++ )
00078 {
00079 q = Q_->GetQ ( *belief_, a );
00080 if ( q > v )
00081 {
00082 v = q;
00083 action = a;
00084 }
00085 }
00086
00087 if ( action == INT_MAX )
00088 {
00089 ROS_ERROR_STREAM ( "ControllerPOMDP:: Could not get joint action for observation " << observation
00090 << " at belief state: " << endl << belief_->SoftPrint() );
00091 abort();
00092 }
00093 prev_action_ = action;
00094 publishAction ( action );
00095 publishExpectedReward ( action );
00096 publishCurrentBelief ();
00097 ROS_INFO_STREAM ( "ControllerPOMDP:: Episode " << getDecisionEpisode() << " - Action: "
00098 << action << " (" << loader_->GetDecPOMDP()->GetJointAction ( action )->SoftPrint()
00099 << ") - Observation: " << observation << " (" << loader_->GetDecPOMDP()->GetJointObservation ( observation )->SoftPrint()
00100 << ") - P(b|a,o): " << eta );
00101
00102 if ( getDecisionEpisode() > 0 && eta <= Globals::PROB_PRECISION )
00103 {
00104 ROS_WARN ( "ControllerPOMDP:: Impossible action-observation trace! You should check your model for the probabilities of the preceding transitions and observations!" );
00105 }
00106
00107 incrementDecisionEpisode();
00108 }
00109
00110
00111
00112 void
00113 ControllerPOMDP::
00114 publishAction ( uint32_t a )
00115 {
00116 ActionSymbol aInfo;
00117 aInfo.action_symbol = a;
00118 aInfo.decision_episode = getDecisionEpisode();
00119 action_pub_.publish ( aInfo );
00120 }
00121
00122
00123
00124 void
00125 ControllerPOMDP::
00126 publishExpectedReward ( uint32_t a )
00127 {
00128 std_msgs::Float32 reward;
00129 vector<double> r_vec;
00130 for ( uint32_t s = 0; s < getNumberOfStates(); s++ )
00131 {
00132 r_vec.push_back ( loader_->GetDecPOMDP()->GetReward ( s, a ) );
00133 }
00134
00135 reward.data = belief_->InnerProduct ( r_vec );
00136 exp_reward_pub_.publish ( reward );
00137 }
00138
00139
00140
00141 void
00142 ControllerPOMDP::
00143 publishCurrentBelief ()
00144 {
00145 BeliefStateInfo b;
00146 for ( size_t i = 0; i < belief_->Size(); i++ )
00147 {
00148 b.belief.push_back ( belief_->Get ( i ) );
00149 }
00150 current_belief_pub_.publish ( b );
00151 }
00152
00153
00154
00155 size_t
00156 ControllerPOMDP::
00157 getNumberOfActions ()
00158 {
00159 return loader_->GetDecPOMDP()->GetNrJointActions();
00160 }
00161
00162
00163
00164 size_t
00165 ControllerPOMDP::
00166 getNumberOfStates ()
00167 {
00168 return loader_->GetDecPOMDP()->GetNrStates();
00169 }
00170
00171
00172
00173 size_t
00174 ControllerPOMDP::
00175 getNumberOfObservations ()
00176 {
00177 return loader_->GetDecPOMDP()->GetNrJointObservations();
00178 }
00179
00180
00181
00182 void
00183 ControllerPOMDP::
00184 extBeliefCallback ( const BeliefStateInfoConstPtr& msg )
00185 {
00186 belief_->Set ( msg->belief );
00187 if ( ! ( belief_->SanityCheck() ) )
00188 {
00189 normalizeBelief ( belief_ );
00190 }
00191 }
00192
00193
00194
00195 void
00196 ControllerPOMDP::
00197 isdCallback ( const FactoredDistributionConstPtr& msg )
00198 {
00199 if ( ISD_ == 0 )
00200 {
00201 MADPComponentFactoredStates state_factor_description;
00202 for ( size_t k = 0; k < msg->factors.size(); k++ )
00203 {
00204 state_factor_description.AddStateFactor();
00205 for ( size_t x = 0; x < msg->factors[k].belief.size(); x++ )
00206 {
00207 state_factor_description.AddStateFactorValue ( k );
00208 }
00209 }
00210 state_factor_description.SetInitialized ( true );
00211 if ( state_factor_description.GetNrStates() != getNumberOfStates() )
00212 {
00213 ROS_WARN_STREAM ( "ControllerPOMDP:: Received an initial state distribution with an incorrect number of states ("
00214 << state_factor_description.GetNrStates() << ", should be " << getNumberOfStates() << ") . Ignoring." );
00215 return;
00216 }
00217 ISD_ = boost::shared_ptr<FSDist_COF> ( new FSDist_COF ( state_factor_description ) );
00218 }
00219
00220 for ( size_t k = 0; k < msg->factors.size(); k++ )
00221 {
00222 for ( size_t x = 0; x < msg->factors[k].belief.size(); x++ )
00223 {
00224 double p = msg->factors[k].belief[x];
00225 ISD_->SetProbability ( k, x, p );
00226 }
00227 }
00228
00229 ISD_->SanityCheck();
00230 }
00231
00232
00233
00234 void
00235 ControllerPOMDP::
00236 normalizeBelief ( boost::shared_ptr<JointBeliefInterface> b )
00237 {
00238 vector<double> one_vec ( getNumberOfStates(), 1.0 );
00239 float sum = b->InnerProduct ( one_vec );
00240 if ( sum > 0 )
00241 {
00242 for ( size_t s = 0; s < getNumberOfStates(); s++ )
00243 {
00244 float p = b->Get ( s ) / sum;
00245 b->Set ( s, p );
00246 }
00247 }
00248 else
00249 {
00250 ROS_WARN ( "ControllerPOMDP:: Failed to normalize. Setting belief to default ISD." );
00251 b->Set ( * ( loader_->GetDecPOMDP()->GetISD() ) );
00252 }
00253 }