00001 #include "SepPlanExplore.hh"
00002
00003
00004
00005 SepPlanExplore::SepPlanExplore(int id, int modelType, int predType,
00006 int nModels, int trainMode,
00007 int trainFreq,
00008 float featPct, float expPct,
00009 float treeThreshold, bool stoch,
00010 float featRange, Random rng):
00011 id(id), modelType(modelType), predType(predType), nModels(nModels),
00012 mode(trainMode), freq(trainFreq),
00013 featPct(featPct), expPct(expPct),
00014 treeThresh(treeThreshold), stoch(stoch),
00015 featRange(featRange), rng(rng)
00016 {
00017 SPEDEBUG = false;
00018
00019 cout << "Created Sep Plan & Explore models " << id << " with nModels: " << nModels << endl;
00020
00021 for (int i = 0; i < id; i++)
00022 rng.uniform(0,1);
00023
00024 initModels();
00025
00026 }
00027
00028 SepPlanExplore::SepPlanExplore(const SepPlanExplore& spe):
00029 id(spe.id), modelType(spe.modelType),
00030 predType(spe.predType), nModels(spe.nModels),
00031 mode(spe.mode), freq(spe.freq),
00032 featPct(spe.featPct), expPct(spe.expPct),
00033 treeThresh(spe.treeThresh), stoch(spe.stoch),
00034 featRange(spe.featRange), rng(spe.rng)
00035 {
00036 cout << "spe get copy" << endl;
00037 SPEDEBUG = spe.SPEDEBUG;
00038 expModel = spe.expModel->getCopy();
00039 planModel = spe.planModel->getCopy();
00040 }
00041
00042 SepPlanExplore* SepPlanExplore::getCopy(){
00043 SepPlanExplore* copy = new SepPlanExplore(*this);
00044 return copy;
00045 }
00046
00047 SepPlanExplore::~SepPlanExplore() {
00048 delete expModel;
00049 delete planModel;
00050 }
00051
00052
00053 bool SepPlanExplore::trainInstances(std::vector<classPair> &instances){
00054 if (SPEDEBUG) cout << id << "SPE trainInstances: " << instances.size() << endl;
00055
00056
00057 bool expChanged = expModel->trainInstances(instances);
00058 bool planChanged = planModel->trainInstances(instances);
00059
00060 return (expChanged || planChanged);
00061
00062 }
00063
00064
00065
00066
00067 bool SepPlanExplore::trainInstance(classPair &instance){
00068 if (SPEDEBUG) cout << id << "SPE trainInstance: " << endl;
00069
00070
00071 bool expChanged = expModel->trainInstance(instance);
00072 bool planChanged = planModel->trainInstance(instance);
00073
00074 return (expChanged || planChanged);
00075
00076 }
00077
00078
00079 void SepPlanExplore::testInstance(const std::vector<float> &input, std::map<float, float>* retval){
00080 if (SPEDEBUG) cout << id << " testInstance" << endl;
00081
00082 retval->clear();
00083
00084
00085 planModel->testInstance(input, retval);
00086
00087 }
00088
00089
00090
00091 float SepPlanExplore::getConf(const std::vector<float> &input){
00092 if (SPEDEBUG) cout << "getConf" << endl;
00093
00094
00095 return expModel->getConf(input);
00096
00097 }
00098
00099
00100
00101 void SepPlanExplore::initModels(){
00102 if (SPEDEBUG) cout << "initModels()" << endl;
00103
00104 if (nModels < 2){
00105 cout << "Should really use Sep plan & explore models with multiple models" << endl;
00106 exit(-1);
00107 }
00108
00109
00110 expModel = new MultipleClassifiers(id, modelType, predType,
00111 nModels, mode, freq,
00112 featPct, expPct, treeThresh, stoch, featRange, rng);
00113
00114
00115 if (modelType == C45TREE){
00116 planModel = new C45Tree(id, mode, freq, 0, 0.0, rng);
00117 }
00118 else if (modelType == M5MULTI){
00119 planModel = new M5Tree(id, mode, freq, 0, 0.0, false, false, treeThresh, rng);
00120 }
00121 else if (modelType == M5ALLMULTI){
00122 planModel = new M5Tree(id, mode, freq, 0, 0.0, false, true, treeThresh, rng);
00123 }
00124 else if (modelType == M5ALLSINGLE){
00125 planModel = new M5Tree(id, mode, freq, 0, 0.0, true, true, treeThresh, rng);
00126 }
00127 else if (modelType == M5SINGLE){
00128 planModel = new M5Tree(id, mode, freq, 0, 0.0, true, false, treeThresh, rng);
00129 }
00130 else if (modelType == LSTSINGLE){
00131 planModel = new LinearSplitsTree(id, mode, freq, 0, 0.0, true, treeThresh, rng);
00132 }
00133 else if (modelType == LSTMULTI){
00134 planModel = new LinearSplitsTree(id, mode, freq, 0, 0.0, false, treeThresh, rng);
00135 }
00136 else if (modelType == STUMP){
00137 planModel = new Stump(id, mode, freq, 0, 0.0, rng);
00138 }
00139 else if (modelType == ALLM5TYPES){
00140
00141 bool simple = rng.bernoulli(0.5);
00142 bool allFeats = rng.bernoulli(0.5);
00143
00144 planModel = new M5Tree(id, mode, freq, 0, 0.0, simple, allFeats, treeThresh, rng);
00145 }
00146 else {
00147 cout << "Invalid model type for this committee" << endl;
00148 exit(-1);
00149 }
00150 }
00151
00152