LinearSplitsTree.hh
Go to the documentation of this file.
00001 #ifndef _LINEARSPLITS_HH_
00002 #define _LINEARSPLITS_HH_
00003 
00004 #include <rl_common/Random.h>
00005 #include <rl_common/core.hh>
00006 #include <vector>
00007 #include <set>
00008 #include <map>
00009 
00010 #define N_LST_EXP 200000
00011 #define N_LS_NODES 2500
00012 
00013 #define BUILD_EVERY 0
00014 #define BUILD_ON_ERROR 1
00015 #define BUILD_EVERY_N 2
00016 #define BUILD_ON_TERMINAL 3
00017 #define BUILD_ON_TERMINAL_AND_ERROR 4
00018 
00020 class LinearSplitsTree: public Classifier {
00021 
00022 public:
00023 
00024   // mode - re-build tree every step?  
00025   // re-build only on misclassifications? or rebuild every 'trainFreq' steps
00026   LinearSplitsTree(int id, int trainMode, int trainFreq, int m, 
00027                    float featPct, bool simple, float min_er, Random rng);
00028 
00029   LinearSplitsTree(const LinearSplitsTree&);
00030   virtual LinearSplitsTree* getCopy();
00031 
00032   ~LinearSplitsTree();
00033 
00034   // structs to be defined
00035   struct tree_node;
00036   struct tree_experience;
00037   
00038     
00040   struct tree_node {
00041     int id;
00042 
00043     // split criterion
00044     int dim;
00045     float val;
00046     float avgError;
00047 
00048     bool leaf;
00049     
00050     // for regression model
00051     float constant;
00052     std::vector<float> coefficients;
00053 
00054     // next nodes in tree
00055     tree_node *l;
00056     tree_node *r;
00057 
00058     // set of all outputs seen at this leaf/node
00059     int nInstances;
00060 
00061   };
00062 
00063   struct tree_experience {
00064     std::vector<float> input;
00065     float output;
00066   };
00067 
00068   bool trainInstance(classPair &instance);
00069   bool trainInstances(std::vector<classPair> &instances);
00070   void testInstance(const std::vector<float> &input, std::map<float, float>* retval);
00071   float getConf(const std::vector<float> &input);
00072 
00073   void buildTree(tree_node* node, const std::vector<tree_experience*> &instances,
00074                  bool changed);
00075   void copyTree(tree_node* newNode, tree_node* origNode);
00076 
00077 
00078   // helper functions
00079   void initTree();
00080   void rebuildTree();
00081   void initTreeNode(tree_node* node);
00082   tree_node* traverseTree(tree_node* node, const std::vector<float> &input);
00083   tree_node* getCorrectChild(tree_node* node, const std::vector<float> &input);
00084   bool passTest(int dim, float val, const std::vector<float> &input);
00085   float calcER(int dim, float val, 
00086                const std::vector<tree_experience*> &instances, float error,
00087                std::vector<tree_experience*> &left,
00088                std::vector<tree_experience*> &right,
00089                float *leftError, float *rightError);
00090   float* sortOnDim(int dim, const std::vector<tree_experience*> &instances);
00091   std::set<float> getUniques(int dim, const std::vector<tree_experience*> &instances, float & minVal, float& maxVal);
00092   void deleteTree(tree_node* node);
00093   float calcAvgErrorforSet(const std::vector<tree_experience*> &instances);
00094   void printTree(tree_node *t, int level);
00095   void testPossibleSplits(float avgError, const std::vector<tree_experience*> &instances, 
00096                           float *bestER, int *bestDim, 
00097                           float *bestVal, 
00098                           std::vector<tree_experience*> *bestLeft, 
00099                           std::vector<tree_experience*> *bestRight,
00100                           float *bestLeftError, float *bestRightError);
00101   void implementSplit(tree_node* node, const std::vector<tree_experience*> &instances, 
00102                       float bestER, int bestDim,
00103                       float bestVal, 
00104                       const std::vector<tree_experience*> &left, 
00105                       const std::vector<tree_experience*> &right,
00106                       bool changed, float leftError, float rightError);
00107   void compareSplits(float er, int dim, float val, 
00108                      const std::vector<tree_experience*> &left, 
00109                      const std::vector<tree_experience*> &right,
00110                      int *nties, float leftError, float rightError,
00111                      float *bestER, int *bestDim, 
00112                      float *bestVal, 
00113                      std::vector<tree_experience*> *bestLeft, 
00114                      std::vector<tree_experience*> *bestRight,
00115                      float *bestLeftError, float *bestRightError);
00116   void leafPrediction(tree_node *t, const std::vector<float> &in, std::map<float, float>* retval);
00117   void makeLeaf(tree_node* node, const std::vector<tree_experience*> &instances);
00118 
00119   float fitSimpleLinearModel(const std::vector<tree_experience*> &instances,
00120                              float* constant, std::vector<float> *coeff);
00121   float fitMultiLinearModel(const std::vector<tree_experience*> &instances,
00122                             float* constant, std::vector<float> * coeff);
00123 
00124   tree_node* allocateNode();
00125   void deallocateNode(tree_node* node);
00126   void initNodes();
00127 
00128   bool INCDEBUG;
00129   bool DTDEBUG;
00130   bool LMDEBUG;
00131   bool SPLITDEBUG;
00132   bool STOCH_DEBUG;
00133   bool NODEDEBUG;
00134   int nExperiences;
00135   bool COPYDEBUG;
00136 
00137   float SPLIT_MARGIN;
00138 
00139 private:
00140 
00141   const int id;
00142   
00143   const int mode;
00144   const int freq;
00145   const int M;
00146   const float featPct; 
00147   const bool SIMPLE;
00148   const float MIN_ER;
00149 
00150   Random rng;
00151 
00152   int nOutput;
00153   int nnodes;
00154   bool hadError;
00155   int totalnodes;
00156   int maxnodes;
00157 
00158   // INSTANCES
00159   std::vector<tree_experience*> experiences;
00160   tree_experience allExp[N_LST_EXP];
00161   tree_node allNodes[N_LS_NODES];
00162   std::vector<int> freeNodes;
00163 
00164   // TREE
00165   tree_node* root;
00166   tree_node* lastNode;
00167 
00168 };
00169 
00170 
00171 #endif
00172   


rl_agent
Author(s): Todd Hester
autogenerated on Thu Jun 6 2019 22:00:13