00001
00008 #ifndef _M5TREE_HH_
00009 #define _M5TREE_HH_
00010
00011 #include <rl_common/Random.h>
00012 #include <rl_common/core.hh>
00013 #include <vector>
00014 #include <set>
00015 #include <map>
00016
00017 #define N_M5_EXP 200000
00018 #define N_M5_NODES 2500
00019
00020 #define BUILD_EVERY 0
00021 #define BUILD_ON_ERROR 1
00022 #define BUILD_EVERY_N 2
00023 #define BUILD_ON_TERMINAL 3
00024 #define BUILD_ON_TERMINAL_AND_ERROR 4
00025
00027 class M5Tree: public Classifier {
00028
00029 public:
00030
00042 M5Tree(int id, int trainMode, int trainFreq, int m,
00043 float featPct, bool simple, bool allowAllFeats,
00044 float min_sdr, Random rng);
00045
00047 M5Tree(const M5Tree&);
00048
00049 virtual M5Tree* getCopy();
00050
00051 ~M5Tree();
00052
00053
00054 struct tree_node;
00055 struct tree_experience;
00056
00057
00059 struct tree_node {
00060 int id;
00061
00062
00063 int dim;
00064 float val;
00065 bool leaf;
00066
00067
00068 tree_node *l;
00069 tree_node *r;
00070
00071
00072 int nInstances;
00073
00074
00075 float constant;
00076 std::vector<float> coefficients;
00077
00078 };
00079
00081 struct tree_experience {
00082 std::vector<float> input;
00083 float output;
00084 };
00085
00087 void copyTree(tree_node* newNode, tree_node* origNode);
00088
00089 virtual bool trainInstance(classPair &instance);
00090 virtual bool trainInstances(std::vector<classPair> &instances);
00091 virtual void testInstance(const std::vector<float> &input, std::map<float, float>* retval);
00092 virtual float getConf(const std::vector<float> &input);
00093
00095 void buildTree(tree_node* node, const std::vector<tree_experience*> &instances, bool changed);
00096
00097
00099 void initTree();
00100
00102 void rebuildTree();
00103
00105 void initTreeNode(tree_node* node);
00106
00108 tree_node* traverseTree(tree_node* node, const std::vector<float> &input);
00109
00111 tree_node* getCorrectChild(tree_node* node, const std::vector<float> &input);
00112
00114 bool passTest(int dim, float val, const std::vector<float> &input);
00115
00117 float calcSDR(int dim, float val,
00118 const std::vector<tree_experience*> &instances, float sd,
00119 std::vector<tree_experience*> &left,
00120 std::vector<tree_experience*> &right);
00121
00123 float* sortOnDim(int dim, const std::vector<tree_experience*> &instances);
00124
00126 std::set<float> getUniques(int dim, const std::vector<tree_experience*> &instances, float & minVal, float& maxVal);
00127
00129 void deleteTree(tree_node* node);
00130
00132 float calcSDforSet(const std::vector<tree_experience*> &instances);
00133
00135 void printTree(tree_node *t, int level);
00136
00138 void testPossibleSplits(const std::vector<tree_experience*> &instances,
00139 float *bestSDR, int *bestDim,
00140 float *bestVal,
00141 std::vector<tree_experience*> *bestLeft,
00142 std::vector<tree_experience*> *bestRight);
00143
00145 void implementSplit(tree_node* node,
00146 const std::vector<tree_experience*> &instances,
00147 float bestSDR, int bestDim,
00148 float bestVal,
00149 const std::vector<tree_experience*> &left,
00150 const std::vector<tree_experience*> &right,
00151 bool changed);
00152
00154 void compareSplits(float sdr, int dim, float val,
00155 const std::vector<tree_experience*> &left,
00156 const std::vector<tree_experience*> &right,
00157 int *nties, float *bestSDR, int *bestDim,
00158 float *bestVal,
00159 std::vector<tree_experience*> *bestLeft,
00160 std::vector<tree_experience*> *bestRight);
00161
00163 void leafPrediction(tree_node *t, const std::vector<float> &in, std::map<float, float>* retval);
00164
00166 void makeLeaf(tree_node* node);
00167
00169 void removeChildren(tree_node* node);
00170
00172 void pruneTree(tree_node* node, const std::vector<tree_experience*> &instances);
00173
00175 int fitLinearModel(tree_node* node, const std::vector<tree_experience*> &instances,
00176 std::vector<bool> featureMask, int nFeats, float* resSum);
00177
00179 int fitSimpleLinearModel(tree_node* node, const std::vector<tree_experience*> &instances,
00180 std::vector<bool> featureMask, int nFeats, float* resSum);
00181
00183 void getFeatsUsed(tree_node* node, std::vector<bool> *featsUsed);
00184
00186 tree_node* allocateNode();
00187
00189 void deallocateNode(tree_node* node);
00190
00192 void initNodes();
00193
00194 bool INCDEBUG;
00195 bool DTDEBUG;
00196 bool LMDEBUG;
00197 bool SPLITDEBUG;
00198 bool STOCH_DEBUG;
00199 bool NODEDEBUG;
00200 bool COPYDEBUG;
00201 int nExperiences;
00202
00203 float SPLIT_MARGIN;
00204
00205 private:
00206
00207 const int id;
00208
00209 const int mode;
00210 const int freq;
00211 const int M;
00212 float featPct;
00213 const bool SIMPLE;
00214 const bool ALLOW_ALL_FEATS;
00215 const float MIN_SDR;
00216
00217 Random rng;
00218
00219 int nfeat;
00220
00221 int nOutput;
00222 int nnodes;
00223 bool hadError;
00224 int totalnodes;
00225 int maxnodes;
00226
00227
00229 std::vector<tree_experience*> experiences;
00230
00232 tree_experience allExp[N_M5_EXP];
00233
00235 tree_node allNodes[N_M5_NODES];
00236 std::vector<int> freeNodes;
00237
00238
00240 tree_node* root;
00242 tree_node* lastNode;
00243
00244 };
00245
00246
00247 #endif
00248