Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00042 #include "ml_classifiers/svm_classifier.h"
00043 #include <pluginlib/class_list_macros.h>
00044 #include <math.h>
00045
00046 PLUGINLIB_DECLARE_CLASS(ml_classifiers, SVMClassifier, ml_classifiers::SVMClassifier, ml_classifiers::Classifier)
00047
00048 using namespace std;
00049
00050 namespace ml_classifiers{
00051
00052 SVMClassifier::SVMClassifier(){}
00053
00054 SVMClassifier::~SVMClassifier(){}
00055
00056 void SVMClassifier::save(const std::string filename){}
00057
00058 bool SVMClassifier::load(const std::string filename){return false;}
00059
00060 void SVMClassifier::addTrainingPoint(string target_class, const std::vector<double> point)
00061 {
00062 class_data[target_class].push_back(point);
00063 }
00064
00065 void SVMClassifier::train()
00066 {
00067 if(class_data.size() == 0){
00068 printf("SVMClassifier::train() -- No training data available! Doing nothing.\n");
00069 return;
00070 }
00071
00072 int n_classes = class_data.size();
00073
00074
00075 int n_data = 0;
00076 int dims = class_data.begin()->second[0].size();
00077 for(ClassMap::iterator iter = class_data.begin(); iter != class_data.end(); iter++){
00078 CPointList cpl = iter->second;
00079 if(cpl.size() == 1)
00080 n_data += 2;
00081 else
00082 n_data += cpl.size();
00083 }
00084
00085
00086 svm_data.l = n_data;
00087 svm_data.y = new double[n_data];
00088 svm_data.x = new svm_node*[n_data];
00089 for(int i=0; i<n_data; i++)
00090 svm_data.x[i] = new svm_node[dims+1];
00091
00092
00093 label_str_to_int.clear();
00094 label_int_to_str.clear();
00095 int label_n = 0;
00096 for(ClassMap::iterator iter = class_data.begin(); iter != class_data.end(); iter++){
00097 string cname = iter->first;
00098 label_str_to_int[cname] = label_n;
00099 label_int_to_str[label_n] = cname;
00100
00101 ++label_n;
00102 }
00103
00104
00105 scaling_factors = new double*[dims];
00106 for(int i=0; i<dims; i++)
00107 scaling_factors[i] = new double[2];
00108
00109
00110 for(int j=0; j<dims; j++){
00111
00112 double minval = INFINITY;
00113 double maxval = -INFINITY;
00114 for(ClassMap::iterator iter = class_data.begin(); iter != class_data.end(); iter++){
00115 CPointList cpl = iter->second;
00116 for(size_t i=0; i<cpl.size(); i++){
00117 if(cpl[i][j] < minval)
00118 minval = cpl[i][j];
00119 if(cpl[i][j] > maxval)
00120 maxval = cpl[i][j];
00121 }
00122 }
00123 double factor = maxval-minval;
00124 double offset = minval;
00125
00126
00127 for(ClassMap::iterator iter = class_data.begin(); iter != class_data.end(); iter++){
00128 for(size_t i=0; i<iter->second.size(); i++){
00129 iter->second[i][j] = (iter->second[i][j] - offset) / factor;
00130 }
00131 }
00132 scaling_factors[j][0] = offset;
00133 scaling_factors[j][1] = factor;
00134 }
00135
00136
00137 int n = 0;
00138 for(ClassMap::iterator iter = class_data.begin(); iter != class_data.end(); iter++){
00139 string cname = iter->first;
00140 CPointList cpl = iter->second;
00141
00142
00143 if(cpl.size() == 1){
00144 svm_data.y[n] = label_str_to_int[cname];
00145 svm_data.y[n+1] = label_str_to_int[cname];
00146 for(int j=0; j<dims; j++){
00147 svm_data.x[n][j].index = j;
00148 svm_data.x[n][j].value = cpl[0][j] + 0.001;
00149 svm_data.x[n+1][j].index = j;
00150 svm_data.x[n+1][j].value = cpl[0][j] + 0.001;
00151 }
00152 svm_data.x[n][dims].index = -1;
00153 svm_data.x[n+1][dims].index = -1;
00154 n = n + 2;
00155 }
00156 else{
00157 for(size_t i=0; i<cpl.size(); i++){
00158 svm_data.y[n] = label_str_to_int[cname];
00159 for(int j=0; j<dims; j++){
00160 svm_data.x[n][j].index = j;
00161 svm_data.x[n][j].value = cpl[i][j];
00162 }
00163 svm_data.x[n][dims].index = -1;
00164 n = n + 1;
00165 }
00166 }
00167 }
00168
00169
00170 svm_parameter params;
00171 params.svm_type = C_SVC;
00172 params.kernel_type = RBF;
00173 params.cache_size = 100.0;
00174 params.gamma = 1.0;
00175 params.C = 1.0;
00176 params.eps = 0.001;
00177 params.shrinking = 1;
00178 params.probability = 0;
00179 params.degree = 0;
00180 params.nr_weight = 0;
00181
00182
00183
00184 const char *err_str = svm_check_parameter(&svm_data, ¶ms);
00185 if(err_str){
00186 printf("SVMClassifier::train() -- Bad SVM parameters!\n");
00187 printf("%s\n",err_str);
00188 return;
00189 }
00190
00191
00192 int n_folds = min(10, n_data);
00193 double *resp = new double[n_data];
00194 double best_accy = 0.0;
00195 double best_g = 0.0;
00196 double best_c = 0.0;
00197
00198
00199 for(double c = -5.0; c <= 15.0; c += 2.0){
00200 for(double g = 3.0; g >= -15.0; g -= 2.0){
00201 params.gamma = pow(2,g);
00202 params.C = pow(2,c);
00203
00204 svm_cross_validation(&svm_data, ¶ms, n_folds, resp);
00205
00206
00207 int correct = 0;
00208 for(int i=0; i<n_data; i++){
00209 if(resp[i] == svm_data.y[i])
00210 ++correct;
00211 double accy = double(correct) / double(n_data);
00212 if(accy > best_accy){
00213 best_accy = accy;
00214 best_g = params.gamma;
00215 best_c = params.C;
00216 }
00217 }
00218 }
00219 }
00220
00221
00222 double start_c = best_c - 1.0;
00223 double end_c = best_c + 1.0;
00224 double start_g = best_g + 1.0;
00225 double end_g = best_g - 1.0;
00226 for(double c = start_c; c <= end_c; c += 0.1){
00227 for(double g = start_g; g >= end_g; g -= 0.1){
00228 params.gamma = pow(2,g);
00229 params.C = pow(2,c);
00230 svm_cross_validation(&svm_data, ¶ms, n_folds, resp);
00231
00232
00233 int correct = 0;
00234 for(int i=0; i<n_data; i++){
00235 if(resp[i] == svm_data.y[i])
00236 ++correct;
00237 double accy = double(correct) / double(n_data);
00238
00239 if(accy > best_accy){
00240 best_accy = accy;
00241 best_g = params.gamma;
00242 best_c = params.C;
00243 }
00244 }
00245 }
00246 }
00247
00248
00249 params.gamma = best_g;
00250 params.C = best_c;
00251
00252 printf("BEST PARAMS ncl: %i c: %f g: %f accy: %f \n\n", n_classes, best_c, best_g, best_accy);
00253
00254
00255 trained_model = svm_train(&svm_data, ¶ms);
00256 }
00257
00258 void SVMClassifier::clear()
00259 {
00260 class_data.clear();
00261 label_str_to_int.clear();
00262 label_int_to_str.clear();
00263 trained_model = NULL;
00264 scaling_factors = NULL;
00265 }
00266
00267 string SVMClassifier::classifyPoint(const std::vector<double> point)
00268 {
00269
00270 int dims = point.size();
00271 svm_node* test_pt = new svm_node[dims+1];
00272 for(int i=0; i<dims; i++){
00273 test_pt[i].index = i;
00274
00275 test_pt[i].value = (point[i]-scaling_factors[i][0]) / scaling_factors[i][1];
00276 }
00277 test_pt[dims].index = -1;
00278
00279
00280 int label_n = svm_predict(trained_model, test_pt);
00281 return label_int_to_str[label_n];
00282 }
00283 }
00284