RewardsByFunction.cpp
Go to the documentation of this file.
00001 #include "RewardsByFunction.h"
00002 #include "BeliefWithState.h"
00003 #include "Actions.h"
00004 #include "MOMDP.h"
00005 #include <sstream>
00006 RewardsByFunction::RewardsByFunction(void)
00007 {
00008 }
00009 
00010 RewardsByFunction::~RewardsByFunction(void)
00011 {
00012 }
00013 
00014 REAL_VALUE RewardsByFunction::getReward(BeliefWithState& b, int a)
00015 {
00016         int Xc = b.sval; // currrent value for observed state variable
00017         SharedPointer<belief_vector> Bc = b.bvec; // current belief for unobserved state variable
00018 
00019         if (!(getMatrix(Xc)->isColumnEmpty(a)))
00020         {
00021                 return inner_prod_column( *getMatrix(Xc), a, *Bc );
00022         }
00023         else
00024         {
00025                 return 0;
00026         }
00027 }
00028 
00029 // (unobserved states, action)
00030 SharedPointer<SparseMatrix> RewardsByFunction::getMatrix(int x)
00031 {
00032         stringstream ss;
00033         ss << "reward x" << x;
00034         string key = ss.str();
00035         if(problem->cache.hasKey(key))
00036         {
00037                 
00038         }
00039         else
00040         {
00041                 problem->cache.put(key, getMatrixInner(x,false));
00042         }
00043         return problem->cache.get(key);
00044 }
00045 //SharedPointer<SparseMatrix> RewardsByFunction::getMatrixTr(int x)
00046 //{
00047 //      return getMatrixInner(x,true);
00048 //}
00049 SharedPointer<SparseMatrix> RewardsByFunction::getMatrixInner(int x, bool transpose)
00050 {
00051         kmatrix tempMatrix;
00052         
00053         int numAction = problem->actions->size();
00054         int numUnobsState = problem->YStates->size();
00055         if(!transpose)
00056         {
00057                 tempMatrix.resize(numUnobsState, numAction);
00058         }
00059         else
00060         {
00061                 tempMatrix.resize(numAction, numUnobsState);
00062         }
00063 
00064         SharedPointer<SparseMatrix> result (new SparseMatrix());
00065         ValueSet xVals = problem->XStates->get(x);
00066 
00067         FOR(s, problem->YStates->size())
00068         {
00069                 ValueSet yVals = problem->YStates->get(s);
00070                 FOR(a, problem->actions->size())
00071                 {
00072                         ValueSet aVals = problem->actions->get(a);
00073 
00074                         map<string, SharedPointer<IVariableValue> > sourceVals;
00075                         sourceVals.insert(xVals.vals.begin(), xVals.vals.end());
00076                         sourceVals.insert(yVals.vals.begin(), yVals.vals.end());
00077                         sourceVals.insert(aVals.vals.begin(), aVals.vals.end());
00078 
00079                         vector<vector<SharedPointer<RelEntry> > > RelEntries;
00080                         double totalReward = 0.0;
00081                         FOREACH(SharedPointer<VariableRelation> , curRel, relations)
00082                         {
00083                                 vector<SharedPointer<RelEntry> > rewards = (*curRel)->getProb(sourceVals);
00084                                 if(rewards.size() > 1 )
00085                                 {
00086                                         throw runtime_error("Reward Relation should return at most one RelEntry");
00087                                 }
00088                                 if(rewards.size() > 0 )
00089                                 {
00090                                         totalReward += rewards[0]->prob;
00091                                 }
00092 
00093                         }
00094 
00095                         if(!transpose)
00096                         {
00097                                 tempMatrix.push_back(s, a, totalReward);
00098                         }
00099                         else
00100                         {
00101                                 tempMatrix.push_back(a, s, totalReward);
00102                         }
00103 
00104 
00105                 }
00106         }
00107                 
00108         copy(*result, tempMatrix);
00109         return result;
00110 }


appl
Author(s): petercai
autogenerated on Tue Jan 7 2014 11:02:29