M5Tree.hh
Go to the documentation of this file.
00001 
00008 #ifndef _M5TREE_HH_
00009 #define _M5TREE_HH_
00010 
00011 #include <rl_common/Random.h>
00012 #include <rl_common/core.hh>
00013 #include <vector>
00014 #include <set>
00015 #include <map>
00016 
00017 #define N_M5_EXP 200000
00018 #define N_M5_NODES 2500
00019 
00020 #define BUILD_EVERY 0
00021 #define BUILD_ON_ERROR 1
00022 #define BUILD_EVERY_N 2
00023 #define BUILD_ON_TERMINAL 3
00024 #define BUILD_ON_TERMINAL_AND_ERROR 4
00025 
00027 class M5Tree: public Classifier {
00028 
00029 public:
00030 
00042   M5Tree(int id, int trainMode, int trainFreq, int m, 
00043          float featPct, bool simple, bool allowAllFeats, 
00044          float min_sdr, Random rng);
00045 
00047   M5Tree(const M5Tree&);
00048 
00049   virtual M5Tree* getCopy();
00050 
00051   ~M5Tree();
00052 
00053   // structs to be defined
00054   struct tree_node;
00055   struct tree_experience;
00056   
00057     
00059   struct tree_node {
00060     int id;
00061 
00062     // split criterion
00063     int dim;
00064     float val;
00065     bool leaf;
00066 
00067     // next nodes in tree
00068     tree_node *l;
00069     tree_node *r;
00070     
00071     // set of all outputs seen at this leaf/node
00072     int nInstances;
00073 
00074     // for regression model
00075     float constant;
00076     std::vector<float> coefficients;
00077 
00078   };
00079 
00081   struct tree_experience {
00082     std::vector<float> input;
00083     float output;
00084   };
00085   
00087   void copyTree(tree_node* newNode, tree_node* origNode);
00088 
00089   virtual bool trainInstance(classPair &instance);
00090   virtual bool trainInstances(std::vector<classPair> &instances);
00091   virtual void testInstance(const std::vector<float> &input, std::map<float, float>* retval);
00092   virtual float getConf(const std::vector<float> &input);
00093 
00095   void buildTree(tree_node* node, const std::vector<tree_experience*> &instances,  bool changed);
00096 
00097   // helper functions
00099   void initTree();
00100 
00102   void rebuildTree();
00103 
00105   void initTreeNode(tree_node* node);
00106 
00108   tree_node* traverseTree(tree_node* node, const std::vector<float> &input);
00109 
00111   tree_node* getCorrectChild(tree_node* node, const std::vector<float> &input);
00112 
00114   bool passTest(int dim, float val, const std::vector<float> &input);
00115 
00117   float calcSDR(int dim, float val, 
00118                        const std::vector<tree_experience*> &instances, float sd,
00119                         std::vector<tree_experience*> &left,
00120                         std::vector<tree_experience*> &right);
00121 
00123   float* sortOnDim(int dim, const std::vector<tree_experience*> &instances);
00124 
00126   std::set<float> getUniques(int dim, const std::vector<tree_experience*> &instances, float & minVal, float& maxVal);
00127 
00129   void deleteTree(tree_node* node);
00130 
00132   float calcSDforSet(const std::vector<tree_experience*> &instances);
00133 
00135   void printTree(tree_node *t, int level);
00136 
00138   void testPossibleSplits(const std::vector<tree_experience*> &instances, 
00139                           float *bestSDR, int *bestDim, 
00140                           float *bestVal, 
00141                           std::vector<tree_experience*> *bestLeft, 
00142                           std::vector<tree_experience*> *bestRight);
00143   
00145   void implementSplit(tree_node* node, 
00146                       const std::vector<tree_experience*> &instances,
00147                       float bestSDR, int bestDim,
00148                       float bestVal, 
00149                       const std::vector<tree_experience*> &left, 
00150                       const std::vector<tree_experience*> &right,
00151                       bool changed);
00152   
00154   void compareSplits(float sdr, int dim, float val, 
00155                      const std::vector<tree_experience*> &left, 
00156                      const std::vector<tree_experience*> &right,
00157                      int *nties, float *bestSDR, int *bestDim, 
00158                      float *bestVal, 
00159                      std::vector<tree_experience*> *bestLeft, 
00160                      std::vector<tree_experience*> *bestRight);
00161   
00163   void leafPrediction(tree_node *t, const std::vector<float> &in, std::map<float, float>* retval);
00164   
00166   void makeLeaf(tree_node* node);
00167 
00169   void removeChildren(tree_node* node);
00170 
00172   void pruneTree(tree_node* node, const std::vector<tree_experience*> &instances);
00173   
00175   int fitLinearModel(tree_node* node, const std::vector<tree_experience*> &instances,
00176                      std::vector<bool> featureMask, int nFeats, float* resSum);
00177 
00179   int fitSimpleLinearModel(tree_node* node, const std::vector<tree_experience*> &instances,
00180                      std::vector<bool> featureMask, int nFeats, float* resSum);
00181 
00183   void getFeatsUsed(tree_node* node, std::vector<bool> *featsUsed);
00184 
00186   tree_node* allocateNode();
00187   
00189   void deallocateNode(tree_node* node);
00190   
00192   void initNodes();
00193 
00194   bool INCDEBUG;
00195   bool DTDEBUG;
00196   bool LMDEBUG;
00197   bool SPLITDEBUG;
00198   bool STOCH_DEBUG;
00199   bool NODEDEBUG;
00200   bool COPYDEBUG;
00201   int nExperiences;
00202 
00203   float SPLIT_MARGIN;
00204 
00205 private:
00206 
00207   const int id;
00208   
00209   const int mode;
00210   const int freq;
00211   const int M;
00212   float featPct; 
00213   const bool SIMPLE;
00214   const bool ALLOW_ALL_FEATS;
00215   const float MIN_SDR;
00216 
00217   Random rng;
00218 
00219   int nfeat;
00220 
00221   int nOutput;
00222   int nnodes;
00223   bool hadError;
00224   int totalnodes;
00225   int maxnodes;
00226 
00227   // INSTANCES
00229   std::vector<tree_experience*> experiences;
00230   
00232   tree_experience allExp[N_M5_EXP];
00233   
00235   tree_node allNodes[N_M5_NODES];
00236   std::vector<int> freeNodes;
00237 
00238   // TREE
00240   tree_node* root;
00242   tree_node* lastNode;
00243 
00244 };
00245 
00246 
00247 #endif
00248   


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13