00001 
00007 #include <assert.h>
00008 #include <stdlib.h>
00009 #include <sstream>
00010 #include "GlobalResource.h"
00011 
00012 #ifdef _MSC_VER
00013 #else
00014 #include <unistd.h>
00015 #endif
00016 
00017 #if defined(_MSC_VER) || defined(_CYGWIN)
00018 #include "getopt.h"
00019 #else
00020 #include <getopt.h>
00021 #endif
00022 
00023 #ifdef __cplusplus
00024 extern "C"
00025 #endif
00026 {
00027         extern unsigned long GlobalMemLimit;
00028 }
00029 
00030 
00031 
00032 #include <stdio.h>
00033 #include <ctime>
00034 
00035 #include <iostream>
00036 #include <fstream>
00037 #include <string.h>
00038 #include <algorithm>
00039 
00040 #include "solverUtils.h"
00041 
00042 using namespace std;
00043 using namespace momdp;
00044 
00045 namespace momdp{
00046 
00047         
00048 
00049 
00050 
00051         static EnumEntry strategiesG[] = {
00052                 {"sarsop", S_SARSOP},
00053                 {"sarsop_bp", S_BPS},
00054                 {NULL, -1}
00055         };
00056 
00057         int getEnum(const char* key, EnumEntry* table, const char* cmdName, const char *opt)
00058         {
00059                 EnumEntry* i = table;
00060                 for (; NULL != i->key; i++) {
00061                         if (0 == strcmp(i->key, key)) {
00062                                 return i->val;
00063                         }
00064                 }
00065                 fprintf(stderr, "ERROR: invalid parameter %s for option %s\n\n", key, opt);
00066                 
00067                 exit(EXIT_FAILURE);
00068         }
00069 
00070         bool endsWith(const std::string& s,
00071                 const std::string& suffix)
00072         {
00073                 if (s.size() < suffix.size()) return false;
00074                 return (s.substr(s.size() - suffix.size()) == suffix);
00075         }
00076 
00077         
00078 
00079 
00080 
00081         SolverParams::SolverParams(void)
00082         {
00083                 MDPSolution = false;
00084                 QMDPSolution = false;
00085                 FIBSolution = false;
00086                 useLookahead = true;
00087                 graphDepth = 0; 
00088                 graphMaxBranch = 0; 
00089                 graphProbThreshold = 0;
00090                 simLen = -1;
00091                 simNum = -1;
00092                 seed = time(0);
00093                 
00094                 
00095                 memoryLimit = 0; 
00096                 strategy = S_SARSOP;
00097                 
00098                 targetPrecision = 1e-3;
00099                 useFastParser = false;
00100                 doConvertPOMDP = false;
00101                 outPolicyFileName = "";
00102                 interval = -1;
00103                 timeoutSeconds = -1;
00104                 delta = 0.1;
00105                 randomizationBP = false;
00106                 overPruneThreshold = 50.0;
00107                 lowerPruneThreshold = 5.0;
00108                 dynamicDeltaPercentageMode = false;
00109                 BP_IMPROVEMENT_CONSTANT = 0.5;
00110 
00111                 targetTrials = 0;
00112                 dumpPolicyTrace = false;
00113                 dumpPolicyTraceTime = false;
00114 
00115                 outPolicyFileName = "out.policy"; 
00116                 dumpData = false;
00117 
00118 
00119 
00120         }
00121 
00122         void SolverParams::setStrategy(const char* strategyName)
00123         {
00124                 strategy = getEnum(strategyName, strategiesG, cmdName, "--search");
00125         }
00126 
00127         void SolverParams::inferMissingValues(void)
00128         {
00129                 
00130         }
00131 
00132 
00133 
00134         bool SolverParams::parseCommandLineOption(int argc, char **argv, SolverParams& p)
00135         {
00136                 static char shortOptions[] = "hp:t:v:fo:";
00137                 static struct option longOptions[]={
00138 
00139         
00140                 {"help",                        0,NULL,'h'}, 
00141                 {"version",                     0,NULL,'V'}, 
00142                 {"fast",                        0,NULL,'f'}, 
00143                 {"memory",                      1,NULL,'m'}, 
00144 
00145         
00146                 {"precision",                   1,NULL,'p'}, 
00147                 {"randomization",               0,NULL,'c'}, 
00148                 {"timeout",                     1,NULL,'T'}, 
00149                 {"output",                      1,NULL,'o'}, 
00150                 {"policy-interval",             1, NULL, 'i'}, 
00151                 {"trial-improvement-factor",     1,NULL, 'j'}, 
00152 
00153                 
00154                 {"unfactored-init",             0, NULL, 'M' }, 
00155                 {"hardcoded",                   1, NULL, 'H' }, 
00156                 {"mdp",                         0, NULL, 'W' }, 
00157                 {"qmdp",                        0, NULL, 'X' }, 
00158                 {"fib",                         0, NULL, 'F' }, 
00159                 {"overPruneThreshold",          1,NULL, 'b'}, 
00160                 {"lowerPruneThreshold",         1,NULL, 'g'}, 
00161                 {"trials",                      1,NULL, 'N'}, 
00162                 {"dump",                        0,NULL, 'D'}, 
00163                 
00164                 {"dumpPolicyTrace",             0,NULL, 'P'}, 
00165 
00166         
00167                 { "policy-file",                1, NULL, 'Q' }, 
00168 
00169                 
00170                 { "lookahead",                  1, NULL, 'L' }, 
00171 
00172         
00173                 { "simLen",                     1, NULL, 'S' }, 
00174                 { "simNum",                     1, NULL, 'U' }, 
00175                 { "srand",                      1, NULL, 'R' }, 
00176                 { "output-file",                1, NULL, 'O' }, 
00177 
00178 
00179         
00180                 
00181                 { "statemap",                   1, NULL, 'A' }, 
00182 
00183         
00184                 { "policy-graph",               1, NULL, 'G' }, 
00185                 { "graph-max-depth",            1, NULL, 'd' }, 
00186                 { "graph-max-branch",           1, NULL, 'B' }, 
00187                 { "graph-min-prob",             1, NULL, 't' }, 
00188 
00189                 {NULL,0,0,0}
00190 
00191                 };
00192 
00193                 
00194                 
00195                 
00196                 
00197                 
00198                 
00199                 
00200                 
00201                 
00202                 
00203                 
00204                 
00205                 
00206                 
00207                 
00208                 
00209                 
00210                 
00211                 
00212                 
00213 
00214 
00215                 p.cmdName = argv[0];
00216 
00217                 while (1) 
00218                 {
00219                         char optchar = getopt_long(argc,argv,shortOptions,longOptions,NULL);
00220                         if (optchar == -1) break;
00221 
00222                         switch (optchar) 
00223                         {
00224                         case 'A':
00225                                 {
00226                                         p.stateMapFile = string(optarg);
00227                                 }
00228                                 break;
00229                         case 'H':
00230                                 {
00231                                         p.hardcodedProblem = string(optarg);
00232                                 }
00233                                 break;
00234                         case 'M':
00235                                 {
00236                                         p.doConvertPOMDP = true;
00237                                 }
00238                                 break;
00239                         case 'B':
00240                                 {
00241                                         p.graphMaxBranch = atoi(optarg);
00242                                 }
00243                                 break;
00244                         case 'G':
00245                                 {
00246                                         p.policyGraphFile = string(optarg);
00247                                 }
00248                                 break;
00249                         case 'd':
00250                                 {
00251                                         p.graphDepth = atoi(optarg);
00252                                 }
00253                                 break;
00254                         case 't':
00255                                 {
00256                                         p.graphProbThreshold = atof(optarg);
00257                                 }
00258                                 break;
00259 
00260                         case 'h': 
00261                                 return false;
00262                                 break;
00263                         case 'V': 
00264                                 cout << "Approximate POMDP Planning (APPL) Toolkit Version 0.9" << endl;
00265                                 exit(EXIT_SUCCESS);
00266                                 break;
00267                         case 's': 
00268                                 p.setStrategy(optarg);
00269                                 break;
00270                         case 'f': 
00271                                 p.useFastParser = true;
00272                                 break;
00273                         case 'p': 
00274                                 p.targetPrecision = atof(optarg);
00275                                 break;
00276                         case 'b': 
00277                                 p.overPruneThreshold = atof(optarg);
00278                                 break;
00279                         case 'g': 
00280                                 p.lowerPruneThreshold = atof(optarg);
00281                                 break;
00282                         case 'j': 
00283                                 p.BP_IMPROVEMENT_CONSTANT = atof(optarg);
00284                                 break;
00285                         case 'm': 
00286                                 {
00287                                         double limit = atof(optarg);
00288                                         p.memoryLimit = (unsigned long)(limit * 1024*1024);
00289                                         GlobalMemLimit = (unsigned long)(limit * 1024*1024);
00290                                 }
00291                                 break;
00292                         case 'l':
00293                                 p.dynamicDeltaPercentageMode = true;
00294                                 break;
00295                         case 'c': 
00296                                 p.randomizationBP = true;
00297                                 break;
00298 
00299                         case 'N': 
00300                                 p.targetTrials = atoi(optarg);
00301                                 break;
00302 
00303                         case 'D': 
00304                                 p.dumpData = true;
00305                                 break;
00306                         case 'P': 
00307                                 p.dumpPolicyTrace = true;
00308                                 break;
00309                         case 'I': 
00310                                 p.dumpPolicyTraceTime = true;
00311                                 break;
00312                         case 'o': 
00313                                 p.outPolicyFileName = string(optarg);
00314                                 break;
00315                         case 'T': 
00316                                 p.timeoutSeconds = atof(optarg);
00317                                 break;
00318                         case 'i':
00319                                 p.interval = atof(optarg);
00320                                 break;
00321                         case 'a':
00322                                 p.delta = atof(optarg);
00323                                 break;
00324                         case 'W':
00325                                 p.MDPSolution = true;
00326                                 break;
00327                         case 'X':
00328                                 p.QMDPSolution = true;
00329                                 break;
00330                         case 'F':
00331                                 p.FIBSolution = true;
00332                                 break;
00333 
00334 
00335                         case 'L':
00336                                 {
00337                                         string useLookahead(optarg);
00338                                         if(useLookahead.compare("yes") == 0)
00339                                         {
00340                                                 p.useLookahead = true;
00341                                         }
00342                                         else
00343                                         {
00344                                                 p.useLookahead = false;
00345                                         }
00346                                 }
00347                                 break;
00348                         case 'Q':
00349                                 {
00350                                         p.policyFile = string(optarg);
00351                                 }
00352                                 break;
00353                         case 'O':
00354                                 {
00355                                         p.outputFile = string(optarg);
00356                                 }
00357                                 break;
00358                         case 'S':
00359                                 p.simLen = atoi(optarg);
00360                                 break;
00361                         case 'U':
00362                                 p.simNum = atoi(optarg);
00363                                 break;
00364                         case 'R':
00365                                 p.seed = atoi(optarg) >= 0 ? atoi(optarg) : time(0);
00366                                 break;
00367 
00368                         case '?': 
00369                         case ':': 
00370                                 
00371                                 cerr << endl;
00372                                 return false;
00373                                 break;
00374                         default:
00375                                 cerr << "unknowm paramter specified" << endl << endl;
00376                                 return false;
00377                         }
00378                 }
00379                 if (argc-optind != 1) 
00380                 {
00381                         if(p.hardcodedProblem.length() > 0 )
00382                         {
00383                                 cout << "Using hardcoded problem : " << p.hardcodedProblem << endl;
00384                         }
00385                         else
00386                         {
00387                                 cerr << "Error: no arguments were given." << endl << endl;
00388                                 return false;
00389                         }
00390                 }
00391 
00392                 
00393                 p.inferMissingValues();
00394 
00395                 if( p.hardcodedProblem.length() ==0 )
00396                 {
00397                         p.problemName = string(argv[optind++]);
00398                 
00399                         
00400                         std::string probNameStr = p.problemName;
00401                         std::string suffixStr(".pomdp");
00402                         std::string suffixStr2(".pomdpx");
00403                         std::transform(probNameStr.begin(), probNameStr.end(), probNameStr.begin(), ::tolower);
00404 
00405                         bool test1 = endsWith(probNameStr, suffixStr);
00406                         test1 |= endsWith(probNameStr, suffixStr2);
00407                         if (test1) 
00408                         {
00409                                 
00410                         } 
00411                         else 
00412                         {
00413                                 cerr << "ERROR: only POMDP or POMDPX file format with suffix .pomdp or .pomdpx are supported. The specified file: "<< p.problemName << " is not supported." << endl<< endl;
00414                                 return false;
00415                         }
00416 
00417                         p.problemBasenameWithoutPath = GlobalResource::parseBaseNameWithoutPath(p.problemName);
00418                         p.problemBasenameWithPath = GlobalResource::parseBaseNameWithPath(p.problemName);
00419                 }
00420                 else
00421                 {
00422                         p.problemName = p.hardcodedProblem;
00423                         p.problemBasenameWithoutPath = p.hardcodedProblem;
00424                         p.problemBasenameWithPath = p.hardcodedProblem;
00425 
00426                 }
00427                 return true;
00428         }
00429 }; 
00430