svmtrain.c
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <ctype.h>
5 #include "../svm.h"
6 
7 #include "mex.h"
8 #include "svm_model_matlab.h"
9 
10 #ifdef MX_API_VER
11 #if MX_API_VER < 0x07030000
12 typedef int mwIndex;
13 #endif
14 #endif
15 
16 #define CMD_LEN 2048
17 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
18 
19 void print_null(const char *s) {}
20 void print_string_matlab(const char *s) {mexPrintf(s);}
21 
23 {
24  mexPrintf(
25  "Usage: model = svmtrain(training_label_vector, training_instance_matrix, 'libsvm_options');\n"
26  "libsvm_options:\n"
27  "-s svm_type : set type of SVM (default 0)\n"
28  " 0 -- C-SVC\n"
29  " 1 -- nu-SVC\n"
30  " 2 -- one-class SVM\n"
31  " 3 -- epsilon-SVR\n"
32  " 4 -- nu-SVR\n"
33  "-t kernel_type : set type of kernel function (default 2)\n"
34  " 0 -- linear: u'*v\n"
35  " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
36  " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
37  " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
38  " 4 -- precomputed kernel (kernel values in training_instance_matrix)\n"
39  "-d degree : set degree in kernel function (default 3)\n"
40  "-g gamma : set gamma in kernel function (default 1/num_features)\n"
41  "-r coef0 : set coef0 in kernel function (default 0)\n"
42  "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
43  "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
44  "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
45  "-m cachesize : set cache memory size in MB (default 100)\n"
46  "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
47  "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
48  "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
49  "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
50  "-v n : n-fold cross validation mode\n"
51  "-q : quiet mode (no outputs)\n"
52  );
53 }
54 
55 // svm arguments
56 struct svm_parameter param; // set by parse_command_line
57 struct svm_problem prob; // set by read_problem
58 struct svm_model *model;
59 struct svm_node *x_space;
61 int nr_fold;
62 
63 
65 {
66  int i;
67  int total_correct = 0;
68  double total_error = 0;
69  double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
70  double *target = Malloc(double,prob.l);
71  double retval = 0.0;
72 
74  if(param.svm_type == EPSILON_SVR ||
76  {
77  for(i=0;i<prob.l;i++)
78  {
79  double y = prob.y[i];
80  double v = target[i];
81  total_error += (v-y)*(v-y);
82  sumv += v;
83  sumy += y;
84  sumvv += v*v;
85  sumyy += y*y;
86  sumvy += v*y;
87  }
88  mexPrintf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
89  mexPrintf("Cross Validation Squared correlation coefficient = %g\n",
90  ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
91  ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
92  );
93  retval = total_error/prob.l;
94  }
95  else
96  {
97  for(i=0;i<prob.l;i++)
98  if(target[i] == prob.y[i])
99  ++total_correct;
100  mexPrintf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
101  retval = 100.0*total_correct/prob.l;
102  }
103  free(target);
104  return retval;
105 }
106 
107 // nrhs should be 3
108 int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
109 {
110  int i, argc = 1;
111  char cmd[CMD_LEN];
112  char *argv[CMD_LEN/2];
113  void (*print_func)(const char *) = print_string_matlab; // default printing to matlab display
114 
115  // default values
116  param.svm_type = C_SVC;
118  param.degree = 3;
119  param.gamma = 0; // 1/num_features
120  param.coef0 = 0;
121  param.nu = 0.5;
122  param.cache_size = 100;
123  param.C = 1;
124  param.eps = 1e-3;
125  param.p = 0.1;
126  param.shrinking = 1;
127  param.probability = 0;
128  param.nr_weight = 0;
129  param.weight_label = NULL;
130  param.weight = NULL;
131  cross_validation = 0;
132 
133  if(nrhs <= 1)
134  return 1;
135 
136  if(nrhs > 2)
137  {
138  // put options in argv[]
139  mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1);
140  if((argv[argc] = strtok(cmd, " ")) != NULL)
141  while((argv[++argc] = strtok(NULL, " ")) != NULL)
142  ;
143  }
144 
145  // parse options
146  for(i=1;i<argc;i++)
147  {
148  if(argv[i][0] != '-') break;
149  ++i;
150  if(i>=argc && argv[i-1][1] != 'q') // since option -q has no parameter
151  return 1;
152  switch(argv[i-1][1])
153  {
154  case 's':
155  param.svm_type = atoi(argv[i]);
156  break;
157  case 't':
158  param.kernel_type = atoi(argv[i]);
159  break;
160  case 'd':
161  param.degree = atoi(argv[i]);
162  break;
163  case 'g':
164  param.gamma = atof(argv[i]);
165  break;
166  case 'r':
167  param.coef0 = atof(argv[i]);
168  break;
169  case 'n':
170  param.nu = atof(argv[i]);
171  break;
172  case 'm':
173  param.cache_size = atof(argv[i]);
174  break;
175  case 'c':
176  param.C = atof(argv[i]);
177  break;
178  case 'e':
179  param.eps = atof(argv[i]);
180  break;
181  case 'p':
182  param.p = atof(argv[i]);
183  break;
184  case 'h':
185  param.shrinking = atoi(argv[i]);
186  break;
187  case 'b':
188  param.probability = atoi(argv[i]);
189  break;
190  case 'q':
191  print_func = &print_null;
192  i--;
193  break;
194  case 'v':
195  cross_validation = 1;
196  nr_fold = atoi(argv[i]);
197  if(nr_fold < 2)
198  {
199  mexPrintf("n-fold cross validation: n must >= 2\n");
200  return 1;
201  }
202  break;
203  case 'w':
204  ++param.nr_weight;
205  param.weight_label = (int *)realloc(param.weight_label,sizeof(int)*param.nr_weight);
206  param.weight = (double *)realloc(param.weight,sizeof(double)*param.nr_weight);
207  param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);
208  param.weight[param.nr_weight-1] = atof(argv[i]);
209  break;
210  default:
211  mexPrintf("Unknown option -%c\n", argv[i-1][1]);
212  return 1;
213  }
214  }
215 
216  svm_set_print_string_function(print_func);
217 
218  return 0;
219 }
220 
221 // read in a problem (in svmlight format)
222 int read_problem_dense(const mxArray *label_vec, const mxArray *instance_mat)
223 {
224  int i, j, k;
225  int elements, max_index, sc, label_vector_row_num;
226  double *samples, *labels;
227 
228  prob.x = NULL;
229  prob.y = NULL;
230  x_space = NULL;
231 
232  labels = mxGetPr(label_vec);
233  samples = mxGetPr(instance_mat);
234  sc = (int)mxGetN(instance_mat);
235 
236  elements = 0;
237  // the number of instance
238  prob.l = (int)mxGetM(instance_mat);
239  label_vector_row_num = (int)mxGetM(label_vec);
240 
241  if(label_vector_row_num!=prob.l)
242  {
243  mexPrintf("Length of label vector does not match # of instances.\n");
244  return -1;
245  }
246 
248  elements = prob.l * (sc + 1);
249  else
250  {
251  for(i = 0; i < prob.l; i++)
252  {
253  for(k = 0; k < sc; k++)
254  if(samples[k * prob.l + i] != 0)
255  elements++;
256  // count the '-1' element
257  elements++;
258  }
259  }
260 
261  prob.y = Malloc(double,prob.l);
262  prob.x = Malloc(struct svm_node *,prob.l);
263  x_space = Malloc(struct svm_node, elements);
264 
265  max_index = sc;
266  j = 0;
267  for(i = 0; i < prob.l; i++)
268  {
269  prob.x[i] = &x_space[j];
270  prob.y[i] = labels[i];
271 
272  for(k = 0; k < sc; k++)
273  {
274  if(param.kernel_type == PRECOMPUTED || samples[k * prob.l + i] != 0)
275  {
276  x_space[j].index = k + 1;
277  x_space[j].value = samples[k * prob.l + i];
278  j++;
279  }
280  }
281  x_space[j++].index = -1;
282  }
283 
284  if(param.gamma == 0 && max_index > 0)
285  param.gamma = 1.0/max_index;
286 
288  for(i=0;i<prob.l;i++)
289  {
290  if((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
291  {
292  mexPrintf("Wrong input format: sample_serial_number out of range\n");
293  return -1;
294  }
295  }
296 
297  return 0;
298 }
299 
300 int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
301 {
302  int i, j, k, low, high;
303  mwIndex *ir, *jc;
304  int elements, max_index, num_samples, label_vector_row_num;
305  double *samples, *labels;
306  mxArray *instance_mat_col; // transposed instance sparse matrix
307 
308  prob.x = NULL;
309  prob.y = NULL;
310  x_space = NULL;
311 
312  // transpose instance matrix
313  {
314  mxArray *prhs[1], *plhs[1];
315  prhs[0] = mxDuplicateArray(instance_mat);
316  if(mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
317  {
318  mexPrintf("Error: cannot transpose training instance matrix\n");
319  return -1;
320  }
321  instance_mat_col = plhs[0];
322  mxDestroyArray(prhs[0]);
323  }
324 
325  // each column is one instance
326  labels = mxGetPr(label_vec);
327  samples = mxGetPr(instance_mat_col);
328  ir = mxGetIr(instance_mat_col);
329  jc = mxGetJc(instance_mat_col);
330 
331  num_samples = (int)mxGetNzmax(instance_mat_col);
332 
333  // the number of instance
334  prob.l = (int)mxGetN(instance_mat_col);
335  label_vector_row_num = (int)mxGetM(label_vec);
336 
337  if(label_vector_row_num!=prob.l)
338  {
339  mexPrintf("Length of label vector does not match # of instances.\n");
340  return -1;
341  }
342 
343  elements = num_samples + prob.l;
344  max_index = (int)mxGetM(instance_mat_col);
345 
346  prob.y = Malloc(double,prob.l);
347  prob.x = Malloc(struct svm_node *,prob.l);
348  x_space = Malloc(struct svm_node, elements);
349 
350  j = 0;
351  for(i=0;i<prob.l;i++)
352  {
353  prob.x[i] = &x_space[j];
354  prob.y[i] = labels[i];
355  low = (int)jc[i], high = (int)jc[i+1];
356  for(k=low;k<high;k++)
357  {
358  x_space[j].index = (int)ir[k] + 1;
359  x_space[j].value = samples[k];
360  j++;
361  }
362  x_space[j++].index = -1;
363  }
364 
365  if(param.gamma == 0 && max_index > 0)
366  param.gamma = 1.0/max_index;
367 
368  return 0;
369 }
370 
371 static void fake_answer(mxArray *plhs[])
372 {
373  plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
374 }
375 
376 // Interface function of matlab
377 // now assume prhs[0]: label prhs[1]: features
378 void mexFunction( int nlhs, mxArray *plhs[],
379  int nrhs, const mxArray *prhs[] )
380 {
381  const char *error_msg;
382 
383  // fix random seed to have same results for each run
384  // (for cross validation and probability estimation)
385  srand(1);
386 
387  // Transform the input Matrix to libsvm format
388  if(nrhs > 1 && nrhs < 4)
389  {
390  int err;
391 
392  if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
393  mexPrintf("Error: label vector and instance matrix must be double\n");
394  fake_answer(plhs);
395  return;
396  }
397 
398  if(parse_command_line(nrhs, prhs, NULL))
399  {
400  exit_with_help();
402  fake_answer(plhs);
403  return;
404  }
405 
406  if(mxIsSparse(prhs[1]))
407  {
409  {
410  // precomputed kernel requires dense matrix, so we make one
411  mxArray *rhs[1], *lhs[1];
412 
413  rhs[0] = mxDuplicateArray(prhs[1]);
414  if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
415  {
416  mexPrintf("Error: cannot generate a full training instance matrix\n");
418  fake_answer(plhs);
419  return;
420  }
421  err = read_problem_dense(prhs[0], lhs[0]);
422  mxDestroyArray(lhs[0]);
423  mxDestroyArray(rhs[0]);
424  }
425  else
426  err = read_problem_sparse(prhs[0], prhs[1]);
427  }
428  else
429  err = read_problem_dense(prhs[0], prhs[1]);
430 
431  // svmtrain's original code
432  error_msg = svm_check_parameter(&prob, &param);
433 
434  if(err || error_msg)
435  {
436  if (error_msg != NULL)
437  mexPrintf("Error: %s\n", error_msg);
439  free(prob.y);
440  free(prob.x);
441  free(x_space);
442  fake_answer(plhs);
443  return;
444  }
445 
446  if(cross_validation)
447  {
448  double *ptr;
449  plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
450  ptr = mxGetPr(plhs[0]);
451  ptr[0] = do_cross_validation();
452  }
453  else
454  {
455  int nr_feat = (int)mxGetN(prhs[1]);
456  const char *error_msg;
457  model = svm_train(&prob, &param);
458  error_msg = model_to_matlab_structure(plhs, nr_feat, model);
459  if(error_msg)
460  mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
462  }
464  free(prob.y);
465  free(prob.x);
466  free(x_space);
467  }
468  else
469  {
470  exit_with_help();
471  fake_answer(plhs);
472  return;
473  }
474 }
string cmd
Definition: easy.py:48
int nr_fold
Definition: svmtrain.c:61
struct svm_problem prob
Definition: svmtrain.c:57
const char * svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
Definition: svm.cpp:2977
def err(line_no, msg)
Definition: checkdata.py:18
struct svm_parameter param
Definition: svmtrain.c:56
def svm_train(arg1, arg2=None, arg3=None)
Definition: svmutil.py:77
double value
Definition: svm.h:15
int nr_weight
Definition: svm.h:40
void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
Definition: svm.cpp:2314
int * weight_label
Definition: svm.h:41
Definition: svm.h:25
TFSIMD_FORCE_INLINE const tfScalar & y() const
int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
Definition: svmtrain.c:108
static void fake_answer(mxArray *plhs[])
Definition: svmtrain.c:371
Definition: svm.h:52
double p
Definition: svm.h:44
int cross_validation
Definition: svmtrain.c:60
double cache_size
Definition: svm.h:37
void print_string_matlab(const char *s)
Definition: svmtrain.c:20
int read_problem_dense(const mxArray *label_vec, const mxArray *instance_mat)
Definition: svmtrain.c:222
double eps
Definition: svm.h:38
int max_index
Definition: svm-scale.c:29
int shrinking
Definition: svm.h:45
void print_null(const char *s)
Definition: svmtrain.c:19
struct svm_node ** x
Definition: svm.h:22
void svm_free_and_destroy_model(svm_model **model_ptr_ptr)
Definition: svm.cpp:2961
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Definition: svmtrain.c:378
#define CMD_LEN
Definition: svmtrain.c:16
#define Malloc(type, n)
Definition: svmtrain.c:17
Definition: svm.h:26
int index
Definition: svm.h:14
int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
Definition: svmtrain.c:300
void exit_with_help()
Definition: svmtrain.c:22
int probability
Definition: svm.h:46
int degree
Definition: svm.h:32
Definition: svm.h:25
Definition: svm.h:12
double * y
Definition: svm.h:21
double gamma
Definition: svm.h:33
int l
Definition: svm.h:20
double * weight
Definition: svm.h:42
double C
Definition: svm.h:39
int svm_type
Definition: svm.h:30
double nu
Definition: svm.h:43
struct svm_model * model
Definition: svmtrain.c:58
double coef0
Definition: svm.h:34
double do_cross_validation()
Definition: svmtrain.c:64
void svm_destroy_param(svm_parameter *param)
Definition: svm.cpp:2971
const char * model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
int kernel_type
Definition: svm.h:31
struct svm_node * x_space
Definition: svmtrain.c:59
void svm_set_print_string_function(void(*print_func)(const char *))
Definition: svm.cpp:3106


haf_grasping
Author(s): David Fischinger
autogenerated on Mon Jun 10 2019 13:28:43