00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include <ctype.h>
00005 #include <errno.h>
00006 #include "svm.h"
00007 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
00008
00009 void print_null(const char *s) {}
00010
00011 void exit_with_help()
00012 {
00013 printf(
00014 "Usage: svm-train [options] training_set_file [model_file]\n"
00015 "options:\n"
00016 "-s svm_type : set type of SVM (default 0)\n"
00017 " 0 -- C-SVC (multi-class classification)\n"
00018 " 1 -- nu-SVC (multi-class classification)\n"
00019 " 2 -- one-class SVM\n"
00020 " 3 -- epsilon-SVR (regression)\n"
00021 " 4 -- nu-SVR (regression)\n"
00022 "-t kernel_type : set type of kernel function (default 2)\n"
00023 " 0 -- linear: u'*v\n"
00024 " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
00025 " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
00026 " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
00027 " 4 -- precomputed kernel (kernel values in training_set_file)\n"
00028 "-d degree : set degree in kernel function (default 3)\n"
00029 "-g gamma : set gamma in kernel function (default 1/num_features)\n"
00030 "-r coef0 : set coef0 in kernel function (default 0)\n"
00031 "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
00032 "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
00033 "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
00034 "-m cachesize : set cache memory size in MB (default 100)\n"
00035 "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
00036 "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
00037 "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
00038 "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
00039 "-v n: n-fold cross validation mode\n"
00040 "-q : quiet mode (no outputs)\n"
00041 );
00042 exit(1);
00043 }
00044
00045 void exit_input_error(int line_num)
00046 {
00047 fprintf(stderr,"Wrong input format at line %d\n", line_num);
00048 exit(1);
00049 }
00050
00051 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
00052 void read_problem(const char *filename);
00053 void do_cross_validation();
00054
00055 struct svm_parameter param;
00056 struct svm_problem prob;
00057 struct svm_model *model;
00058 struct svm_node *x_space;
00059 int cross_validation;
00060 int nr_fold;
00061
00062 static char *line = NULL;
00063 static int max_line_len;
00064
00065 static char* readline(FILE *input)
00066 {
00067 int len;
00068
00069 if(fgets(line,max_line_len,input) == NULL)
00070 return NULL;
00071
00072 while(strrchr(line,'\n') == NULL)
00073 {
00074 max_line_len *= 2;
00075 line = (char *) realloc(line,max_line_len);
00076 len = (int) strlen(line);
00077 if(fgets(line+len,max_line_len-len,input) == NULL)
00078 break;
00079 }
00080 return line;
00081 }
00082
00083 int main(int argc, char **argv)
00084 {
00085 char input_file_name[1024];
00086 char model_file_name[1024];
00087 const char *error_msg;
00088
00089 parse_command_line(argc, argv, input_file_name, model_file_name);
00090 read_problem(input_file_name);
00091 error_msg = svm_check_parameter(&prob,¶m);
00092
00093 if(error_msg)
00094 {
00095 fprintf(stderr,"ERROR: %s\n",error_msg);
00096 exit(1);
00097 }
00098
00099 if(cross_validation)
00100 {
00101 do_cross_validation();
00102 }
00103 else
00104 {
00105 model = svm_train(&prob,¶m);
00106 if(svm_save_model(model_file_name,model))
00107 {
00108 fprintf(stderr, "can't save model to file %s\n", model_file_name);
00109 exit(1);
00110 }
00111 svm_free_and_destroy_model(&model);
00112 }
00113 svm_destroy_param(¶m);
00114 free(prob.y);
00115 free(prob.x);
00116 free(x_space);
00117 free(line);
00118
00119 return 0;
00120 }
00121
00122 void do_cross_validation()
00123 {
00124 int i;
00125 int total_correct = 0;
00126 double total_error = 0;
00127 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
00128 double *target = Malloc(double,prob.l);
00129
00130 svm_cross_validation(&prob,¶m,nr_fold,target);
00131 if(param.svm_type == EPSILON_SVR ||
00132 param.svm_type == NU_SVR)
00133 {
00134 for(i=0;i<prob.l;i++)
00135 {
00136 double y = prob.y[i];
00137 double v = target[i];
00138 total_error += (v-y)*(v-y);
00139 sumv += v;
00140 sumy += y;
00141 sumvv += v*v;
00142 sumyy += y*y;
00143 sumvy += v*y;
00144 }
00145 printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
00146 printf("Cross Validation Squared correlation coefficient = %g\n",
00147 ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
00148 ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
00149 );
00150 }
00151 else
00152 {
00153 for(i=0;i<prob.l;i++)
00154 if(target[i] == prob.y[i])
00155 ++total_correct;
00156 printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
00157 }
00158 free(target);
00159 }
00160
00161 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
00162 {
00163 int i;
00164 void (*print_func)(const char*) = NULL;
00165
00166
00167 param.svm_type = C_SVC;
00168 param.kernel_type = RBF;
00169 param.degree = 3;
00170 param.gamma = 0;
00171 param.coef0 = 0;
00172 param.nu = 0.5;
00173 param.cache_size = 100;
00174 param.C = 1;
00175 param.eps = 1e-3;
00176 param.p = 0.1;
00177 param.shrinking = 1;
00178 param.probability = 0;
00179 param.nr_weight = 0;
00180 param.weight_label = NULL;
00181 param.weight = NULL;
00182 cross_validation = 0;
00183
00184
00185 for(i=1;i<argc;i++)
00186 {
00187 if(argv[i][0] != '-') break;
00188 if(++i>=argc)
00189 exit_with_help();
00190 switch(argv[i-1][1])
00191 {
00192 case 's':
00193 param.svm_type = atoi(argv[i]);
00194 break;
00195 case 't':
00196 param.kernel_type = atoi(argv[i]);
00197 break;
00198 case 'd':
00199 param.degree = atoi(argv[i]);
00200 break;
00201 case 'g':
00202 param.gamma = atof(argv[i]);
00203 break;
00204 case 'r':
00205 param.coef0 = atof(argv[i]);
00206 break;
00207 case 'n':
00208 param.nu = atof(argv[i]);
00209 break;
00210 case 'm':
00211 param.cache_size = atof(argv[i]);
00212 break;
00213 case 'c':
00214 param.C = atof(argv[i]);
00215 break;
00216 case 'e':
00217 param.eps = atof(argv[i]);
00218 break;
00219 case 'p':
00220 param.p = atof(argv[i]);
00221 break;
00222 case 'h':
00223 param.shrinking = atoi(argv[i]);
00224 break;
00225 case 'b':
00226 param.probability = atoi(argv[i]);
00227 break;
00228 case 'q':
00229 print_func = &print_null;
00230 i--;
00231 break;
00232 case 'v':
00233 cross_validation = 1;
00234 nr_fold = atoi(argv[i]);
00235 if(nr_fold < 2)
00236 {
00237 fprintf(stderr,"n-fold cross validation: n must >= 2\n");
00238 exit_with_help();
00239 }
00240 break;
00241 case 'w':
00242 ++param.nr_weight;
00243 param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
00244 param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
00245 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
00246 param.weight[param.nr_weight-1] = atof(argv[i]);
00247 break;
00248 default:
00249 fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
00250 exit_with_help();
00251 }
00252 }
00253
00254 svm_set_print_string_function(print_func);
00255
00256
00257
00258 if(i>=argc)
00259 exit_with_help();
00260
00261 strcpy(input_file_name, argv[i]);
00262
00263 if(i<argc-1)
00264 strcpy(model_file_name,argv[i+1]);
00265 else
00266 {
00267 char *p = strrchr(argv[i],'/');
00268 if(p==NULL)
00269 p = argv[i];
00270 else
00271 ++p;
00272 sprintf(model_file_name,"%s.model",p);
00273 }
00274 }
00275
00276
00277
00278 void read_problem(const char *filename)
00279 {
00280 int max_index, inst_max_index, i;
00281 size_t elements, j;
00282 FILE *fp = fopen(filename,"r");
00283 char *endptr;
00284 char *idx, *val, *label;
00285
00286 if(fp == NULL)
00287 {
00288 fprintf(stderr,"can't open input file %s\n",filename);
00289 exit(1);
00290 }
00291
00292 prob.l = 0;
00293 elements = 0;
00294
00295 max_line_len = 1024;
00296 line = Malloc(char,max_line_len);
00297 while(readline(fp)!=NULL)
00298 {
00299 char *p = strtok(line," \t");
00300
00301
00302 while(1)
00303 {
00304 p = strtok(NULL," \t");
00305 if(p == NULL || *p == '\n')
00306 break;
00307 ++elements;
00308 }
00309 ++elements;
00310 ++prob.l;
00311 }
00312 rewind(fp);
00313
00314 prob.y = Malloc(double,prob.l);
00315 prob.x = Malloc(struct svm_node *,prob.l);
00316 x_space = Malloc(struct svm_node,elements);
00317
00318 max_index = 0;
00319 j=0;
00320 for(i=0;i<prob.l;i++)
00321 {
00322 inst_max_index = -1;
00323 readline(fp);
00324 prob.x[i] = &x_space[j];
00325 label = strtok(line," \t\n");
00326 if(label == NULL)
00327 exit_input_error(i+1);
00328
00329 prob.y[i] = strtod(label,&endptr);
00330 if(endptr == label || *endptr != '\0')
00331 exit_input_error(i+1);
00332
00333 while(1)
00334 {
00335 idx = strtok(NULL,":");
00336 val = strtok(NULL," \t");
00337
00338 if(val == NULL)
00339 break;
00340
00341 errno = 0;
00342 x_space[j].index = (int) strtol(idx,&endptr,10);
00343 if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
00344 exit_input_error(i+1);
00345 else
00346 inst_max_index = x_space[j].index;
00347
00348 errno = 0;
00349 x_space[j].value = strtod(val,&endptr);
00350 if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
00351 exit_input_error(i+1);
00352
00353 ++j;
00354 }
00355
00356 if(inst_max_index > max_index)
00357 max_index = inst_max_index;
00358 x_space[j++].index = -1;
00359 }
00360
00361 if(param.gamma == 0 && max_index > 0)
00362 param.gamma = 1.0/max_index;
00363
00364 if(param.kernel_type == PRECOMPUTED)
00365 for(i=0;i<prob.l;i++)
00366 {
00367 if (prob.x[i][0].index != 0)
00368 {
00369 fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
00370 exit(1);
00371 }
00372 if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
00373 {
00374 fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
00375 exit(1);
00376 }
00377 }
00378
00379 fclose(fp);
00380 }