svm-train.c
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <ctype.h>
5 #include <errno.h>
6 #include "svm.h"
7 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
8 
9 void print_null(const char *s) {}
10 
12 {
13  printf(
14  "Usage: svm-train [options] training_set_file [model_file]\n"
15  "options:\n"
16  "-s svm_type : set type of SVM (default 0)\n"
17  " 0 -- C-SVC\n"
18  " 1 -- nu-SVC\n"
19  " 2 -- one-class SVM\n"
20  " 3 -- epsilon-SVR\n"
21  " 4 -- nu-SVR\n"
22  "-t kernel_type : set type of kernel function (default 2)\n"
23  " 0 -- linear: u'*v\n"
24  " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
25  " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
26  " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
27  " 4 -- precomputed kernel (kernel values in training_set_file)\n"
28  "-d degree : set degree in kernel function (default 3)\n"
29  "-g gamma : set gamma in kernel function (default 1/num_features)\n"
30  "-r coef0 : set coef0 in kernel function (default 0)\n"
31  "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
32  "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
33  "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
34  "-m cachesize : set cache memory size in MB (default 100)\n"
35  "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
36  "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
37  "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
38  "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
39  "-v n: n-fold cross validation mode\n"
40  "-q : quiet mode (no outputs)\n"
41  );
42  exit(1);
43 }
44 
45 void exit_input_error(int line_num)
46 {
47  fprintf(stderr,"Wrong input format at line %d\n", line_num);
48  exit(1);
49 }
50 
51 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name);
52 void read_problem(const char *filename);
53 void do_cross_validation();
54 
55 struct svm_parameter param; // set by parse_command_line
56 struct svm_problem prob; // set by read_problem
57 struct svm_model *model;
58 struct svm_node *x_space;
60 int nr_fold;
61 
62 static char *line = NULL;
63 static int max_line_len;
64 
65 static char* readline(FILE *input)
66 {
67  int len;
68 
69  if(fgets(line,max_line_len,input) == NULL)
70  return NULL;
71 
72  while(strrchr(line,'\n') == NULL)
73  {
74  max_line_len *= 2;
75  line = (char *) realloc(line,max_line_len);
76  len = (int) strlen(line);
77  if(fgets(line+len,max_line_len-len,input) == NULL)
78  break;
79  }
80  return line;
81 }
82 
83 int main(int argc, char **argv)
84 {
85  char input_file_name[1024];
86  char model_file_name[1024];
87  const char *error_msg;
88 
89  parse_command_line(argc, argv, input_file_name, model_file_name);
90  read_problem(input_file_name);
91  error_msg = svm_check_parameter(&prob,&param);
92 
93  if(error_msg)
94  {
95  fprintf(stderr,"ERROR: %s\n",error_msg);
96  exit(1);
97  }
98 
100  {
102  }
103  else
104  {
105  model = svm_train(&prob,&param);
106  if(svm_save_model(model_file_name,model))
107  {
108  fprintf(stderr, "can't save model to file %s\n", model_file_name);
109  exit(1);
110  }
112  }
114  free(prob.y);
115  free(prob.x);
116  free(x_space);
117  free(line);
118 
119  return 0;
120 }
121 
123 {
124  int i;
125  int total_correct = 0;
126  double total_error = 0;
127  double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
128  double *target = Malloc(double,prob.l);
129 
131  if(param.svm_type == EPSILON_SVR ||
132  param.svm_type == NU_SVR)
133  {
134  for(i=0;i<prob.l;i++)
135  {
136  double y = prob.y[i];
137  double v = target[i];
138  total_error += (v-y)*(v-y);
139  sumv += v;
140  sumy += y;
141  sumvv += v*v;
142  sumyy += y*y;
143  sumvy += v*y;
144  }
145  printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
146  printf("Cross Validation Squared correlation coefficient = %g\n",
147  ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
148  ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
149  );
150  }
151  else
152  {
153  for(i=0;i<prob.l;i++)
154  if(target[i] == prob.y[i])
155  ++total_correct;
156  printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
157  }
158  free(target);
159 }
160 
161 void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
162 {
163  int i;
164  void (*print_func)(const char*) = NULL; // default printing to stdout
165 
166  // default values
167  param.svm_type = C_SVC;
169  param.degree = 3;
170  param.gamma = 0; // 1/num_features
171  param.coef0 = 0;
172  param.nu = 0.5;
173  param.cache_size = 100;
174  param.C = 1;
175  param.eps = 1e-3;
176  param.p = 0.1;
177  param.shrinking = 1;
178  param.probability = 0;
179  param.nr_weight = 0;
180  param.weight_label = NULL;
181  param.weight = NULL;
182  cross_validation = 0;
183 
184  // parse options
185  for(i=1;i<argc;i++)
186  {
187  if(argv[i][0] != '-') break;
188  if(++i>=argc)
189  exit_with_help();
190  switch(argv[i-1][1])
191  {
192  case 's':
193  param.svm_type = atoi(argv[i]);
194  break;
195  case 't':
196  param.kernel_type = atoi(argv[i]);
197  break;
198  case 'd':
199  param.degree = atoi(argv[i]);
200  break;
201  case 'g':
202  param.gamma = atof(argv[i]);
203  break;
204  case 'r':
205  param.coef0 = atof(argv[i]);
206  break;
207  case 'n':
208  param.nu = atof(argv[i]);
209  break;
210  case 'm':
211  param.cache_size = atof(argv[i]);
212  break;
213  case 'c':
214  param.C = atof(argv[i]);
215  break;
216  case 'e':
217  param.eps = atof(argv[i]);
218  break;
219  case 'p':
220  param.p = atof(argv[i]);
221  break;
222  case 'h':
223  param.shrinking = atoi(argv[i]);
224  break;
225  case 'b':
226  param.probability = atoi(argv[i]);
227  break;
228  case 'q':
229  print_func = &print_null;
230  i--;
231  break;
232  case 'v':
233  cross_validation = 1;
234  nr_fold = atoi(argv[i]);
235  if(nr_fold < 2)
236  {
237  fprintf(stderr,"n-fold cross validation: n must >= 2\n");
238  exit_with_help();
239  }
240  break;
241  case 'w':
242  ++param.nr_weight;
243  param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
244  param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
245  param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
246  param.weight[param.nr_weight-1] = atof(argv[i]);
247  break;
248  default:
249  fprintf(stderr,"Unknown option: -%c\n", argv[i-1][1]);
250  exit_with_help();
251  }
252  }
253 
254  svm_set_print_string_function(print_func);
255 
256  // determine filenames
257 
258  if(i>=argc)
259  exit_with_help();
260 
261  strcpy(input_file_name, argv[i]);
262 
263  if(i<argc-1)
264  strcpy(model_file_name,argv[i+1]);
265  else
266  {
267  char *p = strrchr(argv[i],'/');
268  if(p==NULL)
269  p = argv[i];
270  else
271  ++p;
272  sprintf(model_file_name,"%s.model",p);
273  }
274 }
275 
276 // read in a problem (in svmlight format)
277 
278 void read_problem(const char *filename)
279 {
280  int elements, max_index, inst_max_index, i, j;
281  FILE *fp = fopen(filename,"r");
282  char *endptr;
283  char *idx, *val, *label;
284 
285  if(fp == NULL)
286  {
287  fprintf(stderr,"can't open input file %s\n",filename);
288  exit(1);
289  }
290 
291  prob.l = 0;
292  elements = 0;
293 
294  max_line_len = 1024;
295  line = Malloc(char,max_line_len);
296  while(readline(fp)!=NULL)
297  {
298  char *p = strtok(line," \t"); // label
299 
300  // features
301  while(1)
302  {
303  p = strtok(NULL," \t");
304  if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
305  break;
306  ++elements;
307  }
308  ++elements;
309  ++prob.l;
310  }
311  rewind(fp);
312 
313  prob.y = Malloc(double,prob.l);
314  prob.x = Malloc(struct svm_node *,prob.l);
315  x_space = Malloc(struct svm_node,elements);
316 
317  max_index = 0;
318  j=0;
319  for(i=0;i<prob.l;i++)
320  {
321  inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
322  readline(fp);
323  prob.x[i] = &x_space[j];
324  label = strtok(line," \t\n");
325  if(label == NULL) // empty line
326  exit_input_error(i+1);
327 
328  prob.y[i] = strtod(label,&endptr);
329  if(endptr == label || *endptr != '\0')
330  exit_input_error(i+1);
331 
332  while(1)
333  {
334  idx = strtok(NULL,":");
335  val = strtok(NULL," \t");
336 
337  if(val == NULL)
338  break;
339 
340  errno = 0;
341  x_space[j].index = (int) strtol(idx,&endptr,10);
342  if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
343  exit_input_error(i+1);
344  else
345  inst_max_index = x_space[j].index;
346 
347  errno = 0;
348  x_space[j].value = strtod(val,&endptr);
349  if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
350  exit_input_error(i+1);
351 
352  ++j;
353  }
354 
355  if(inst_max_index > max_index)
356  max_index = inst_max_index;
357  x_space[j++].index = -1;
358  }
359 
360  if(param.gamma == 0 && max_index > 0)
361  param.gamma = 1.0/max_index;
362 
364  for(i=0;i<prob.l;i++)
365  {
366  if (prob.x[i][0].index != 0)
367  {
368  fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
369  exit(1);
370  }
371  if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
372  {
373  fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
374  exit(1);
375  }
376  }
377 
378  fclose(fp);
379 }
filename
const char * svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
Definition: svm.cpp:2977
def svm_train(arg1, arg2=None, arg3=None)
Definition: svmutil.py:77
double value
Definition: svm.h:15
int nr_fold
Definition: svm-train.c:60
int nr_weight
Definition: svm.h:40
int main(int argc, char **argv)
Definition: svm-train.c:83
void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
Definition: svm.cpp:2314
def svm_save_model(model_file_name, model)
Definition: svmutil.py:40
#define Malloc(type, n)
Definition: svm-train.c:7
void print_null(const char *s)
Definition: svm-train.c:9
int * weight_label
Definition: svm.h:41
void do_cross_validation()
Definition: svm-train.c:122
Definition: svm.h:25
struct svm_node * x_space
Definition: svm-train.c:58
TFSIMD_FORCE_INLINE const tfScalar & y() const
Definition: svm.h:52
double p
Definition: svm.h:44
void exit_with_help()
Definition: svm-train.c:11
struct svm_parameter param
Definition: svm-train.c:55
struct svm_model * model
Definition: svm-train.c:57
struct svm_problem prob
Definition: svm-train.c:56
static char * readline(FILE *input)
Definition: svm-train.c:65
void exit_input_error(int line_num)
Definition: svm-train.c:45
double cache_size
Definition: svm.h:37
index
Definition: subset.py:58
double eps
Definition: svm.h:38
int max_index
Definition: svm-scale.c:29
int shrinking
Definition: svm.h:45
struct svm_node ** x
Definition: svm.h:22
void svm_free_and_destroy_model(svm_model **model_ptr_ptr)
Definition: svm.cpp:2961
Definition: svm.h:26
int index
Definition: svm.h:14
void parse_command_line(int argc, char **argv, char *input_file_name, char *model_file_name)
Definition: svm-train.c:161
int probability
Definition: svm.h:46
int degree
Definition: svm.h:32
int cross_validation
Definition: svm-train.c:59
Definition: svm.h:25
Definition: svm.h:12
double * y
Definition: svm.h:21
double gamma
Definition: svm.h:33
int l
Definition: svm.h:20
double * weight
Definition: svm.h:42
double C
Definition: svm.h:39
int svm_type
Definition: svm.h:30
double nu
Definition: svm.h:43
static char * line
Definition: svm-train.c:62
double coef0
Definition: svm.h:34
label
Definition: subset.py:57
void svm_destroy_param(svm_parameter *param)
Definition: svm.cpp:2971
void read_problem(const char *filename)
Definition: svm-train.c:278
int kernel_type
Definition: svm.h:31
static int max_line_len
Definition: svm-train.c:63
void svm_set_print_string_function(void(*print_func)(const char *))
Definition: svm.cpp:3106


haf_grasping
Author(s): David Fischinger
autogenerated on Mon Jun 10 2019 13:28:43