$search
00001 /* 00002 * learning.cpp 00003 * outlet_model 00004 * 00005 * Created by Victor Eruhimov on 12/29/08. 00006 * Copyright 2008 Argus Corp. All rights reserved. 00007 * 00008 */ 00009 00010 #include <stdio.h> 00011 #include "outlet_pose_estimation/detail/learning.h" 00012 00013 CvRTrees* train_rf(CvMat* predictors, CvMat* labels) 00014 { 00015 int stat[2]; 00016 get_stat(labels, stat); 00017 printf("%d negative samples, %d positive samples\n", stat[0], stat[1]); 00018 00019 const int tree_count = 500; 00020 const float priors[] = {0.25f,0.75f}; 00021 CvRTrees* rtrees = new CvRTrees(); 00022 CvRTParams rtparams = CvRTParams(5, 10, 0, false, 2, priors, true, 00023 (int)sqrt((float)predictors->cols), tree_count, 1e-6, 00024 CV_TERMCRIT_ITER + CV_TERMCRIT_EPS); 00025 CvMat* var_type = cvCreateMat(predictors->cols + 1, 1, CV_8UC1); 00026 for(int i = 0; i < predictors->cols; i++) 00027 { 00028 *(int*)(var_type->data.ptr + i*var_type->step) = CV_VAR_NUMERICAL; 00029 } 00030 *(int*)(var_type->data.ptr + predictors->cols*var_type->step) = CV_VAR_CATEGORICAL; 00031 rtrees->train(predictors, CV_ROW_SAMPLE, labels, 0, 0, var_type, 0, rtparams); 00032 return rtrees; 00033 } 00034 00035 void get_stat(CvMat* labels, int* stat) 00036 { 00037 stat[0] = 0; 00038 stat[1] = 0; 00039 for(int i = 0; i < labels->rows; i++) 00040 { 00041 int val = *(int*)(labels->data.ptr + labels->step*i); 00042 stat[val]++; 00043 } 00044 }