00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035 #include "laser_processor.h"
00036 #include "calc_leg_features.h"
00037
00038 #include "opencv/cxcore.h"
00039 #include "opencv/cv.h"
00040 #include "opencv/ml.h"
00041
00042 #include "people_msgs/PositionMeasurement.h"
00043 #include "sensor_msgs/LaserScan.h"
00044
00045 using namespace std;
00046 using namespace laser_processor;
00047 using namespace ros;
00048
00049 enum LoadType {LOADING_NONE, LOADING_POS, LOADING_NEG, LOADING_TEST};
00050
00051 class TrainLegDetector
00052 {
00053 public:
00054 ScanMask mask_;
00055 int mask_count_;
00056
00057 vector< vector<float> > pos_data_;
00058 vector< vector<float> > neg_data_;
00059 vector< vector<float> > test_data_;
00060
00061 CvRTrees forest;
00062
00063 float connected_thresh_;
00064
00065 int feat_count_;
00066
00067 TrainLegDetector() : mask_count_(0), connected_thresh_(0.06), feat_count_(0)
00068 {
00069 }
00070
00071 void loadData(LoadType load, char* file)
00072 {
00073 if (load != LOADING_NONE)
00074 {
00075 switch (load)
00076 {
00077 case LOADING_POS:
00078 printf("Loading positive training data from file: %s\n", file);
00079 break;
00080 case LOADING_NEG:
00081 printf("Loading negative training data from file: %s\n", file);
00082 break;
00083 case LOADING_TEST:
00084 printf("Loading test data from file: %s\n", file);
00085 break;
00086 default:
00087 break;
00088 }
00089
00090 ros::record::Player p;
00091 if (p.open(file, ros::Time()))
00092 {
00093 mask_.clear();
00094 mask_count_ = 0;
00095
00096 switch (load)
00097 {
00098 case LOADING_POS:
00099 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &pos_data_);
00100 break;
00101 case LOADING_NEG:
00102 mask_count_ = 1000;
00103 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &neg_data_);
00104 break;
00105 case LOADING_TEST:
00106 p.addHandler<sensor_msgs::LaserScan>(string("*"), &TrainLegDetector::loadCb, this, &test_data_);
00107 break;
00108 default:
00109 break;
00110 }
00111
00112 while (p.nextMsg())
00113 {}
00114 }
00115 }
00116 }
00117
00118 void loadCb(string name, sensor_msgs::LaserScan* scan, ros::Time t, ros::Time t_no_use, void* n)
00119 {
00120 vector< vector<float> >* data = (vector< vector<float> >*)(n);
00121
00122 if (mask_count_++ < 20)
00123 {
00124 mask_.addScan(*scan);
00125 }
00126 else
00127 {
00128 ScanProcessor processor(*scan, mask_);
00129 processor.splitConnected(connected_thresh_);
00130 processor.removeLessThan(5);
00131
00132 for (list<SampleSet*>::iterator i = processor.getClusters().begin();
00133 i != processor.getClusters().end();
00134 i++)
00135 data->push_back(calcLegFeatures(*i, *scan));
00136 }
00137 }
00138
00139 void train()
00140 {
00141 int sample_size = pos_data_.size() + neg_data_.size();
00142 feat_count_ = pos_data_[0].size();
00143
00144 CvMat* cv_data = cvCreateMat(sample_size, feat_count_, CV_32FC1);
00145 CvMat* cv_resp = cvCreateMat(sample_size, 1, CV_32S);
00146
00147
00148 int j = 0;
00149 for (vector< vector<float> >::iterator i = pos_data_.begin();
00150 i != pos_data_.end();
00151 i++)
00152 {
00153 float* data_row = (float*)(cv_data->data.ptr + cv_data->step * j);
00154 for (int k = 0; k < feat_count_; k++)
00155 data_row[k] = (*i)[k];
00156
00157 cv_resp->data.i[j] = 1;
00158 j++;
00159 }
00160
00161
00162 for (vector< vector<float> >::iterator i = neg_data_.begin();
00163 i != neg_data_.end();
00164 i++)
00165 {
00166 float* data_row = (float*)(cv_data->data.ptr + cv_data->step * j);
00167 for (int k = 0; k < feat_count_; k++)
00168 data_row[k] = (*i)[k];
00169
00170 cv_resp->data.i[j] = -1;
00171 j++;
00172 }
00173
00174 CvMat* var_type = cvCreateMat(1, feat_count_ + 1, CV_8U);
00175 cvSet(var_type, cvScalarAll(CV_VAR_ORDERED));
00176 cvSetReal1D(var_type, feat_count_, CV_VAR_CATEGORICAL);
00177
00178 float priors[] = {1.0, 1.0};
00179
00180 CvRTParams fparam(8, 20, 0, false, 10, priors, false, 5, 50, 0.001f, CV_TERMCRIT_ITER);
00181 fparam.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 0.1);
00182
00183 forest.train(cv_data, CV_ROW_SAMPLE, cv_resp, 0, 0, var_type, 0,
00184 fparam);
00185
00186
00187 cvReleaseMat(&cv_data);
00188 cvReleaseMat(&cv_resp);
00189 cvReleaseMat(&var_type);
00190 }
00191
00192 void test()
00193 {
00194 CvMat* tmp_mat = cvCreateMat(1, feat_count_, CV_32FC1);
00195
00196 int pos_right = 0;
00197 int pos_total = 0;
00198 for (vector< vector<float> >::iterator i = pos_data_.begin();
00199 i != pos_data_.end();
00200 i++)
00201 {
00202 for (int k = 0; k < feat_count_; k++)
00203 tmp_mat->data.fl[k] = (float)((*i)[k]);
00204 if (forest.predict(tmp_mat) > 0)
00205 pos_right++;
00206 pos_total++;
00207 }
00208
00209 int neg_right = 0;
00210 int neg_total = 0;
00211 for (vector< vector<float> >::iterator i = neg_data_.begin();
00212 i != neg_data_.end();
00213 i++)
00214 {
00215 for (int k = 0; k < feat_count_; k++)
00216 tmp_mat->data.fl[k] = (float)((*i)[k]);
00217 if (forest.predict(tmp_mat) < 0)
00218 neg_right++;
00219 neg_total++;
00220 }
00221
00222 int test_right = 0;
00223 int test_total = 0;
00224 for (vector< vector<float> >::iterator i = test_data_.begin();
00225 i != test_data_.end();
00226 i++)
00227 {
00228 for (int k = 0; k < feat_count_; k++)
00229 tmp_mat->data.fl[k] = (float)((*i)[k]);
00230 if (forest.predict(tmp_mat) > 0)
00231 test_right++;
00232 test_total++;
00233 }
00234
00235 printf(" Pos train set: %d/%d %g\n", pos_right, pos_total, (float)(pos_right) / pos_total);
00236 printf(" Neg train set: %d/%d %g\n", neg_right, neg_total, (float)(neg_right) / neg_total);
00237 printf(" Test set: %d/%d %g\n", test_right, test_total, (float)(test_right) / test_total);
00238
00239 cvReleaseMat(&tmp_mat);
00240
00241 }
00242
00243 void save(char* file)
00244 {
00245 forest.save(file);
00246 }
00247 };
00248
00249 int main(int argc, char **argv)
00250 {
00251 TrainLegDetector tld;
00252
00253 LoadType loading = LOADING_NONE;
00254
00255 char save_file[100];
00256 save_file[0] = 0;
00257
00258 printf("Loading data...\n");
00259 for (int i = 1; i < argc; i++)
00260 {
00261 if (!strcmp(argv[i], "--train"))
00262 loading = LOADING_POS;
00263 else if (!strcmp(argv[i], "--neg"))
00264 loading = LOADING_NEG;
00265 else if (!strcmp(argv[i], "--test"))
00266 loading = LOADING_TEST;
00267 else if (!strcmp(argv[i], "--save"))
00268 {
00269 if (++i < argc)
00270 strncpy(save_file, argv[i], 100);
00271 continue;
00272 }
00273 else
00274 tld.loadData(loading, argv[i]);
00275 }
00276
00277 printf("Training classifier...\n");
00278 tld.train();
00279
00280 printf("Evlauating classifier...\n");
00281 tld.test();
00282
00283 if (strlen(save_file) > 0)
00284 {
00285 printf("Saving classifier as: %s\n", save_file);
00286 tld.save(save_file);
00287 }
00288 }