00001
00002 #ifndef FactoredPomdp_H
00003 #define FactoredPomdp_H
00004
00005 #include <iostream>
00006 #include <fstream>
00007 #include <sstream>
00008 #include <ostream>
00009 #include "tinyxml.h"
00010 #include <string>
00011 #include <vector>
00012 #include <algorithm>
00013 #include <iterator>
00014 #include <stdlib.h>
00015 #include <map>
00016 #include <set>
00017 #include <iomanip>
00018
00019 #include "CPTimer.h"
00020 #include <signal.h>
00021 #include "MathLib.h"
00022 #include "State.h"
00023 #include "ObsAct.h"
00024
00025 #include "PreCMatrix.h"
00026 #include "POMDPLayer.h"
00027 #include "Function.h"
00028 #include "SparseTable.h"
00029 #include "MObject.h"
00030 #define REWVALUEENUM "rew"
00031
00032
00033
00034
00035
00036
00037 #define CLEAN_EXIT 0
00038 #define XML_INPUT_ERROR -1
00039
00040 #ifdef DEBUG_LOG_ON
00041
00042
00043 #define DEBUGREADXMLINPUT 1
00044
00045 #define DEBUGFASTCONVERSIONSTATE 1
00046 #define DEBUGFASTCONVERSIONOBS 1
00047 #define DEBUGFASTCONVERSIONREWARD 1
00048 #define DEBUGFASTCONVERSIONOTHERS 1
00049 #define DEBUGFASTCONVERSIONBELIEF 1
00050
00051 #define DEBUGFACTOREDCONVERSIONSTATE 1
00052 #define DEBUGFACTOREDCONVERSIONOBS 1
00053 #define DEBUGFACTOREDCONVERSIONREWARD 1
00054 #define DEBUGFACTOREDCONVERSIONTERMINAL 1
00055 #define DEBUGFACTOREDCONVERSIONOTHERS 1
00056 #define DEBUGFACTOREDCONVERSIONBELIEF 1
00057 #else
00058 #define DEBUGREADXMLINPUT 0
00059
00060 #define DEBUGFASTCONVERSIONSTATE 0
00061 #define DEBUGFASTCONVERSIONOBS 0
00062 #define DEBUGFASTCONVERSIONREWARD 0
00063 #define DEBUGFASTCONVERSIONOTHERS 0
00064 #define DEBUGFASTCONVERSIONBELIEF 0
00065
00066 #define DEBUGFACTOREDCONVERSIONSTATE 0
00067 #define DEBUGFACTOREDCONVERSIONOBS 0
00068 #define DEBUGFACTOREDCONVERSIONREWARD 0
00069 #define DEBUGFACTOREDCONVERSIONTERMINAL 0
00070 #define DEBUGFACTOREDCONVERSIONOTHERS 0
00071 #define DEBUGFACTOREDCONVERSIONBELIEF 0
00072 #endif
00073
00074 using std::cin;
00075 using std::cout;
00076 using std::endl;
00077
00078 typedef vector<vector<PreSparseMatrix> > vvPreSparseMatrix;
00079 typedef vector<vector<SharedPointer<SparseMatrix> > > vvSparseMatrix;
00080
00081 namespace momdp {
00082
00083 class FactoredPomdp {
00084 friend class MOMDP;
00085
00086 private:
00088
00089 string filename;
00090
00091 double discount;
00092 vector<State> stateList;
00093 vector<ObsAct> observationList;
00094 vector<ObsAct> actionList;
00095 vector<ObsAct> rewardList;
00096 vector<ObsAct> terminalStateRewardList;
00097
00098 map<string, StateObsAct*> mymap;
00099 vector<Function> stateFunctionList;
00100 vector<Function> beliefFunctionList;
00101 vector<Function> processedBeliefFunctionList;
00102 vector<Function> observFunctionList;
00103 vector<Function> actionFunctionList;
00104 vector<Function> rewardFunctionList;
00105 vector<Function> terminalStateRewardFunctionList;
00106
00107 map<string, Function*> mapFunc;
00108 vector<string> canonicalNamePrev;
00109 vector<string> canonicalNameCurr;
00110 vector<string> canonicalNameForTerminal;
00111
00113
00114 vector<vector<vector<double> > > oldStateTransition;
00115 vector<vector<vector<double> > > oldObservTransition;
00116 vector<vector<double> > oldRewardTransition;
00117
00118 public:
00119
00120
00121 POMDPLayer layer;
00122
00126
00127 FactoredPomdp();
00128 FactoredPomdp(string f);
00129 ~FactoredPomdp();
00130
00131 void Tokenize(const string& str, vector<string>& tokens,
00132 const string& delimiters = " ");
00133
00134 State createState(TiXmlElement* varChild);
00135 ObsAct createObservation(TiXmlElement* varChild);
00136 ObsAct createAction(TiXmlElement* varChild);
00137
00138 void createInitialBelief(TiXmlElement* varChild);
00139 const bool checkStateNameExists(string stateName) const;
00140 const bool checkActionNameExists(string actionName) const;
00141 const bool checkObsNameExists(string obsName) const;
00142 const bool checkRewardNameExists(string rewardName) const;
00143 const bool checkTerminalNameExists(string rewardName) const;
00144 const bool checkParentNameExists(string parentName) const;
00145 const bool checkInstanceMatchesParent(string instanceName, string parent);
00146 const State& findState(string varName);
00147 const bool checkIdentityIsValid(vector<string> tokens) const;
00148 bool checkFunctionProbabilities(Function* f, TiXmlElement* xmlnode, string whichFunction);
00149 Function createFunction(TiXmlElement* pFunction, int whichFunction);
00150 void printXMLErrorHeader(TiXmlBase* base);
00151 void printXMLWarningHeader(TiXmlBase* base);
00152
00153
00154
00155 const bool checkRewardFunctionHasOnlyPreviousTimeSliceAndAction() const;
00156 const set<string> getRewardFunctionCurrentTimeSliceVars(Function* rewardFunc) ;
00157 SharedPointer<SparseTable> removeRedundantUIsFromReward(SharedPointer<SparseTable> st);
00158 SharedPointer<SparseTable> combineSimilarEntriesInReward(SharedPointer<SparseTable> st);
00159 SharedPointer<SparseTable> preprocessRewardTable();
00160 void preprocessRewardFunction();
00161
00162
00163
00164 unsigned int start();
00165 const int checkProblemType();
00166
00167
00168
00169 void sortStateList();
00170
00171 SharedPointer<SparseTable> finalRewardTable;
00172 SharedPointer<SparseTable> finalStateTable;
00173 SharedPointer<SparseTable> finalBeliefTable;
00174 SharedPointer<SparseTable> mergeTables(vector<Function>* functionList, int whichFunction, ofstream& debugfile, bool printDebugFile);
00175
00176 SharedPointer<SparseTable> mergeSparseTables(vector<SharedPointer<SparseTable> > stList, int whichFunction, ofstream& debugfile, bool printDebugFile);
00177 void preprocessBeliefTables(ofstream& debugfile, bool printDebugFile);
00178 bool preprocessBeliefTablesDone;
00179 void mergeBeliefTables(ofstream& debugfile, bool printDebugFile);
00180
00181
00182 const void defineCanonicalNames();
00183
00184 SharedPointer<SparseTable> expandObsRewSparseTable(SharedPointer<SparseTable> st, int whichFunction);
00185
00186 bool validateModel(Function sf, string& info);
00187 const bool isPreviousTimeSlice(string name) const;
00188 const bool isCurrentTimeSlice(string name) const;
00189
00190
00191
00192
00193 map<string, int> actionStringIndexMap;
00194 int numActions;
00195 void mapActionsToValue();
00196
00197 map<string, int> observationStringIndexMap;
00198 int numObservations;
00199 void mapObservationsToValue();
00200 map<int, int> observationUIIndexMap;
00201 void mapObservationsUIsToValue(SharedPointer<SparseTable> st);
00202
00203 map<string, int> positionStringIndexMap;
00204
00205 map<int, int> fastPositionCIIndexMap;
00206 map<int, int> fastPositionUIIndexMap;
00207
00208 map<string, int> getStartXYVarValues();
00209 map<string, int> getStartActionXYVarValues();
00210 bool getNextActionXYVarValues(map<string, int> &curValues, int &action, int &stateX, int& stateY);
00211 bool getNextXYVarValues(map<string, int> &curValues,int &stateX, int &stateY);
00212 bool getNextActionXXpYVarValues(map<string, int> &curValues, int &action, int &stateX, int &stateXp, int &stateY);
00213 map<string, int> getStartSVarValues();
00214 map<string, int> getStartActionSVarValues();
00215 bool getNextActionSVarValues(map<string, int> &curValues, int &action, int &stateNum);
00216 bool getNextSVarValues(map<string, int> &curValues,int &stateNum);
00217
00218 int numMergedStates;
00219 void resortFastStateTables(ofstream& debugfile, bool printDebugFile);
00220 void mapFastStatesToValue();
00221 void mapFastIndexesToValues(SharedPointer<SparseTable> st);
00222 void convertFast();
00223 void convertFastStateTrans();
00224 void convertFastObsTrans();
00225 void convertFastNoObservationsVariables();
00226 void convertFastRewardTrans();
00227 void convertFastBelief();
00228
00229 void convertFastVariables();
00230 SharedPointer<SparseTable> reduceUnmatchedCIWithUI(SharedPointer<SparseTable> st, ofstream& debugfile, bool printDebugFile);
00231
00232 map<string, int> positionXStringIndexMap;
00233 map<string, int> positionYStringIndexMap;
00234
00235
00236
00237 map<int, int> factoredPositionCIIndexMap;
00238 map<int, int> factoredPositionUIIndexMap;
00239 int numMergedStatesX;
00240 int numMergedStatesY;
00241
00242 vector<vector<PreSparseMatrix> > createVvPreSparseMatrix(int a, int b, int c, int d);
00243 void printSparseMatrix(string title, vector<vector<SharedPointer<SparseMatrix> > > M, ofstream& debugfile);
00244
00245 void resortFactoredStateTables(ofstream& debugfile, bool printDebugFile, const int MIXEDTYPE);
00246 void mapFactoredStatesToValue();
00247 void mapFactoredCIsToValue(SharedPointer<SparseTable> st);
00248 void mapFactoredStateUIsToValue(SharedPointer<SparseTable> st);
00249 void mapFactoredBeliefIndexesToValue(SharedPointer<SparseTable> st);
00250 void convertFactored();
00251 void convertFactoredReparam();
00252 void convertFactoredStateTrans();
00253 void convertFactoredStateReparamTrans();
00254 void expandFactoredStateTable(SharedPointer<SparseTable> sf);
00255 void convertFactoredObsTrans();
00256 void convertFactoredNoObservationsVariables();
00257 void convertFactoredRewardTrans();
00258 void convertFactoredTerminalStateReward();
00259
00260 void convertFactoredBeliefCommon(ofstream& debugfile, bool printDebugFile);
00261 void convertFactoredBelief();
00262 void convertFactoredBeliefReparam();
00263 void convertFactoredVariables();
00264
00265
00266 vector<vector<SharedPointer<SparseMatrix> > > helperPreSparseMatrixToSparseMatrix(vector< vector<PreSparseMatrix> > precm);
00267 vector<vector<vector<SharedPointer<SparseMatrix> > > > helperPreSparseMatrixToSparseMatrix(vector< vector<vector<PreSparseMatrix> > >precm);
00268
00269
00270 map<string, string> getFactoredObservedStatesSymbols(int stateNum);
00271 map<string, string> getFactoredUnobservedStatesSymbols(int stateNum);
00272 map<string, string> getActionsSymbols(int actionNum);
00273 map<string, string> getObservationsSymbols(int observationNum);
00274
00275
00276 };
00277 }
00278
00279 #endif