00001 #ifndef _LINEARSPLITS_HH_
00002 #define _LINEARSPLITS_HH_
00003
00004 #include <rl_common/Random.h>
00005 #include <rl_common/core.hh>
00006 #include <vector>
00007 #include <set>
00008 #include <map>
00009
00010 #define N_LST_EXP 200000
00011 #define N_LS_NODES 2500
00012
00013 #define BUILD_EVERY 0
00014 #define BUILD_ON_ERROR 1
00015 #define BUILD_EVERY_N 2
00016 #define BUILD_ON_TERMINAL 3
00017 #define BUILD_ON_TERMINAL_AND_ERROR 4
00018
00020 class LinearSplitsTree: public Classifier {
00021
00022 public:
00023
00024
00025
00026 LinearSplitsTree(int id, int trainMode, int trainFreq, int m,
00027 float featPct, bool simple, float min_er, Random rng);
00028
00029 LinearSplitsTree(const LinearSplitsTree&);
00030 virtual LinearSplitsTree* getCopy();
00031
00032 ~LinearSplitsTree();
00033
00034
00035 struct tree_node;
00036 struct tree_experience;
00037
00038
00040 struct tree_node {
00041 int id;
00042
00043
00044 int dim;
00045 float val;
00046 float avgError;
00047
00048 bool leaf;
00049
00050
00051 float constant;
00052 std::vector<float> coefficients;
00053
00054
00055 tree_node *l;
00056 tree_node *r;
00057
00058
00059 int nInstances;
00060
00061 };
00062
00063 struct tree_experience {
00064 std::vector<float> input;
00065 float output;
00066 };
00067
00068 bool trainInstance(classPair &instance);
00069 bool trainInstances(std::vector<classPair> &instances);
00070 void testInstance(const std::vector<float> &input, std::map<float, float>* retval);
00071 float getConf(const std::vector<float> &input);
00072
00073 void buildTree(tree_node* node, const std::vector<tree_experience*> &instances,
00074 bool changed);
00075 void copyTree(tree_node* newNode, tree_node* origNode);
00076
00077
00078
00079 void initTree();
00080 void rebuildTree();
00081 void initTreeNode(tree_node* node);
00082 tree_node* traverseTree(tree_node* node, const std::vector<float> &input);
00083 tree_node* getCorrectChild(tree_node* node, const std::vector<float> &input);
00084 bool passTest(int dim, float val, const std::vector<float> &input);
00085 float calcER(int dim, float val,
00086 const std::vector<tree_experience*> &instances, float error,
00087 std::vector<tree_experience*> &left,
00088 std::vector<tree_experience*> &right,
00089 float *leftError, float *rightError);
00090 float* sortOnDim(int dim, const std::vector<tree_experience*> &instances);
00091 std::set<float> getUniques(int dim, const std::vector<tree_experience*> &instances, float & minVal, float& maxVal);
00092 void deleteTree(tree_node* node);
00093 float calcAvgErrorforSet(const std::vector<tree_experience*> &instances);
00094 void printTree(tree_node *t, int level);
00095 void testPossibleSplits(float avgError, const std::vector<tree_experience*> &instances,
00096 float *bestER, int *bestDim,
00097 float *bestVal,
00098 std::vector<tree_experience*> *bestLeft,
00099 std::vector<tree_experience*> *bestRight,
00100 float *bestLeftError, float *bestRightError);
00101 void implementSplit(tree_node* node, const std::vector<tree_experience*> &instances,
00102 float bestER, int bestDim,
00103 float bestVal,
00104 const std::vector<tree_experience*> &left,
00105 const std::vector<tree_experience*> &right,
00106 bool changed, float leftError, float rightError);
00107 void compareSplits(float er, int dim, float val,
00108 const std::vector<tree_experience*> &left,
00109 const std::vector<tree_experience*> &right,
00110 int *nties, float leftError, float rightError,
00111 float *bestER, int *bestDim,
00112 float *bestVal,
00113 std::vector<tree_experience*> *bestLeft,
00114 std::vector<tree_experience*> *bestRight,
00115 float *bestLeftError, float *bestRightError);
00116 void leafPrediction(tree_node *t, const std::vector<float> &in, std::map<float, float>* retval);
00117 void makeLeaf(tree_node* node, const std::vector<tree_experience*> &instances);
00118
00119 float fitSimpleLinearModel(const std::vector<tree_experience*> &instances,
00120 float* constant, std::vector<float> *coeff);
00121 float fitMultiLinearModel(const std::vector<tree_experience*> &instances,
00122 float* constant, std::vector<float> * coeff);
00123
00124 tree_node* allocateNode();
00125 void deallocateNode(tree_node* node);
00126 void initNodes();
00127
00128 bool INCDEBUG;
00129 bool DTDEBUG;
00130 bool LMDEBUG;
00131 bool SPLITDEBUG;
00132 bool STOCH_DEBUG;
00133 bool NODEDEBUG;
00134 int nExperiences;
00135 bool COPYDEBUG;
00136
00137 float SPLIT_MARGIN;
00138
00139 private:
00140
00141 const int id;
00142
00143 const int mode;
00144 const int freq;
00145 const int M;
00146 const float featPct;
00147 const bool SIMPLE;
00148 const float MIN_ER;
00149
00150 Random rng;
00151
00152 int nOutput;
00153 int nnodes;
00154 bool hadError;
00155 int totalnodes;
00156 int maxnodes;
00157
00158
00159 std::vector<tree_experience*> experiences;
00160 tree_experience allExp[N_LST_EXP];
00161 tree_node allNodes[N_LS_NODES];
00162 std::vector<int> freeNodes;
00163
00164
00165 tree_node* root;
00166 tree_node* lastNode;
00167
00168 };
00169
00170
00171 #endif
00172