svm-train.c
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include <ctype.h>
00005 #include <errno.h>
00006 #include "svm.h"
00007 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
00008 
00009 void print_null(const char *s) {}
00010 
00011 void exit_with_help()
00012 {
00013         printf(
00014         "Usage: svm-train [options] training_set_file [model_file]\n"
00015         "options:\n"
00016         "-s svm_type : set type of SVM (default 0)\n"
00017         "       0 -- C-SVC\n"
00018         "       1 -- nu-SVC\n"
00019         "       2 -- one-class SVM\n"
00020         "       3 -- epsilon-SVR\n"
00021         "       4 -- nu-SVR\n"
00022         "-t kernel_type : set type of kernel function (default 2)\n"
00023         "       0 -- linear: u'*v\n"
00024         "       1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
00025         "       2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
00026         "       3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
00027         "       4 -- precomputed kernel (kernel values in training_set_file)\n"
00028         "-d degree : set degree in kernel function (default 3)\n"
00029         "-g gamma : set gamma in kernel function (default 1/num_features)\n"
00030         "-r coef0 : set coef0 in kernel function (default 0)\n"
00031         "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
00032         "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
00033         "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
00034         "-m cachesize : set cache memory size in MB (default 100)\n"
00035         "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
00036         "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
00037         "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
00038         "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
00039         "-v n: n-fold cross validation mode\n"
00040         "-q : quiet mode (no outputs)\n"
00041         );
00042         exit(1);
00043 }
00044 
00045 void exit_input_error(int line_num)
00046 {
00047         fprintf(stderr,"Wrong input format at line %d\n", line_num);
00048         exit(1);
00049 }
00050 
00051 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
00052 void read_problem(const char *filename);
00053 void do_cross_validation();
00054 
00055 struct svm_parameter param;             // set by parse_command_line
00056 struct svm_problem prob;                // set by read_problem
00057 struct svm_model *model;
00058 struct svm_node *x_space;
00059 int cross_validation;
00060 int nr_fold;
00061 
00062 static char *line = NULL;
00063 static int max_line_len;
00064 
00065 static char* readline(FILE *input)
00066 {
00067         int len;
00068         
00069         if(fgets(line,max_line_len,input) == NULL)
00070                 return NULL;
00071 
00072         while(strrchr(line,'\n') == NULL)
00073         {
00074                 max_line_len *= 2;
00075                 line = (char *) realloc(line,max_line_len);
00076                 len = (int) strlen(line);
00077                 if(fgets(line+len,max_line_len-len,input) == NULL)
00078                         break;
00079         }
00080         return line;
00081 }
00082 
00083 int main(int argc, char **argv)
00084 {
00085         char input_file_name[1024];
00086         char model_file_name[1024];
00087         const char *error_msg;
00088 
00089         parse_command_line(argc, argv, input_file_name, model_file_name);
00090         read_problem(input_file_name);
00091         error_msg = svm_check_parameter(&prob,&param);
00092 
00093         if(error_msg)
00094         {
00095                 fprintf(stderr,"Error: %s\n",error_msg);
00096                 exit(1);
00097         }
00098 
00099         if(cross_validation)
00100         {
00101                 do_cross_validation();
00102         }
00103         else
00104         {
00105                 model = svm_train(&prob,&param);
00106                 if(svm_save_model(model_file_name,model))
00107                 {
00108                         fprintf(stderr, "can't save model to file %s\n", model_file_name);
00109                         exit(1);
00110                 }
00111                 svm_free_and_destroy_model(&model);
00112         }
00113         svm_destroy_param(&param);
00114         free(prob.y);
00115         free(prob.x);
00116         free(x_space);
00117         free(line);
00118 
00119         return 0;
00120 }
00121 
00122 void do_cross_validation()
00123 {
00124         int i;
00125         int total_correct = 0;
00126         double total_error = 0;
00127         double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
00128         double *target = Malloc(double,prob.l);
00129 
00130         svm_cross_validation(&prob,&param,nr_fold,target);
00131         if(param.svm_type == EPSILON_SVR ||
00132            param.svm_type == NU_SVR)
00133         {
00134                 for(i=0;i<prob.l;i++)
00135                 {
00136                         double y = prob.y[i];
00137                         double v = target[i];
00138                         total_error += (v-y)*(v-y);
00139                         sumv += v;
00140                         sumy += y;
00141                         sumvv += v*v;
00142                         sumyy += y*y;
00143                         sumvy += v*y;
00144                 }
00145                 printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
00146                 printf("Cross Validation Squared correlation coefficient = %g\n",
00147                         ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
00148                         ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
00149                         );
00150         }
00151         else
00152         {
00153                 for(i=0;i<prob.l;i++)
00154                         if(target[i] == prob.y[i])
00155                                 ++total_correct;
00156                 printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
00157         }
00158         free(target);
00159 }
00160 
00161 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
00162 {
00163         int i;
00164         void (*print_func)(const char*) = NULL; // default printing to stdout
00165 
00166         // default values
00167         param.svm_type = C_SVC;
00168         param.kernel_type = RBF;
00169         param.degree = 3;
00170         param.gamma = 0;        // 1/num_features
00171         param.coef0 = 0;
00172         param.nu = 0.5;
00173         param.cache_size = 100;
00174         param.C = 1;
00175         param.eps = 1e-3;
00176         param.p = 0.1;
00177         param.shrinking = 1;
00178         param.probability = 0;
00179         param.nr_weight = 0;
00180         param.weight_label = NULL;
00181         param.weight = NULL;
00182         cross_validation = 0;
00183 
00184         // parse options
00185         for(i=1;i<argc;i++)
00186         {
00187                 if(argv[i][0] != '-') break;
00188                 if(++i>=argc)
00189                         exit_with_help();
00190                 switch(argv[i-1][1])
00191                 {
00192                         case 's':
00193                                 param.svm_type = atoi(argv[i]);
00194                                 break;
00195                         case 't':
00196                                 param.kernel_type = atoi(argv[i]);
00197                                 break;
00198                         case 'd':
00199                                 param.degree = atoi(argv[i]);
00200                                 break;
00201                         case 'g':
00202                                 param.gamma = atof(argv[i]);
00203                                 break;
00204                         case 'r':
00205                                 param.coef0 = atof(argv[i]);
00206                                 break;
00207                         case 'n':
00208                                 param.nu = atof(argv[i]);
00209                                 break;
00210                         case 'm':
00211                                 param.cache_size = atof(argv[i]);
00212                                 break;
00213                         case 'c':
00214                                 param.C = atof(argv[i]);
00215                                 break;
00216                         case 'e':
00217                                 param.eps = atof(argv[i]);
00218                                 break;
00219                         case 'p':
00220                                 param.p = atof(argv[i]);
00221                                 break;
00222                         case 'h':
00223                                 param.shrinking = atoi(argv[i]);
00224                                 break;
00225                         case 'b':
00226                                 param.probability = atoi(argv[i]);
00227                                 break;
00228                         case 'q':
00229                                 print_func = &print_null;
00230                                 i--;
00231                                 break;
00232                         case 'v':
00233                                 cross_validation = 1;
00234                                 nr_fold = atoi(argv[i]);
00235                                 if(nr_fold < 2)
00236                                 {
00237                                         fprintf(stderr,"n-fold cross validation: n must >= 2\n");
00238                                         exit_with_help();
00239                                 }
00240                                 break;
00241                         case 'w':
00242                                 ++param.nr_weight;
00243                                 param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
00244                                 param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
00245                                 param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
00246                                 param.weight[param.nr_weight-1] = atof(argv[i]);
00247                                 break;
00248                         default:
00249                                 fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
00250                                 exit_with_help();
00251                 }
00252         }
00253 
00254         svm_set_print_string_function(print_func);
00255 
00256         // determine filenames
00257 
00258         if(i>=argc)
00259                 exit_with_help();
00260 
00261         strcpy(input_file_name, argv[i]);
00262 
00263         if(i<argc-1)
00264                 strcpy(model_file_name,argv[i+1]);
00265         else
00266         {
00267                 char *p = strrchr(argv[i],'/');
00268                 if(p==NULL)
00269                         p = argv[i];
00270                 else
00271                         ++p;
00272                 sprintf(model_file_name,"%s.model",p);
00273         }
00274 }
00275 
00276 // read in a problem (in svmlight format)
00277 
00278 void read_problem(const char *filename)
00279 {
00280         int elements, max_index, inst_max_index, i, j;
00281         FILE *fp = fopen(filename,"r");
00282         char *endptr;
00283         char *idx, *val, *label;
00284 
00285         if(fp == NULL)
00286         {
00287                 fprintf(stderr,"can't open input file %s\n",filename);
00288                 exit(1);
00289         }
00290 
00291         prob.l = 0;
00292         elements = 0;
00293 
00294         max_line_len = 1024;
00295         line = Malloc(char,max_line_len);
00296         while(readline(fp)!=NULL)
00297         {
00298                 char *p = strtok(line," \t"); // label
00299 
00300                 // features
00301                 while(1)
00302                 {
00303                         p = strtok(NULL," \t");
00304                         if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
00305                                 break;
00306                         ++elements;
00307                 }
00308                 ++elements;
00309                 ++prob.l;
00310         }
00311         rewind(fp);
00312 
00313         prob.y = Malloc(double,prob.l);
00314         prob.x = Malloc(struct svm_node *,prob.l);
00315         x_space = Malloc(struct svm_node,elements);
00316 
00317         max_index = 0;
00318         j=0;
00319         for(i=0;i<prob.l;i++)
00320         {
00321                 inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
00322                 readline(fp);
00323                 prob.x[i] = &x_space[j];
00324                 label = strtok(line," \t");
00325                 prob.y[i] = strtod(label,&endptr);
00326                 if(endptr == label)
00327                         exit_input_error(i+1);
00328 
00329                 while(1)
00330                 {
00331                         idx = strtok(NULL,":");
00332                         val = strtok(NULL," \t");
00333 
00334                         if(val == NULL)
00335                                 break;
00336 
00337                         errno = 0;
00338                         x_space[j].index = (int) strtol(idx,&endptr,10);
00339                         if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
00340                                 exit_input_error(i+1);
00341                         else
00342                                 inst_max_index = x_space[j].index;
00343 
00344                         errno = 0;
00345                         x_space[j].value = strtod(val,&endptr);
00346                         if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
00347                                 exit_input_error(i+1);
00348 
00349                         ++j;
00350                 }
00351 
00352                 if(inst_max_index > max_index)
00353                         max_index = inst_max_index;
00354                 x_space[j++].index = -1;
00355         }
00356 
00357         if(param.gamma == 0 && max_index > 0)
00358                 param.gamma = 1.0/max_index;
00359 
00360         if(param.kernel_type == PRECOMPUTED)
00361                 for(i=0;i<prob.l;i++)
00362                 {
00363                         if (prob.x[i][0].index != 0)
00364                         {
00365                                 fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
00366                                 exit(1);
00367                         }
00368                         if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
00369                         {
00370                                 fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
00371                                 exit(1);
00372                         }
00373                 }
00374 
00375         fclose(fp);
00376 }


libsvm3
Author(s): various
autogenerated on Wed Nov 27 2013 11:36:23