00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 #include "laser_processor.h"
00036 #include "calc_leg_features.h"
00037
00038 #include "opencv/cxcore.h"
00039 #include "opencv/cv.h"
00040 #include "opencv/ml.h"
00041
00042
00043 #include "rosbag/bag.h"
00044 #include <rosbag/view.h>
00045 #include <boost/foreach.hpp>
00046
00047 #include "std_msgs/String.h"
00048 #include "std_msgs/Int32.h"
00049
00050 #include "srs_msgs/PositionMeasurement.h"
00051 #include "sensor_msgs/LaserScan.h"
00052
00053
00054
00055 using namespace std;
00056 using namespace laser_processor;
00057 using namespace ros;
00058
00059 enum LoadType {LOADING_NONE, LOADING_POS, LOADING_NEG, LOADING_TEST, LOADING_BACKGROUND};
00060
00061 class TrainLegDetector
00062 {
00063 public:
00064 ScanMask mask_;
00065 Background background_;
00066
00067 int mask_count_;
00068
00069 vector< vector<float> > pos_data_;
00070 vector< vector<float> > neg_data_;
00071 vector< vector<float> > test_data_;
00072
00073 CvRTrees forest;
00074
00075 float connected_thresh_;
00076
00077 int feat_count_;
00078
00079 TrainLegDetector() : mask_count_(0), connected_thresh_(0.06), feat_count_(0)
00080 {}
00081
00082
00083 void loadData(LoadType load, char* file)
00084 {
00085
00086 if (load != LOADING_NONE)
00087 {
00088 switch (load)
00089 {
00090 case LOADING_BACKGROUND:
00091 printf("Loading background data from file: %s\n",file); break;
00092 case LOADING_POS:
00093 printf("Loading positive training data from file: %s\n",file); break;
00094 case LOADING_NEG:
00095 printf("Loading negative training data from file: %s\n",file); break;
00096 case LOADING_TEST:
00097 printf("Loading test data from file: %s\n",file); break;
00098 default:
00099 break;
00100 }
00101
00102
00103 rosbag::Bag bag;
00104
00105 bag.open(file, rosbag::bagmode::Read);
00106
00107 std::vector<std::string> topics;
00108 topics.push_back(std::string("/scan_front"));
00109
00110
00111
00112 rosbag::View view(bag, rosbag::TopicQuery(topics));
00113 printf("file size %i scans \n", view.size());
00114
00115
00116 BOOST_FOREACH(rosbag::MessageInstance const m, view)
00117 {
00118
00119 sensor_msgs::LaserScan::ConstPtr scanptr = m.instantiate<sensor_msgs::LaserScan>();
00120
00121 if (scanptr != NULL)
00122 {
00123
00124 switch (load)
00125 {
00126 case LOADING_BACKGROUND:
00127 loadBackground ( &(sensor_msgs::LaserScan(*scanptr)));
00128 break;
00129
00130 case LOADING_POS:
00131 loadCbBackgroundremoved ( &(sensor_msgs::LaserScan(*scanptr)),&pos_data_);
00132
00133 break;
00134
00135 case LOADING_NEG:
00136 mask_count_ = 1000;
00137 loadCb ( &(sensor_msgs::LaserScan(*scanptr)),&neg_data_);
00138 break;
00139
00140 case LOADING_TEST:
00141 loadCb ( &(sensor_msgs::LaserScan(*scanptr)),&test_data_);
00142 break;
00143
00144 default:
00145 break;
00146 }
00147 }
00148 else
00149 printf ("loadData(), scanptr is NULL");
00150
00151 }
00152
00153 bag.close();
00154
00155 }
00156
00157 }
00158
00159 void loadBackground(sensor_msgs::LaserScan* scan)
00160 {
00161
00162 background_.addScan(*scan, 0.3);
00163 }
00164
00165 void loadCb(sensor_msgs::LaserScan* scan, vector< vector<float> >* data)
00166 {
00167
00168
00169 if (mask_count_++ < 20)
00170 {
00171 mask_.addScan(*scan);
00172 }
00173 else
00174 {
00175 ScanProcessor processor(*scan,mask_);
00176 processor.splitConnected(connected_thresh_);
00177 processor.removeLessThan(5);
00178
00179 for (list<SampleSet*>::iterator i = processor.getClusters().begin();
00180 i != processor.getClusters().end();
00181 i++)
00182 data->push_back( calcLegFeatures(*i, *scan));
00183 }
00184 }
00185
00186 void loadCbBackgroundremoved(sensor_msgs::LaserScan* scan, vector< vector<float> >* data)
00187 {
00188
00189
00190 if (mask_count_++ < 20)
00191 {
00192 mask_.addScan(*scan);
00193 }
00194 else
00195 {
00196 ScanProcessor processor(*scan,mask_);
00197
00198 processor.splitConnected(connected_thresh_);
00199 processor.removeLessThan(5);
00200
00201 for (list<SampleSet*>::iterator i = processor.getClusters().begin();
00202 i != processor.getClusters().end();
00203 i++)
00204 data->push_back( calcLegFeatures(*i, *scan));
00205 }
00206 }
00207
00208 void traincl()
00209 {
00210
00211 int sample_size = pos_data_.size() + neg_data_.size();
00212 feat_count_ = pos_data_[0].size();
00213
00214 CvMat* cv_data = cvCreateMat( sample_size, feat_count_, CV_32FC1);
00215 CvMat* cv_resp = cvCreateMat( sample_size, 1, CV_32S);
00216
00217
00218 int j = 0;
00219 for (vector< vector<float> >::iterator i = pos_data_.begin();
00220 i != pos_data_.end();
00221 i++)
00222 {
00223 float* data_row = (float*)(cv_data->data.ptr + cv_data->step*j);
00224 for (int k = 0; k < feat_count_; k++)
00225 data_row[k] = (*i)[k];
00226
00227 cv_resp->data.i[j] = 1;
00228 j++;
00229 }
00230
00231
00232 for (vector< vector<float> >::iterator i = neg_data_.begin();
00233 i != neg_data_.end();
00234 i++)
00235 {
00236 float* data_row = (float*)(cv_data->data.ptr + cv_data->step*j);
00237 for (int k = 0; k < feat_count_; k++)
00238 data_row[k] = (*i)[k];
00239
00240 cv_resp->data.i[j] = -1;
00241 j++;
00242 }
00243
00244 CvMat* var_type = cvCreateMat( 1, feat_count_ + 1, CV_8U );
00245 cvSet( var_type, cvScalarAll(CV_VAR_ORDERED));
00246 cvSetReal1D( var_type, feat_count_, CV_VAR_CATEGORICAL );
00247
00248 float priors[] = {1.0, 1.0};
00249
00250 CvRTParams fparam(8,
00251 20,
00252 0,
00253 false,
00254 10,
00255 priors,
00256 false,
00257 5,
00258 50,
00259 0.001f,
00260 CV_TERMCRIT_ITER
00261 );
00262 fparam.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 0.1);
00263
00264 forest.train( cv_data,
00265 CV_ROW_SAMPLE,
00266 cv_resp,
00267 0,
00268 0,
00269 var_type,
00270 0,
00271 fparam);
00272
00273
00274 cvReleaseMat(&cv_data);
00275 cvReleaseMat(&cv_resp);
00276 cvReleaseMat(&var_type);
00277 }
00278
00279 void test()
00280 {
00281 CvMat* tmp_mat = cvCreateMat(1,feat_count_,CV_32FC1);
00282
00283 int pos_right = 0;
00284 int pos_total = 0;
00285 for (vector< vector<float> >::iterator i = pos_data_.begin();
00286 i != pos_data_.end();
00287 i++)
00288 {
00289 for (int k = 0; k < feat_count_; k++)
00290 tmp_mat->data.fl[k] = (float)((*i)[k]);
00291 if (forest.predict( tmp_mat) > 0)
00292 pos_right++;
00293 pos_total++;
00294 }
00295
00296 int neg_right = 0;
00297 int neg_total = 0;
00298 for (vector< vector<float> >::iterator i = neg_data_.begin();
00299 i != neg_data_.end();
00300 i++)
00301 {
00302 for (int k = 0; k < feat_count_; k++)
00303 tmp_mat->data.fl[k] = (float)((*i)[k]);
00304 if (forest.predict( tmp_mat ) < 0)
00305 neg_right++;
00306 neg_total++;
00307 }
00308
00309 int test_right = 0;
00310 int test_total = 0;
00311 for (vector< vector<float> >::iterator i = test_data_.begin();
00312 i != test_data_.end();
00313 i++)
00314 {
00315 for (int k = 0; k < feat_count_; k++)
00316 tmp_mat->data.fl[k] = (float)((*i)[k]);
00317 if (forest.predict( tmp_mat ) > 0)
00318 test_right++;
00319 test_total++;
00320 }
00321
00322 printf(" Pos train set: %d/%d %g\n",pos_right, pos_total, (float)(pos_right)/pos_total);
00323 printf(" Neg train set: %d/%d %g\n",neg_right, neg_total, (float)(neg_right)/neg_total);
00324 printf(" Test set: %d/%d %g\n",test_right, test_total, (float)(test_right)/test_total);
00325
00326 cvReleaseMat(&tmp_mat);
00327
00328 }
00329
00330 void save(char* file)
00331 {
00332 forest.save(file);
00333 }
00334 };
00335
00336 int main(int argc, char **argv)
00337 {
00338 if (argc < 2)
00339 {
00340 printf("Usage: train_leg_detector --background background_file --train file1 --neg file2 --test file3 --save conf_file\n");
00341 exit (0);
00342 }
00343
00344 TrainLegDetector tld;
00345
00346 LoadType loading = LOADING_NONE;
00347
00348 char save_file[100];
00349 save_file[0] = 0;
00350
00351 printf("Loading data...\n");
00352 for (int i = 1; i < argc; i++)
00353 {
00354
00355 if (!strcmp(argv[i],"--background"))
00356 loading = LOADING_BACKGROUND;
00357 else if (!strcmp(argv[i],"--train"))
00358 loading = LOADING_POS;
00359 else if (!strcmp(argv[i],"--neg"))
00360 loading = LOADING_NEG;
00361 else if (!strcmp(argv[i],"--test"))
00362 loading = LOADING_TEST;
00363 else if (!strcmp(argv[i],"--save"))
00364 {
00365 if (++i < argc)
00366 strncpy(save_file,argv[i],100);
00367 continue;
00368 }
00369 else
00370 tld.loadData(loading, argv[i]);
00371 }
00372
00373 printf("Training classifier...\n");
00374
00375 tld.traincl();
00376
00377 printf("Evlauating classifier...\n");
00378 tld.test();
00379
00380 if (strlen(save_file) > 0)
00381 {
00382 printf("Saving classifier as: %s\n", save_file);
00383 tld.save(save_file);
00384 }
00385 }