00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 #include "laser_processor.h"
00036 #include "calc_leg_features.h"
00037
00038 #include "opencv/cxcore.h"
00039 #include "opencv/cv.h"
00040 #include "opencv/ml.h"
00041
00042 #include "people_msgs/PositionMeasurement.h"
00043 #include "sensor_msgs/LaserScan.h"
00044
00045 using namespace std;
00046 using namespace laser_processor;
00047 using namespace ros;
00048
00049 enum LoadType {LOADING_NONE, LOADING_POS, LOADING_NEG, LOADING_TEST};
00050
00051 class TrainLegDetector
00052 {
00053 public:
00054 ScanMask mask_;
00055 int mask_count_;
00056
00057 vector< vector<float> > pos_data_;
00058 vector< vector<float> > neg_data_;
00059 vector< vector<float> > test_data_;
00060
00061 CvRTrees forest;
00062
00063 float connected_thresh_;
00064
00065 int feat_count_;
00066
00067 TrainLegDetector() : mask_count_(0), connected_thresh_(0.06), feat_count_(0)
00068 {
00069 }
00070
00071 void loadData(LoadType load, char* file)
00072 {
00073 if (load != LOADING_NONE)
00074 {
00075 switch (load)
00076 {
00077 case LOADING_POS:
00078 printf("Loading positive training data from file: %s\n",file); break;
00079 case LOADING_NEG:
00080 printf("Loading negative training data from file: %s\n",file); break;
00081 case LOADING_TEST:
00082 printf("Loading test data from file: %s\n",file); break;
00083 default:
00084 break;
00085 }
00086
00087 ros::record::Player p;
00088 if (p.open(file, ros::Time()))
00089 {
00090 mask_.clear();
00091 mask_count_ = 0;
00092
00093 switch (load)
00094 {
00095 case LOADING_POS:
00096 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &pos_data_);
00097 break;
00098 case LOADING_NEG:
00099 mask_count_ = 1000;
00100 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &neg_data_);
00101 break;
00102 case LOADING_TEST:
00103 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &test_data_);
00104 break;
00105 default:
00106 break;
00107 }
00108
00109 while (p.nextMsg())
00110 {}
00111 }
00112 }
00113 }
00114
00115 void loadCb(string name, sensor_msgs::LaserScan* scan, ros::Time t, ros::Time t_no_use, void* n)
00116 {
00117 vector< vector<float> >* data = (vector< vector<float> >*)(n);
00118
00119 if (mask_count_++ < 20)
00120 {
00121 mask_.addScan(*scan);
00122 }
00123 else
00124 {
00125 ScanProcessor processor(*scan,mask_);
00126 processor.splitConnected(connected_thresh_);
00127 processor.removeLessThan(5);
00128
00129 for (list<SampleSet*>::iterator i = processor.getClusters().begin();
00130 i != processor.getClusters().end();
00131 i++)
00132 data->push_back( calcLegFeatures(*i, *scan));
00133 }
00134 }
00135
00136 void train()
00137 {
00138 int sample_size = pos_data_.size() + neg_data_.size();
00139 feat_count_ = pos_data_[0].size();
00140
00141 CvMat* cv_data = cvCreateMat( sample_size, feat_count_, CV_32FC1);
00142 CvMat* cv_resp = cvCreateMat( sample_size, 1, CV_32S);
00143
00144
00145 int j = 0;
00146 for (vector< vector<float> >::iterator i = pos_data_.begin();
00147 i != pos_data_.end();
00148 i++)
00149 {
00150 float* data_row = (float*)(cv_data->data.ptr + cv_data->step*j);
00151 for (int k = 0; k < feat_count_; k++)
00152 data_row[k] = (*i)[k];
00153
00154 cv_resp->data.i[j] = 1;
00155 j++;
00156 }
00157
00158
00159 for (vector< vector<float> >::iterator i = neg_data_.begin();
00160 i != neg_data_.end();
00161 i++)
00162 {
00163 float* data_row = (float*)(cv_data->data.ptr + cv_data->step*j);
00164 for (int k = 0; k < feat_count_; k++)
00165 data_row[k] = (*i)[k];
00166
00167 cv_resp->data.i[j] = -1;
00168 j++;
00169 }
00170
00171 CvMat* var_type = cvCreateMat( 1, feat_count_ + 1, CV_8U );
00172 cvSet( var_type, cvScalarAll(CV_VAR_ORDERED));
00173 cvSetReal1D( var_type, feat_count_, CV_VAR_CATEGORICAL );
00174
00175 float priors[] = {1.0, 1.0};
00176
00177 CvRTParams fparam(8,20,0,false,10,priors,false,5,50,0.001f,CV_TERMCRIT_ITER);
00178 fparam.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 0.1);
00179
00180 forest.train( cv_data, CV_ROW_SAMPLE, cv_resp, 0, 0, var_type, 0,
00181 fparam);
00182
00183
00184 cvReleaseMat(&cv_data);
00185 cvReleaseMat(&cv_resp);
00186 cvReleaseMat(&var_type);
00187 }
00188
00189 void test()
00190 {
00191 CvMat* tmp_mat = cvCreateMat(1,feat_count_,CV_32FC1);
00192
00193 int pos_right = 0;
00194 int pos_total = 0;
00195 for (vector< vector<float> >::iterator i = pos_data_.begin();
00196 i != pos_data_.end();
00197 i++)
00198 {
00199 for (int k = 0; k < feat_count_; k++)
00200 tmp_mat->data.fl[k] = (float)((*i)[k]);
00201 if (forest.predict( tmp_mat) > 0)
00202 pos_right++;
00203 pos_total++;
00204 }
00205
00206 int neg_right = 0;
00207 int neg_total = 0;
00208 for (vector< vector<float> >::iterator i = neg_data_.begin();
00209 i != neg_data_.end();
00210 i++)
00211 {
00212 for (int k = 0; k < feat_count_; k++)
00213 tmp_mat->data.fl[k] = (float)((*i)[k]);
00214 if (forest.predict( tmp_mat ) < 0)
00215 neg_right++;
00216 neg_total++;
00217 }
00218
00219 int test_right = 0;
00220 int test_total = 0;
00221 for (vector< vector<float> >::iterator i = test_data_.begin();
00222 i != test_data_.end();
00223 i++)
00224 {
00225 for (int k = 0; k < feat_count_; k++)
00226 tmp_mat->data.fl[k] = (float)((*i)[k]);
00227 if (forest.predict( tmp_mat ) > 0)
00228 test_right++;
00229 test_total++;
00230 }
00231
00232 printf(" Pos train set: %d/%d %g\n",pos_right, pos_total, (float)(pos_right)/pos_total);
00233 printf(" Neg train set: %d/%d %g\n",neg_right, neg_total, (float)(neg_right)/neg_total);
00234 printf(" Test set: %d/%d %g\n",test_right, test_total, (float)(test_right)/test_total);
00235
00236 cvReleaseMat(&tmp_mat);
00237
00238 }
00239
00240 void save(char* file)
00241 {
00242 forest.save(file);
00243 }
00244 };
00245
00246 int main(int argc, char **argv)
00247 {
00248 TrainLegDetector tld;
00249
00250 LoadType loading = LOADING_NONE;
00251
00252 char save_file[100];
00253 save_file[0] = 0;
00254
00255 printf("Loading data...\n");
00256 for (int i = 1; i < argc; i++)
00257 {
00258 if (!strcmp(argv[i],"--train"))
00259 loading = LOADING_POS;
00260 else if (!strcmp(argv[i],"--neg"))
00261 loading = LOADING_NEG;
00262 else if (!strcmp(argv[i],"--test"))
00263 loading = LOADING_TEST;
00264 else if (!strcmp(argv[i],"--save"))
00265 {
00266 if (++i < argc)
00267 strncpy(save_file,argv[i],100);
00268 continue;
00269 }
00270 else
00271 tld.loadData(loading, argv[i]);
00272 }
00273
00274 printf("Training classifier...\n");
00275 tld.train();
00276
00277 printf("Evlauating classifier...\n");
00278 tld.test();
00279
00280 if (strlen(save_file) > 0)
00281 {
00282 printf("Saving classifier as: %s\n", save_file);
00283 tld.save(save_file);
00284 }
00285 }