00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <stdio.h>
00011 #include "outlet_pose_estimation/detail/learning.h"
00012
00013 CvRTrees* train_rf(CvMat* predictors, CvMat* labels)
00014 {
00015 int stat[2];
00016 get_stat(labels, stat);
00017 printf("%d negative samples, %d positive samples\n", stat[0], stat[1]);
00018
00019 const int tree_count = 500;
00020 const float priors[] = {0.25f,0.75f};
00021 CvRTrees* rtrees = new CvRTrees();
00022 CvRTParams rtparams = CvRTParams(5, 10, 0, false, 2, priors, true,
00023 (int)sqrt((float)predictors->cols), tree_count, 1e-6,
00024 CV_TERMCRIT_ITER + CV_TERMCRIT_EPS);
00025 CvMat* var_type = cvCreateMat(predictors->cols + 1, 1, CV_8UC1);
00026 for(int i = 0; i < predictors->cols; i++)
00027 {
00028 *(int*)(var_type->data.ptr + i*var_type->step) = CV_VAR_NUMERICAL;
00029 }
00030 *(int*)(var_type->data.ptr + predictors->cols*var_type->step) = CV_VAR_CATEGORICAL;
00031 rtrees->train(predictors, CV_ROW_SAMPLE, labels, 0, 0, var_type, 0, rtparams);
00032 return rtrees;
00033 }
00034
00035 void get_stat(CvMat* labels, int* stat)
00036 {
00037 stat[0] = 0;
00038 stat[1] = 0;
00039 for(int i = 0; i < labels->rows; i++)
00040 {
00041 int val = *(int*)(labels->data.ptr + labels->step*i);
00042 stat[val]++;
00043 }
00044 }