Go to the documentation of this file.00001 #ifndef _STUMP_HH_
00002 #define _STUMP_HH_
00003
00004 #include <rl_common/Random.h>
00005 #include <rl_common/core.hh>
00006 #include <vector>
00007 #include <set>
00008 #include <map>
00009
00010 #define N_STUMP_EXP 250000
00011
00013 class Stump: public Classifier {
00014
00015 public:
00016
00017
00018
00019 Stump(int id, int trainMode, int trainFreq, int m, float featPct, Random rng);
00020
00021 Stump(const Stump&);
00022 virtual Stump* getCopy();
00023
00024 ~Stump();
00025
00026
00027 struct stump_experience;
00028
00029 struct stump_experience {
00030 std::vector<float> input;
00031 float output;
00032 int id;
00033 };
00034
00035 enum splitTypes{
00036 ONLY,
00037 CUT
00038 };
00039
00040 bool trainInstances(std::vector<classPair> &instances);
00041 bool trainInstance(classPair &instance);
00042 void testInstance(const std::vector<float> &input, std::map<float, float>* retval);
00043 float getConf(const std::vector<float> &input);
00044
00045 void buildStump();
00046
00047
00048 void initStump();
00049 bool passTest(int dim, float val, int type, const std::vector<float> &input);
00050 float calcGainRatio(int dim, float val, int type,float I);
00051 float* sortOnDim(int dim);
00052 float calcIofP(float* P, int size);
00053 float calcIforSet(const std::vector<stump_experience*> &instances);
00054 void printStump();
00055 void testPossibleSplits(float *bestGainRatio, int *bestDim,
00056 float *bestVal, int *bestType);
00057 void implementSplit(float bestGainRatio, int bestDim,
00058 float bestVal, int bestType);
00059 void compareSplits(float gainRatio, int dim, float val, int type,
00060 int *nties, float *bestGainRatio, int *bestDim,
00061 float *bestVal, int *bestType);
00062
00063 void outputProbabilities(std::multiset<float> outputs, std::map<float, float>* retval);
00064 int findMatching(const std::vector<stump_experience*> &instances, int dim,
00065 int val, int minConf);
00066
00067 void setParams(float margin, float forestPct, float minRatio);
00068
00069 bool ALLOW_ONLY_SPLITS;
00070
00071 bool STDEBUG;
00072 bool SPLITDEBUG;
00073 int nExperiences;
00074
00075 float SPLIT_MARGIN;
00076 float MIN_GAIN_RATIO;
00077 float REBUILD_RATIO;
00078 float LOSS_MARGIN;
00079
00080 private:
00081
00082 const int id;
00083
00084 const int mode;
00085 const int freq;
00086 const int M;
00087 float featPct;
00088
00089 Random rng;
00090
00091 int nOutput;
00092 int nnodes;
00093
00094
00095 std::vector<stump_experience*> experiences;
00096 stump_experience allExp[N_STUMP_EXP];
00097
00098
00099 int dim;
00100 float val;
00101 int type;
00102 float gainRatio;
00103
00104
00105 std::multiset<float> leftOutputs;
00106 std::multiset<float> rightOutputs;
00107
00108
00109
00110 };
00111
00112
00113 #endif
00114