00001
00009 #ifndef _C45TREE_HH_
00010 #define _C45TREE_HH_
00011
00012 #include <rl_common/Random.h>
00013 #include <rl_common/core.hh>
00014 #include <vector>
00015 #include <set>
00016 #include <map>
00017
00018 #define N_C45_EXP 200000
00019 #define N_C45_NODES 2500
00020
00021 #define BUILD_EVERY 0
00022 #define BUILD_ON_ERROR 1
00023 #define BUILD_EVERY_N 2
00024 #define BUILD_ON_TERMINAL 3
00025 #define BUILD_ON_TERMINAL_AND_ERROR 4
00026
00027
00029 class C45Tree: public Classifier {
00030
00031 public:
00032
00041 C45Tree(int id, int trainMode, int trainFreq, int m,
00042 float featPct, Random rng);
00043
00045 C45Tree(const C45Tree&);
00046
00047 ~C45Tree();
00048
00049
00050 struct tree_node;
00051 struct tree_experience;
00052
00054 void copyTree(tree_node* newNode, tree_node* origNode);
00055
00056 virtual C45Tree* getCopy();
00057
00059 struct tree_node {
00060 int id;
00061
00062
00063 int dim;
00064 float val;
00065 bool type;
00066
00067
00068 std::map<float,int> outputs;
00069 int nInstances;
00070
00071
00072 tree_node *l;
00073 tree_node *r;
00074
00075 bool leaf;
00076 };
00077
00079 struct tree_experience {
00080 std::vector<float> input;
00081 float output;
00082 };
00083
00085 enum splitTypes{
00086 ONLY,
00087 CUT
00088 };
00089
00090 virtual bool trainInstance(classPair &instance);
00091 virtual bool trainInstances(std::vector<classPair> &instances);
00092 virtual void testInstance(const std::vector<float> &input, std::map<float, float>* retval);
00093 virtual float getConf(const std::vector<float> &input);
00094
00096 bool buildTree(tree_node* node, const std::vector<tree_experience*> &instances, bool changed);
00097
00098
00100 void initTree();
00101
00103 bool rebuildTree();
00104
00106 void initTreeNode(tree_node* node);
00107
00109 tree_node* traverseTree(tree_node* node, const std::vector<float> &input);
00110
00112 tree_node* getCorrectChild(tree_node* node, const std::vector<float> &input);
00113
00115 bool passTest(int dim, float val, bool type, const std::vector<float> &input);
00116
00118 float calcGainRatio(int dim, float val, bool type,
00119 const std::vector<tree_experience*> &instances, float I,
00120 std::vector<tree_experience*> &left,
00121 std::vector<tree_experience*> &right);
00122
00124 float* sortOnDim(int dim, const std::vector<tree_experience*> &instances);
00125
00127 std::set<float> getUniques(int dim, const std::vector<tree_experience*> &instances, float & minVal, float& maxVal);
00128
00130 void deleteTree(tree_node* node);
00131
00133 float calcIofP(float* P, int size);
00134
00136 float calcIforSet(const std::vector<tree_experience*> &instances);
00137
00139 void printTree(tree_node *t, int level);
00140
00142 void testPossibleSplits(const std::vector<tree_experience*> &instances, float *bestGainRatio, int *bestDim,
00143 float *bestVal, bool *bestType,
00144 std::vector<tree_experience*> *bestLeft, std::vector<tree_experience*> *bestRight);
00145
00147 bool implementSplit(tree_node* node, float bestGainRatio, int bestDim,
00148 float bestVal, bool bestType,
00149 const std::vector<tree_experience*> &left,
00150 const std::vector<tree_experience*> &right, bool changed);
00151
00153 void compareSplits(float gainRatio, int dim, float val, bool type,
00154 const std::vector<tree_experience*> &left,
00155 const std::vector<tree_experience*> &right,
00156 int *nties, float *bestGainRatio, int *bestDim,
00157 float *bestVal, bool *bestType,
00158 std::vector<tree_experience*> *bestLeft,
00159 std::vector<tree_experience*> *bestRight);
00160
00162 void outputProbabilities(tree_node *t, std::map<float, float>* retval);
00163
00165 bool makeLeaf(tree_node* node);
00166
00168 tree_node* allocateNode();
00169
00171 void deallocateNode(tree_node* node);
00172
00174 void initNodes();
00175
00176
00177 bool INCDEBUG;
00178 bool DTDEBUG;
00179 bool SPLITDEBUG;
00180 bool STOCH_DEBUG;
00181 int nExperiences;
00182 bool NODEDEBUG;
00183 bool COPYDEBUG;
00184
00185 float SPLIT_MARGIN;
00186 float MIN_GAIN_RATIO;
00187
00188 private:
00189
00190 const int id;
00191
00192 const int mode;
00193 const int freq;
00194 const int M;
00195 const float featPct;
00196 const bool ALLOW_ONLY_SPLITS;
00197
00198 Random rng;
00199
00200 int nOutput;
00201 int nnodes;
00202 bool hadError;
00203 int maxnodes;
00204 int totalnodes;
00205
00207 std::vector<tree_experience*> experiences;
00208
00210 tree_experience allExp[N_C45_EXP];
00211
00213 tree_node allNodes[N_C45_NODES];
00214 std::vector<int> freeNodes;
00215
00216
00218 tree_node* root;
00220 tree_node* lastNode;
00221
00222 };
00223
00224
00225 #endif
00226