svm_train.java
Go to the documentation of this file.
1 import libsvm.*;
2 import java.io.*;
3 import java.util.*;
4 
5 class svm_train {
6  private svm_parameter param; // set by parse_command_line
7  private svm_problem prob; // set by read_problem
8  private svm_model model;
9  private String input_file_name; // set by parse_command_line
10  private String model_file_name; // set by parse_command_line
11  private String error_msg;
12  private int cross_validation;
13  private int nr_fold;
14 
15  private static svm_print_interface svm_print_null = new svm_print_interface()
16  {
17  public void print(String s) {}
18  };
19 
20  private static void exit_with_help()
21  {
22  System.out.print(
23  "Usage: svm_train [options] training_set_file [model_file]\n"
24  +"options:\n"
25  +"-s svm_type : set type of SVM (default 0)\n"
26  +" 0 -- C-SVC (multi-class classification)\n"
27  +" 1 -- nu-SVC (multi-class classification)\n"
28  +" 2 -- one-class SVM\n"
29  +" 3 -- epsilon-SVR (regression)\n"
30  +" 4 -- nu-SVR (regression)\n"
31  +"-t kernel_type : set type of kernel function (default 2)\n"
32  +" 0 -- linear: u'*v\n"
33  +" 1 -- polynomial: (gamma*u'*v + coef0)^degree\n"
34  +" 2 -- radial basis function: exp(-gamma*|u-v|^2)\n"
35  +" 3 -- sigmoid: tanh(gamma*u'*v + coef0)\n"
36  +" 4 -- precomputed kernel (kernel values in training_set_file)\n"
37  +"-d degree : set degree in kernel function (default 3)\n"
38  +"-g gamma : set gamma in kernel function (default 1/num_features)\n"
39  +"-r coef0 : set coef0 in kernel function (default 0)\n"
40  +"-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)\n"
41  +"-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)\n"
42  +"-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)\n"
43  +"-m cachesize : set cache memory size in MB (default 100)\n"
44  +"-e epsilon : set tolerance of termination criterion (default 0.001)\n"
45  +"-h shrinking : whether to use the shrinking heuristics, 0 or 1 (default 1)\n"
46  +"-b probability_estimates : whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)\n"
47  +"-wi weight : set the parameter C of class i to weight*C, for C-SVC (default 1)\n"
48  +"-v n : n-fold cross validation mode\n"
49  +"-q : quiet mode (no outputs)\n"
50  );
51  System.exit(1);
52  }
53 
54  private void do_cross_validation()
55  {
56  int i;
57  int total_correct = 0;
58  double total_error = 0;
59  double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
60  double[] target = new double[prob.l];
61 
62  svm.svm_cross_validation(prob,param,nr_fold,target);
63  if(param.svm_type == svm_parameter.EPSILON_SVR ||
64  param.svm_type == svm_parameter.NU_SVR)
65  {
66  for(i=0;i<prob.l;i++)
67  {
68  double y = prob.y[i];
69  double v = target[i];
70  total_error += (v-y)*(v-y);
71  sumv += v;
72  sumy += y;
73  sumvv += v*v;
74  sumyy += y*y;
75  sumvy += v*y;
76  }
77  System.out.print("Cross Validation Mean squared error = "+total_error/prob.l+"\n");
78  System.out.print("Cross Validation Squared correlation coefficient = "+
79  ((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
80  ((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))+"\n"
81  );
82  }
83  else
84  {
85  for(i=0;i<prob.l;i++)
86  if(target[i] == prob.y[i])
87  ++total_correct;
88  System.out.print("Cross Validation Accuracy = "+100.0*total_correct/prob.l+"%\n");
89  }
90  }
91 
92  private void run(String argv[]) throws IOException
93  {
94  parse_command_line(argv);
95  read_problem();
96  error_msg = svm.svm_check_parameter(prob,param);
97 
98  if(error_msg != null)
99  {
100  System.err.print("ERROR: "+error_msg+"\n");
101  System.exit(1);
102  }
103 
104  if(cross_validation != 0)
105  {
107  }
108  else
109  {
110  model = svm.svm_train(prob,param);
111  svm.svm_save_model(model_file_name,model);
112  }
113  }
114 
115  public static void main(String argv[]) throws IOException
116  {
117  svm_train t = new svm_train();
118  t.run(argv);
119  }
120 
121  private static double atof(String s)
122  {
123  double d = Double.valueOf(s).doubleValue();
124  if (Double.isNaN(d) || Double.isInfinite(d))
125  {
126  System.err.print("NaN or Infinity in input\n");
127  System.exit(1);
128  }
129  return(d);
130  }
131 
132  private static int atoi(String s)
133  {
134  return Integer.parseInt(s);
135  }
136 
137  private void parse_command_line(String argv[])
138  {
139  int i;
140  svm_print_interface print_func = null; // default printing to stdout
141 
142  param = new svm_parameter();
143  // default values
144  param.svm_type = svm_parameter.C_SVC;
145  param.kernel_type = svm_parameter.RBF;
146  param.degree = 3;
147  param.gamma = 0; // 1/num_features
148  param.coef0 = 0;
149  param.nu = 0.5;
150  param.cache_size = 100;
151  param.C = 1;
152  param.eps = 1e-3;
153  param.p = 0.1;
154  param.shrinking = 1;
155  param.probability = 0;
156  param.nr_weight = 0;
157  param.weight_label = new int[0];
158  param.weight = new double[0];
159  cross_validation = 0;
160 
161  // parse options
162  for(i=0;i<argv.length;i++)
163  {
164  if(argv[i].charAt(0) != '-') break;
165  if(++i>=argv.length)
166  exit_with_help();
167  switch(argv[i-1].charAt(1))
168  {
169  case 's':
170  param.svm_type = atoi(argv[i]);
171  break;
172  case 't':
173  param.kernel_type = atoi(argv[i]);
174  break;
175  case 'd':
176  param.degree = atoi(argv[i]);
177  break;
178  case 'g':
179  param.gamma = atof(argv[i]);
180  break;
181  case 'r':
182  param.coef0 = atof(argv[i]);
183  break;
184  case 'n':
185  param.nu = atof(argv[i]);
186  break;
187  case 'm':
188  param.cache_size = atof(argv[i]);
189  break;
190  case 'c':
191  param.C = atof(argv[i]);
192  break;
193  case 'e':
194  param.eps = atof(argv[i]);
195  break;
196  case 'p':
197  param.p = atof(argv[i]);
198  break;
199  case 'h':
200  param.shrinking = atoi(argv[i]);
201  break;
202  case 'b':
203  param.probability = atoi(argv[i]);
204  break;
205  case 'q':
206  print_func = svm_print_null;
207  i--;
208  break;
209  case 'v':
210  cross_validation = 1;
211  nr_fold = atoi(argv[i]);
212  if(nr_fold < 2)
213  {
214  System.err.print("n-fold cross validation: n must >= 2\n");
215  exit_with_help();
216  }
217  break;
218  case 'w':
219  ++param.nr_weight;
220  {
221  int[] old = param.weight_label;
222  param.weight_label = new int[param.nr_weight];
223  System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
224  }
225 
226  {
227  double[] old = param.weight;
228  param.weight = new double[param.nr_weight];
229  System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
230  }
231 
232  param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
233  param.weight[param.nr_weight-1] = atof(argv[i]);
234  break;
235  default:
236  System.err.print("Unknown option: " + argv[i-1] + "\n");
237  exit_with_help();
238  }
239  }
240 
241  svm.svm_set_print_string_function(print_func);
242 
243  // determine filenames
244 
245  if(i>=argv.length)
246  exit_with_help();
247 
248  input_file_name = argv[i];
249 
250  if(i<argv.length-1)
251  model_file_name = argv[i+1];
252  else
253  {
254  int p = argv[i].lastIndexOf('/');
255  ++p; // whew...
256  model_file_name = argv[i].substring(p)+".model";
257  }
258  }
259 
260  // read in a problem (in svmlight format)
261 
262  private void read_problem() throws IOException
263  {
264  BufferedReader fp = new BufferedReader(new FileReader(input_file_name));
265  Vector<Double> vy = new Vector<Double>();
266  Vector<svm_node[]> vx = new Vector<svm_node[]>();
267  int max_index = 0;
268 
269  while(true)
270  {
271  String line = fp.readLine();
272  if(line == null) break;
273 
274  StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
275 
276  vy.addElement(atof(st.nextToken()));
277  int m = st.countTokens()/2;
278  svm_node[] x = new svm_node[m];
279  for(int j=0;j<m;j++)
280  {
281  x[j] = new svm_node();
282  x[j].index = atoi(st.nextToken());
283  x[j].value = atof(st.nextToken());
284  }
285  if(m>0) max_index = Math.max(max_index, x[m-1].index);
286  vx.addElement(x);
287  }
288 
289  prob = new svm_problem();
290  prob.l = vy.size();
291  prob.x = new svm_node[prob.l][];
292  for(int i=0;i<prob.l;i++)
293  prob.x[i] = vx.elementAt(i);
294  prob.y = new double[prob.l];
295  for(int i=0;i<prob.l;i++)
296  prob.y[i] = vy.elementAt(i);
297 
298  if(param.gamma == 0 && max_index > 0)
299  param.gamma = 1.0/max_index;
300 
302  for(int i=0;i<prob.l;i++)
303  {
304  if (prob.x[i][0].index != 0)
305  {
306  System.err.print("Wrong kernel matrix: first column must be 0:sample_serial_number\n");
307  System.exit(1);
308  }
309  if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
310  {
311  System.err.print("Wrong input format: sample_serial_number out of range\n");
312  System.exit(1);
313  }
314  }
315 
316  fp.close();
317  }
318 }
static final int C_SVC
d
int nr_fold
Definition: svmtrain.c:64
struct svm_problem prob
Definition: svmtrain.c:60
static final int NU_SVR
struct svm_parameter param
Definition: svmtrain.c:59
Definition: svm.py:1
def svm_train(arg1, arg2=None, arg3=None)
Definition: svmutil.py:77
void read_problem(const char *filename, mxArray *plhs[])
Definition: libsvmread.c:56
void exit_with_help()
Definition: libsvmread.c:21
Definition: svm.java:5
int parse_command_line(int nrhs, const mxArray *prhs[], char *model_file_name)
Definition: svmtrain.c:111
ROSCONSOLE_DECL void print(FilterBase *filter, void *logger, Level level, const char *file, int line, const char *function, const char *fmt,...) ROSCONSOLE_PRINTF_ATTRIBUTE(7
double value
Definition: svm_node.java:5
int cross_validation
Definition: svmtrain.c:63
struct svm_node * x
Definition: svm-predict.c:12
static final int RBF
int max_index
Definition: svm-scale.c:29
char * line
Definition: svm-scale.c:21
svm_node[][] x
Definition: svm_problem.java:6
static final int PRECOMPUTED
struct svm_model * model
Definition: svmtrain.c:61
void run(ClassLoader *loader)
int main(int argc, char **argv)
double do_cross_validation()
Definition: svmtrain.c:67
static final int EPSILON_SVR


ml_classifiers
Author(s): Scott Niekum , Joshua Whitley
autogenerated on Sun Dec 15 2019 03:53:50