00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 #include "cob_leg_detection/laser_processor.h"
00036 #include "cob_leg_detection/calc_leg_features.h"
00037
00038 #include <opencv2/opencv.hpp>
00039
00040 #include "cob_perception_msgs/PositionMeasurement.h"
00041 #include "sensor_msgs/LaserScan.h"
00042
00043 using namespace std;
00044 using namespace laser_processor;
00045 using namespace ros;
00046
00047 enum LoadType {LOADING_NONE, LOADING_POS, LOADING_NEG, LOADING_TEST};
00048
00049 class TrainLegDetector
00050 {
00051 public:
00052 ScanMask mask_;
00053 int mask_count_;
00054
00055 vector< vector<float> > pos_data_;
00056 vector< vector<float> > neg_data_;
00057 vector< vector<float> > test_data_;
00058
00059 CvRTrees forest;
00060
00061 float connected_thresh_;
00062
00063 int feat_count_;
00064
00065 TrainLegDetector() : mask_count_(0), connected_thresh_(0.06), feat_count_(0)
00066 {
00067 }
00068
00069 void loadData(LoadType load, char* file)
00070 {
00071 if (load != LOADING_NONE)
00072 {
00073 switch (load)
00074 {
00075 case LOADING_POS:
00076 printf("Loading positive training data from file: %s\n",file); break;
00077 case LOADING_NEG:
00078 printf("Loading negative training data from file: %s\n",file); break;
00079 case LOADING_TEST:
00080 printf("Loading test data from file: %s\n",file); break;
00081 default:
00082 break;
00083 }
00084
00085 ros::record::Player p;
00086 if (p.open(file, ros::Time()))
00087 {
00088 mask_.clear();
00089 mask_count_ = 0;
00090
00091 switch (load)
00092 {
00093 case LOADING_POS:
00094 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &pos_data_);
00095 break;
00096 case LOADING_NEG:
00097 mask_count_ = 1000;
00098 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &neg_data_);
00099 break;
00100 case LOADING_TEST:
00101 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &test_data_);
00102 break;
00103 default:
00104 break;
00105 }
00106
00107 while (p.nextMsg())
00108 {}
00109 }
00110 }
00111 }
00112
00113 void loadCb(string name, sensor_msgs::LaserScan* scan, ros::Time t, ros::Time t_no_use, void* n)
00114 {
00115 vector< vector<float> >* data = (vector< vector<float> >*)(n);
00116
00117 if (mask_count_++ < 20)
00118 {
00119 mask_.addScan(*scan);
00120 }
00121 else
00122 {
00123 ScanProcessor processor(*scan,mask_);
00124 processor.splitConnected(connected_thresh_);
00125 processor.removeLessThan(5);
00126
00127 for (list<SampleSet*>::iterator i = processor.getClusters().begin();
00128 i != processor.getClusters().end();
00129 i++)
00130 data->push_back( calcLegFeatures(*i, *scan));
00131 }
00132 }
00133
00134 void train()
00135 {
00136 int sample_size = pos_data_.size() + neg_data_.size();
00137 feat_count_ = pos_data_[0].size();
00138
00139 CvMat* cv_data = cvCreateMat( sample_size, feat_count_, CV_32FC1);
00140 CvMat* cv_resp = cvCreateMat( sample_size, 1, CV_32S);
00141
00142
00143 int j = 0;
00144 for (vector< vector<float> >::iterator i = pos_data_.begin();
00145 i != pos_data_.end();
00146 i++)
00147 {
00148 float* data_row = (float*)(cv_data->data.ptr + cv_data->step*j);
00149 for (int k = 0; k < feat_count_; k++)
00150 data_row[k] = (*i)[k];
00151
00152 cv_resp->data.i[j] = 1;
00153 j++;
00154 }
00155
00156
00157 for (vector< vector<float> >::iterator i = neg_data_.begin();
00158 i != neg_data_.end();
00159 i++)
00160 {
00161 float* data_row = (float*)(cv_data->data.ptr + cv_data->step*j);
00162 for (int k = 0; k < feat_count_; k++)
00163 data_row[k] = (*i)[k];
00164
00165 cv_resp->data.i[j] = -1;
00166 j++;
00167 }
00168
00169 CvMat* var_type = cvCreateMat( 1, feat_count_ + 1, CV_8U );
00170 cvSet( var_type, cvScalarAll(CV_VAR_ORDERED));
00171 cvSetReal1D( var_type, feat_count_, CV_VAR_CATEGORICAL );
00172
00173 float priors[] = {1.0, 1.0};
00174
00175 CvRTParams fparam(8,20,0,false,10,priors,false,5,50,0.001f,CV_TERMCRIT_ITER);
00176 fparam.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 0.1);
00177
00178 forest.train( cv_data, CV_ROW_SAMPLE, cv_resp, 0, 0, var_type, 0,
00179 fparam);
00180
00181
00182 cvReleaseMat(&cv_data);
00183 cvReleaseMat(&cv_resp);
00184 cvReleaseMat(&var_type);
00185 }
00186
00187 void test()
00188 {
00189 CvMat* tmp_mat = cvCreateMat(1,feat_count_,CV_32FC1);
00190
00191 int pos_right = 0;
00192 int pos_total = 0;
00193 for (vector< vector<float> >::iterator i = pos_data_.begin();
00194 i != pos_data_.end();
00195 i++)
00196 {
00197 for (int k = 0; k < feat_count_; k++)
00198 tmp_mat->data.fl[k] = (float)((*i)[k]);
00199 if (forest.predict( tmp_mat) > 0)
00200 pos_right++;
00201 pos_total++;
00202 }
00203
00204 int neg_right = 0;
00205 int neg_total = 0;
00206 for (vector< vector<float> >::iterator i = neg_data_.begin();
00207 i != neg_data_.end();
00208 i++)
00209 {
00210 for (int k = 0; k < feat_count_; k++)
00211 tmp_mat->data.fl[k] = (float)((*i)[k]);
00212 if (forest.predict( tmp_mat ) < 0)
00213 neg_right++;
00214 neg_total++;
00215 }
00216
00217 int test_right = 0;
00218 int test_total = 0;
00219 for (vector< vector<float> >::iterator i = test_data_.begin();
00220 i != test_data_.end();
00221 i++)
00222 {
00223 for (int k = 0; k < feat_count_; k++)
00224 tmp_mat->data.fl[k] = (float)((*i)[k]);
00225 if (forest.predict( tmp_mat ) > 0)
00226 test_right++;
00227 test_total++;
00228 }
00229
00230 printf(" Pos train set: %d/%d %g\n",pos_right, pos_total, (float)(pos_right)/pos_total);
00231 printf(" Neg train set: %d/%d %g\n",neg_right, neg_total, (float)(neg_right)/neg_total);
00232 printf(" Test set: %d/%d %g\n",test_right, test_total, (float)(test_right)/test_total);
00233
00234 cvReleaseMat(&tmp_mat);
00235
00236 }
00237
00238 void save(char* file)
00239 {
00240 forest.save(file);
00241 }
00242 };
00243
00244 int main(int argc, char **argv)
00245 {
00246 TrainLegDetector tld;
00247
00248 LoadType loading = LOADING_NONE;
00249
00250 char save_file[100];
00251 save_file[0] = 0;
00252
00253 printf("Loading data...\n");
00254 for (int i = 1; i < argc; i++)
00255 {
00256 if (!strcmp(argv[i],"--train"))
00257 loading = LOADING_POS;
00258 else if (!strcmp(argv[i],"--neg"))
00259 loading = LOADING_NEG;
00260 else if (!strcmp(argv[i],"--test"))
00261 loading = LOADING_TEST;
00262 else if (!strcmp(argv[i],"--save"))
00263 {
00264 if (++i < argc)
00265 strncpy(save_file,argv[i],100);
00266 continue;
00267 }
00268 else
00269 tld.loadData(loading, argv[i]);
00270 }
00271
00272 printf("Training classifier...\n");
00273 tld.train();
00274
00275 printf("Evlauating classifier...\n");
00276 tld.test();
00277
00278 if (strlen(save_file) > 0)
00279 {
00280 printf("Saving classifier as: %s\n", save_file);
00281 tld.save(save_file);
00282 }
00283 }