svmtrain.c
Go to the documentation of this file.
1 #include <stdio.h>
2 #include <stdlib.h>
3 #include <string.h>
4 #include <ctype.h>
5 #include "../svm.h"
6 
7 #include "mex.h"
8 #include "svm_model_matlab.h"
9 
10 #ifdef MX_API_VER
11 #if MX_API_VER < 0x07030000
12 typedef int mwIndex;
13 #endif
14 #endif
15 
16 #define CMD_LEN 2048
17 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
18 
19 void print_null(const char *s) {}
20 void print_string_matlab(const char *s)
21 {
22  mexPrintf(s);
23 }
24 
26 {
27  mexPrintf(
28  "Usage: model = svmtrain(training_label_vector, training_instance_matrix, 'libsvm_options');\n"
29  "libsvm_options:\n"
30  "-s svm_type : set type of SVM (default 0)\n"
31  " 0 -- C-SVC (multi-class classification)\n"
32  " 1 -- nu-SVC (multi-class classification)\n"
33  " 2 -- one-class SVM\n"
34  " 3 -- epsilon-SVR (regression)\n"
35  " 4 -- nu-SVR (regression)\n"
36  "-t kernel_type : set type of kernel function (default 2)\n"
37  " 0 -- linear: u'*v\n"
38  " 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
39  " 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
40  " 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
41  " 4 -- precomputed kernel (kernel values in training_instance_matrix)\n"
42  "-d degree : set degree in kernel function (default 3)\n"
43  "-g gamma : set gamma in kernel function (default 1/num_features)\n"
44  "-r coef0 : set coef0 in kernel function (default 0)\n"
45  "-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
46  "-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
47  "-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
48  "-m cachesize : set cache memory size in MB (default 100)\n"
49  "-e epsilon : set tolerance of termination criterion (default 0.001)\n"
50  "-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
51  "-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
52  "-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
53  "-v n : n-fold cross validation mode\n"
54  "-q : quiet mode (no outputs)\n"
55  );
56 }
57 
58 // svm arguments
59 struct svm_parameter param; // set by parse_command_line
60 struct svm_problem prob; // set by read_problem
61 struct svm_model *model;
62 struct svm_node *x_space;
64 int nr_fold;
65 
66 
68 {
69  int i;
70  int total_correct = 0;
71  double total_error = 0;
72  double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
73  double *target = Malloc(double, prob.l);
74  double retval = 0.0;
75 
77  if (param.svm_type == EPSILON_SVR ||
79  {
80  for (i = 0; i < prob.l; i++)
81  {
82  double y = prob.y[i];
83  double v = target[i];
84  total_error += (v - y) * (v - y);
85  sumv += v;
86  sumy += y;
87  sumvv += v * v;
88  sumyy += y * y;
89  sumvy += v * y;
90  }
91  mexPrintf("Cross Validation Mean squared error = %g\n", total_error / prob.l);
92  mexPrintf("Cross Validation Squared correlation coefficient = %g\n",
93  ((prob.l * sumvy - sumv * sumy) * (prob.l * sumvy - sumv * sumy)) /
94  ((prob.l * sumvv - sumv * sumv) * (prob.l * sumyy - sumy * sumy))
95  );
96  retval = total_error / prob.l;
97  }
98  else
99  {
100  for (i = 0; i < prob.l; i++)
101  if (target[i] == prob.y[i])
102  ++total_correct;
103  mexPrintf("Cross Validation Accuracy = %g%%\n", 100.0 * total_correct / prob.l);
104  retval = 100.0 * total_correct / prob.l;
105  }
106  free(target);
107  return retval;
108 }
109 
110 // nrhs should be 3
111 int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
112 {
113  int i, argc = 1;
114  char cmd[CMD_LEN];
115  char *argv[CMD_LEN / 2];
116  void (*print_func)(const char *) = print_string_matlab; // default printing to matlab display
117 
118  // default values
119  param.svm_type = C_SVC;
121  param.degree = 3;
122  param.gamma = 0; // 1/num_features
123  param.coef0 = 0;
124  param.nu = 0.5;
125  param.cache_size = 100;
126  param.C = 1;
127  param.eps = 1e-3;
128  param.p = 0.1;
129  param.shrinking = 1;
130  param.probability = 0;
131  param.nr_weight = 0;
132  param.weight_label = NULL;
133  param.weight = NULL;
134  cross_validation = 0;
135 
136  if (nrhs <= 1)
137  return 1;
138 
139  if (nrhs > 2)
140  {
141  // put options in argv[]
142  mxGetString(prhs[2], cmd, mxGetN(prhs[2]) + 1);
143  if ((argv[argc] = strtok(cmd, " ")) != NULL)
144  while ((argv[++argc] = strtok(NULL, " ")) != NULL)
145  ;
146  }
147 
148  // parse options
149  for (i = 1; i < argc; i++)
150  {
151  if (argv[i][0] != '-') break;
152  ++i;
153  if (i >= argc && argv[i - 1][1] != 'q') // since option -q has no parameter
154  return 1;
155  switch (argv[i - 1][1])
156  {
157  case 's':
158  param.svm_type = atoi(argv[i]);
159  break;
160  case 't':
161  param.kernel_type = atoi(argv[i]);
162  break;
163  case 'd':
164  param.degree = atoi(argv[i]);
165  break;
166  case 'g':
167  param.gamma = atof(argv[i]);
168  break;
169  case 'r':
170  param.coef0 = atof(argv[i]);
171  break;
172  case 'n':
173  param.nu = atof(argv[i]);
174  break;
175  case 'm':
176  param.cache_size = atof(argv[i]);
177  break;
178  case 'c':
179  param.C = atof(argv[i]);
180  break;
181  case 'e':
182  param.eps = atof(argv[i]);
183  break;
184  case 'p':
185  param.p = atof(argv[i]);
186  break;
187  case 'h':
188  param.shrinking = atoi(argv[i]);
189  break;
190  case 'b':
191  param.probability = atoi(argv[i]);
192  break;
193  case 'q':
194  print_func = &print_null;
195  i--;
196  break;
197  case 'v':
198  cross_validation = 1;
199  nr_fold = atoi(argv[i]);
200  if (nr_fold < 2)
201  {
202  mexPrintf("n-fold cross validation: n must >= 2\n");
203  return 1;
204  }
205  break;
206  case 'w':
207  ++param.nr_weight;
208  param.weight_label = (int *)realloc(param.weight_label, sizeof(int) * param.nr_weight);
209  param.weight = (double *)realloc(param.weight, sizeof(double) * param.nr_weight);
210  param.weight_label[param.nr_weight - 1] = atoi(&argv[i - 1][2]);
211  param.weight[param.nr_weight - 1] = atof(argv[i]);
212  break;
213  default:
214  mexPrintf("Unknown option -%c\n", argv[i - 1][1]);
215  return 1;
216  }
217  }
218 
219  svm_set_print_string_function(print_func);
220 
221  return 0;
222 }
223 
224 // read in a problem (in svmlight format)
225 int read_problem_dense(const mxArray *label_vec, const mxArray *instance_mat)
226 {
227  int i, j, k;
228  int elements, max_index, sc, label_vector_row_num;
229  double *samples, *labels;
230 
231  prob.x = NULL;
232  prob.y = NULL;
233  x_space = NULL;
234 
235  labels = mxGetPr(label_vec);
236  samples = mxGetPr(instance_mat);
237  sc = (int)mxGetN(instance_mat);
238 
239  elements = 0;
240  // the number of instance
241  prob.l = (int)mxGetM(instance_mat);
242  label_vector_row_num = (int)mxGetM(label_vec);
243 
244  if (label_vector_row_num != prob.l)
245  {
246  mexPrintf("Length of label vector does not match # of instances.\n");
247  return -1;
248  }
249 
251  elements = prob.l * (sc + 1);
252  else
253  {
254  for (i = 0; i < prob.l; i++)
255  {
256  for (k = 0; k < sc; k++)
257  if (samples[k * prob.l + i] != 0)
258  elements++;
259  // count the '-1' element
260  elements++;
261  }
262  }
263 
264  prob.y = Malloc(double, prob.l);
265  prob.x = Malloc(struct svm_node *, prob.l);
266  x_space = Malloc(struct svm_node, elements);
267 
268  max_index = sc;
269  j = 0;
270  for (i = 0; i < prob.l; i++)
271  {
272  prob.x[i] = &x_space[j];
273  prob.y[i] = labels[i];
274 
275  for (k = 0; k < sc; k++)
276  {
277  if (param.kernel_type == PRECOMPUTED || samples[k * prob.l + i] != 0)
278  {
279  x_space[j].index = k + 1;
280  x_space[j].value = samples[k * prob.l + i];
281  j++;
282  }
283  }
284  x_space[j++].index = -1;
285  }
286 
287  if (param.gamma == 0 && max_index > 0)
288  param.gamma = 1.0 / max_index;
289 
291  for (i = 0; i < prob.l; i++)
292  {
293  if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
294  {
295  mexPrintf("Wrong input format: sample_serial_number out of range\n");
296  return -1;
297  }
298  }
299 
300  return 0;
301 }
302 
303 int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
304 {
305  int i, j, k, low, high;
306  mwIndex *ir, *jc;
307  int elements, max_index, num_samples, label_vector_row_num;
308  double *samples, *labels;
309  mxArray *instance_mat_col; // transposed instance sparse matrix
310 
311  prob.x = NULL;
312  prob.y = NULL;
313  x_space = NULL;
314 
315  // transpose instance matrix
316  {
317  mxArray *prhs[1], *plhs[1];
318  prhs[0] = mxDuplicateArray(instance_mat);
319  if (mexCallMATLAB(1, plhs, 1, prhs, "transpose"))
320  {
321  mexPrintf("Error: cannot transpose training instance matrix\n");
322  return -1;
323  }
324  instance_mat_col = plhs[0];
325  mxDestroyArray(prhs[0]);
326  }
327 
328  // each column is one instance
329  labels = mxGetPr(label_vec);
330  samples = mxGetPr(instance_mat_col);
331  ir = mxGetIr(instance_mat_col);
332  jc = mxGetJc(instance_mat_col);
333 
334  num_samples = (int)mxGetNzmax(instance_mat_col);
335 
336  // the number of instance
337  prob.l = (int)mxGetN(instance_mat_col);
338  label_vector_row_num = (int)mxGetM(label_vec);
339 
340  if (label_vector_row_num != prob.l)
341  {
342  mexPrintf("Length of label vector does not match # of instances.\n");
343  return -1;
344  }
345 
346  elements = num_samples + prob.l;
347  max_index = (int)mxGetM(instance_mat_col);
348 
349  prob.y = Malloc(double, prob.l);
350  prob.x = Malloc(struct svm_node *, prob.l);
351  x_space = Malloc(struct svm_node, elements);
352 
353  j = 0;
354  for (i = 0; i < prob.l; i++)
355  {
356  prob.x[i] = &x_space[j];
357  prob.y[i] = labels[i];
358  low = (int)jc[i], high = (int)jc[i + 1];
359  for (k = low; k < high; k++)
360  {
361  x_space[j].index = (int)ir[k] + 1;
362  x_space[j].value = samples[k];
363  j++;
364  }
365  x_space[j++].index = -1;
366  }
367 
368  if (param.gamma == 0 && max_index > 0)
369  param.gamma = 1.0 / max_index;
370 
371  return 0;
372 }
373 
374 static void fake_answer(mxArray *plhs[])
375 {
376  plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
377 }
378 
379 // Interface function of matlab
380 // now assume prhs[0]: label prhs[1]: features
381 void mexFunction(int nlhs, mxArray *plhs[],
382  int nrhs, const mxArray *prhs[])
383 {
384  const char *error_msg;
385 
386  // fix random seed to have same results for each run
387  // (for cross validation and probability estimation)
388  srand(1);
389 
390  // Transform the input Matrix to libsvm format
391  if (nrhs > 1 && nrhs < 4)
392  {
393  int err;
394 
395  if (!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]))
396  {
397  mexPrintf("Error: label vector and instance matrix must be double\n");
398  fake_answer(plhs);
399  return;
400  }
401 
402  if (parse_command_line(nrhs, prhs, NULL))
403  {
404  exit_with_help();
406  fake_answer(plhs);
407  return;
408  }
409 
410  if (mxIsSparse(prhs[1]))
411  {
413  {
414  // precomputed kernel requires dense matrix, so we make one
415  mxArray *rhs[1], *lhs[1];
416 
417  rhs[0] = mxDuplicateArray(prhs[1]);
418  if (mexCallMATLAB(1, lhs, 1, rhs, "full"))
419  {
420  mexPrintf("Error: cannot generate a full training instance matrix\n");
422  fake_answer(plhs);
423  return;
424  }
425  err = read_problem_dense(prhs[0], lhs[0]);
426  mxDestroyArray(lhs[0]);
427  mxDestroyArray(rhs[0]);
428  }
429  else
430  err = read_problem_sparse(prhs[0], prhs[1]);
431  }
432  else
433  err = read_problem_dense(prhs[0], prhs[1]);
434 
435  // svmtrain's original code
436  error_msg = svm_check_parameter(&prob, &param);
437 
438  if (err || error_msg)
439  {
440  if (error_msg != NULL)
441  mexPrintf("Error: %s\n", error_msg);
443  free(prob.y);
444  free(prob.x);
445  free(x_space);
446  fake_answer(plhs);
447  return;
448  }
449 
450  if (cross_validation)
451  {
452  double *ptr;
453  plhs[0] = mxCreateDoubleMatrix(1, 1, mxREAL);
454  ptr = mxGetPr(plhs[0]);
455  ptr[0] = do_cross_validation();
456  }
457  else
458  {
459  int nr_feat = (int)mxGetN(prhs[1]);
460  const char *error_msg;
461  model = svm_train(&prob, &param);
462  error_msg = model_to_matlab_structure(plhs, nr_feat, model);
463  if (error_msg)
464  mexPrintf("Error: can't convert libsvm model to matrix structure: %s\n", error_msg);
466  }
468  free(prob.y);
469  free(prob.x);
470  free(x_space);
471  }
472  else
473  {
474  exit_with_help();
475  fake_answer(plhs);
476  return;
477  }
478 }
string cmd
Definition: easy.py:48
int nr_fold
Definition: svmtrain.c:64
struct svm_problem prob
Definition: svmtrain.c:60
const char * svm_check_parameter(const svm_problem *prob, const svm_parameter *param)
Definition: svm.cpp:3029
def err(line_no, msg)
Definition: checkdata.py:18
struct svm_parameter param
Definition: svmtrain.c:59
def svm_train(arg1, arg2=None, arg3=None)
Definition: svmutil.py:77
double value
Definition: svm.h:15
int nr_weight
Definition: svm.h:40
void svm_cross_validation(const svm_problem *prob, const svm_parameter *param, int nr_fold, double *target)
Definition: svm.cpp:2352
int * weight_label
Definition: svm.h:41
Definition: svm.h:25
int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
Definition: svmtrain.c:111
static void fake_answer(mxArray *plhs[])
Definition: svmtrain.c:374
Definition: svm.h:52
double p
Definition: svm.h:44
int cross_validation
Definition: svmtrain.c:63
double cache_size
Definition: svm.h:37
void print_string_matlab(const char *s)
Definition: svmtrain.c:20
int read_problem_dense(const mxArray *label_vec, const mxArray *instance_mat)
Definition: svmtrain.c:225
double eps
Definition: svm.h:38
int max_index
Definition: svm-scale.c:29
int shrinking
Definition: svm.h:45
void print_null(const char *s)
Definition: svmtrain.c:19
struct svm_node ** x
Definition: svm.h:22
void svm_free_and_destroy_model(svm_model **model_ptr_ptr)
Definition: svm.cpp:3013
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
Definition: svmtrain.c:381
#define CMD_LEN
Definition: svmtrain.c:16
#define Malloc(type, n)
Definition: svmtrain.c:17
Definition: svm.h:26
int index
Definition: svm.h:14
int read_problem_sparse(const mxArray *label_vec, const mxArray *instance_mat)
Definition: svmtrain.c:303
void exit_with_help()
Definition: svmtrain.c:25
int probability
Definition: svm.h:46
int degree
Definition: svm.h:32
Definition: svm.h:25
Definition: svm.h:12
double * y
Definition: svm.h:21
double gamma
Definition: svm.h:33
int l
Definition: svm.h:20
double * weight
Definition: svm.h:42
double C
Definition: svm.h:39
int svm_type
Definition: svm.h:30
double nu
Definition: svm.h:43
struct svm_model * model
Definition: svmtrain.c:61
double coef0
Definition: svm.h:34
double do_cross_validation()
Definition: svmtrain.c:67
void svm_destroy_param(svm_parameter *param)
Definition: svm.cpp:3023
const char * model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
int kernel_type
Definition: svm.h:31
struct svm_node * x_space
Definition: svmtrain.c:62
void svm_set_print_string_function(void(*print_func)(const char *))
Definition: svm.cpp:3158


ml_classifiers
Author(s): Scott Niekum , Joshua Whitley
autogenerated on Mon Feb 28 2022 22:46:49