decpomdp_loader.cpp
Go to the documentation of this file.
00001 
00027 #ifdef HAVE_MADP
00028 
00029 #include <madp/NullPlanner.h>
00030 #include <madp/MADPParser.h>
00031 #include <madp/StateFactorDiscrete.h>
00032 
00033 
00034 
00035 #endif
00036 
00037 #include <sys/stat.h>
00038 
00039 #include <mdm_library/SymbolMetadata.h>
00040 #include <mdm_library/FactoredSymbolMetadata.h>
00041 #include <mdm_library/FactoredDistribution.h>
00042 #include <mdm_library/decpomdp_loader.h>
00043 
00044 
00045 
00046 using namespace ros;
00047 using namespace std;
00048 using namespace mdm_library;
00049 
00050 
00051 
00052 #ifdef HAVE_MADP
00053 
00054 
00055 
00056 DecPOMDPLoader::
00057 DecPOMDPLoader ( const string& problem_file ) :
00058     action_metadata_pub_ ( nh_.advertise<FactoredSymbolMetadata> ( "action_metadata", 1, true ) ),
00059     state_metadata_pub_ ( nh_.advertise<FactoredSymbolMetadata> ( "state_metadata", 1, true ) ),
00060     observation_metadata_pub_ ( nh_.advertise<FactoredSymbolMetadata> ( "observation_metadata", 1, true ) ),
00061     initial_state_distribution_pub_ ( nh_.advertise<FactoredDistribution> ( "initial_state_distribution", 1, false ) )
00062 {
00063     string ext = problem_file.substr ( problem_file.find_last_of ( "." ) );
00064 
00065     try
00066     {
00067         if ( ext == ".pgmx" )
00068         {
00069             boost::shared_ptr<FactoredDecPOMDPDiscrete> f ( new FactoredDecPOMDPDiscrete ( "", "", problem_file ) );
00070             MADPParser parser ( f.get() );
00071             bool isSparse = false;
00072             bool cacheFlatModels = false;
00073             int marginalizeStateFactor;
00074 
00075             NodeHandle private_nh ( "~" );
00076             private_nh.getParam ( "is_sparse", isSparse );
00077             private_nh.getParam ( "cache_flat_models", cacheFlatModels );
00078 
00079             if ( private_nh.getParam ( "marginalize", marginalizeStateFactor ) )
00080             {
00081                 f->MarginalizeStateFactor ( marginalizeStateFactor, isSparse ); //TODO: NOTE: This is necessary for standard Perseus, but needs to be faster.
00082             }
00083             else if ( cacheFlatModels )
00084             {
00085                 f->CacheFlatModels ( isSparse );
00086             }
00087             publishStateMetadata ( f );
00088             publishInitialStateDistribution ( f );
00089             decpomdp_ = f;
00090         }
00091         else if ( ext == ".dpomdp" )
00092         {
00093             boost::shared_ptr<DecPOMDPDiscrete> d ( new DecPOMDPDiscrete ( "", "", problem_file ) );
00094             MADPParser parser ( d.get() );
00095             publishStateMetadata ( d );
00096             publishInitialStateDistribution ( d );
00097             decpomdp_ = d;
00098         }
00099         else
00100         {
00101             ROS_ERROR_STREAM ( "Unsupported model format \"" << ext << "\"" );
00102             abort();
00103         }
00104         publishActionMetadata();
00105         publishObservationMetadata ();
00106     }
00107     catch ( E& e )
00108     {
00109         e.Print();
00110         abort();
00111     }
00112 }
00113 
00114 
00115 
00116 void
00117 DecPOMDPLoader::
00118 publishActionMetadata ()
00119 {
00120     FactoredSymbolMetadata team_metadata;
00121     for ( uint32_t ag = 0; ag < decpomdp_->GetNrAgents(); ag++ )
00122     {
00123         SymbolMetadata ag_metadata;
00124         uint32_t nr_actions = decpomdp_->GetNrActions ( ag );
00125         ag_metadata.number_of_symbols = nr_actions;
00126         for ( uint32_t action = 0; action < nr_actions; action++ )
00127         {
00128             ag_metadata.symbol_names.push_back ( decpomdp_->GetAction ( ag, action )->GetName() );
00129         }
00130         team_metadata.factors.push_back ( ag_metadata );
00131     }
00132     action_metadata_pub_.publish ( team_metadata );
00133 }
00134 
00135 
00136 
00137 void
00138 DecPOMDPLoader::
00139 publishStateMetadata ( boost::shared_ptr<FactoredDecPOMDPDiscrete> f )
00140 {
00141     FactoredSymbolMetadata state_metadata;
00142     for ( uint32_t k = 0; k < f->GetNrStateFactors(); k++ )
00143     {
00144         SymbolMetadata factor_metadata;
00145         const StateFactorDiscrete* sf = f->GetStateFactorDiscrete ( k );
00146         state_metadata.factor_names.push_back ( sf->GetName() );
00147         for ( uint32_t j = 0; j < f->GetNrValuesForFactor ( k ); j++ )
00148         {
00149             factor_metadata.symbol_names.push_back ( sf->GetStateFactorValue ( j ) );
00150         }
00151         state_metadata.factors.push_back ( factor_metadata );
00152     }
00153     state_metadata_pub_.publish ( state_metadata );
00154 }
00155 
00156 
00157 
00158 void
00159 DecPOMDPLoader::
00160 publishStateMetadata ( boost::shared_ptr<DecPOMDPDiscrete> d )
00161 {
00162     FactoredSymbolMetadata state_metadata;
00163     SymbolMetadata symbol_metadata;
00164     state_metadata.factor_names.push_back ( "Joint State" );
00165     for ( uint32_t k = 0; k < d->GetNrStates (); k++ )
00166     {
00167         symbol_metadata.symbol_names.push_back ( d->GetState ( k )->GetName() );
00168     }
00169     symbol_metadata.number_of_symbols = d->GetNrStates();
00170     state_metadata.factors.push_back ( symbol_metadata );
00171     state_metadata_pub_.publish ( state_metadata );
00172 }
00173 
00174 
00175 
00176 void
00177 DecPOMDPLoader::
00178 publishObservationMetadata ()
00179 {
00181     FactoredSymbolMetadata team_metadata;
00182     for ( uint32_t ag = 0; ag < decpomdp_->GetNrAgents(); ag++ )
00183     {
00184         SymbolMetadata ag_metadata;
00185         uint32_t nr_observations = decpomdp_->GetNrObservations ( ag );
00186         ag_metadata.number_of_symbols = nr_observations;
00187         for ( uint32_t observation = 0; observation < nr_observations; observation++ )
00188         {
00189             ag_metadata.symbol_names.push_back ( decpomdp_->GetObservation ( ag, observation )->GetName() );
00190         }
00191         team_metadata.factors.push_back ( ag_metadata );
00192     }
00193     observation_metadata_pub_.publish ( team_metadata );
00194 }
00195 
00196 
00197 
00198 void
00199 DecPOMDPLoader::
00200 publishInitialStateDistribution ( boost::shared_ptr<FactoredDecPOMDPDiscrete> f )
00201 {
00202     FSDist_COF* fsd = ( FSDist_COF* ) f->GetFactoredISD();
00203     FactoredDistribution fdist;
00204     for ( size_t i = 0; i < f->GetNrStateFactors(); i++ )
00205     {
00206         BeliefStateInfo b;
00207         for ( size_t j = 0; j < f->GetNrValuesForFactor ( i ); j++ )
00208         {
00209             b.belief.push_back ( fsd->GetReferrence ( i, j ) );
00210         }
00211         fdist.factors.push_back ( b );
00212     }
00213     initial_state_distribution_pub_.publish ( fdist );
00214 }
00215 
00216 
00217 
00218 void
00219 DecPOMDPLoader::
00220 publishInitialStateDistribution ( boost::shared_ptr<DecPOMDPDiscrete> d )
00221 {
00222     vector<double> isd = d->GetISD()->ToVectorOfDoubles();
00223     FactoredDistribution fdist;
00224     BeliefStateInfo b;
00225     b.belief = isd;
00226     fdist.factors.push_back ( b );
00227     initial_state_distribution_pub_.publish ( b );
00228 }
00229 
00230 
00231 
00232 const boost::shared_ptr<DecPOMDPDiscreteInterface>
00233 DecPOMDPLoader::
00234 GetDecPOMDP()
00235 {
00236     return decpomdp_;
00237 }
00238 
00239 
00240 
00241 #else //NO MADP -- This constructor throws an error
00242 
00243 
00244 
00245 DecPOMDPLoader::
00246 DecPOMDPLoader ( const string& problem_file )
00247 {
00248     ROS_ERROR_STREAM ( "MDM requires MADP to parse problem files." );
00249     ROS_ERROR_STREAM ( "Please install MADP and recompile MDM if you require this functionality." );
00250     abort();
00251 }
00252 
00253 
00254 
00255 #endif


mdm_library
Author(s): Joao Messias
autogenerated on Wed Aug 26 2015 12:28:41