svm_predict.java
Go to the documentation of this file.
1 import libsvm.*;
2 import java.io.*;
3 import java.util.*;
4 
5 class svm_predict {
6  private static svm_print_interface svm_print_null = new svm_print_interface()
7  {
8  public void print(String s) {}
9  };
10 
11  private static svm_print_interface svm_print_stdout = new svm_print_interface()
12  {
13  public void print(String s)
14  {
15  System.out.print(s);
16  }
17  };
18 
19  private static svm_print_interface svm_print_string = svm_print_stdout;
20 
21  static void info(String s)
22  {
23  svm_print_string.print(s);
24  }
25 
26  private static double atof(String s)
27  {
28  return Double.valueOf(s).doubleValue();
29  }
30 
31  private static int atoi(String s)
32  {
33  return Integer.parseInt(s);
34  }
35 
36  private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException
37  {
38  int correct = 0;
39  int total = 0;
40  double error = 0;
41  double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
42 
43  int svm_type=svm.svm_get_svm_type(model);
44  int nr_class=svm.svm_get_nr_class(model);
45  double[] prob_estimates=null;
46 
47  if(predict_probability == 1)
48  {
49  if(svm_type == svm_parameter.EPSILON_SVR ||
50  svm_type == svm_parameter.NU_SVR)
51  {
52  svm_predict.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+svm.svm_get_svr_probability(model)+"\n");
53  }
54  else
55  {
56  int[] labels=new int[nr_class];
57  svm.svm_get_labels(model,labels);
58  prob_estimates = new double[nr_class];
59  output.writeBytes("labels");
60  for(int j=0;j<nr_class;j++)
61  output.writeBytes(" "+labels[j]);
62  output.writeBytes("\n");
63  }
64  }
65  while(true)
66  {
67  String line = input.readLine();
68  if(line == null) break;
69 
70  StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
71 
72  double target = atof(st.nextToken());
73  int m = st.countTokens()/2;
74  svm_node[] x = new svm_node[m];
75  for(int j=0;j<m;j++)
76  {
77  x[j] = new svm_node();
78  x[j].index = atoi(st.nextToken());
79  x[j].value = atof(st.nextToken());
80  }
81 
82  double v;
83  if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm_parameter.NU_SVC))
84  {
85  v = svm.svm_predict_probability(model,x,prob_estimates);
86  output.writeBytes(v+" ");
87  for(int j=0;j<nr_class;j++)
88  output.writeBytes(prob_estimates[j]+" ");
89  output.writeBytes("\n");
90  }
91  else
92  {
93  v = svm.svm_predict(model,x);
94  output.writeBytes(v+"\n");
95  }
96 
97  if(v == target)
98  ++correct;
99  error += (v-target)*(v-target);
100  sumv += v;
101  sumy += target;
102  sumvv += v*v;
103  sumyy += target*target;
104  sumvy += v*target;
105  ++total;
106  }
107  if(svm_type == svm_parameter.EPSILON_SVR ||
108  svm_type == svm_parameter.NU_SVR)
109  {
110  svm_predict.info("Mean squared error = "+error/total+" (regression)\n");
111  svm_predict.info("Squared correlation coefficient = "+
112  ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
113  ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+
114  " (regression)\n");
115  }
116  else
117  svm_predict.info("Accuracy = "+(double)correct/total*100+
118  "% ("+correct+"/"+total+") (classification)\n");
119  }
120 
121  private static void exit_with_help()
122  {
123  System.err.print("usage: svm_predict [options] test_file model_file output_file\n"
124  +"options:\n"
125  +"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
126  +"-q : quiet mode (no outputs)\n");
127  System.exit(1);
128  }
129 
130  public static void main(String argv[]) throws IOException
131  {
132  int i, predict_probability=0;
133  svm_print_string = svm_print_stdout;
134 
135  // parse options
136  for(i=0;i<argv.length;i++)
137  {
138  if(argv[i].charAt(0) != '-') break;
139  ++i;
140  switch(argv[i-1].charAt(1))
141  {
142  case 'b':
143  predict_probability = atoi(argv[i]);
144  break;
145  case 'q':
146  svm_print_string = svm_print_null;
147  i--;
148  break;
149  default:
150  System.err.print("Unknown option: " + argv[i-1] + "\n");
151  exit_with_help();
152  }
153  }
154  if(i>=argv.length-2)
155  exit_with_help();
156  try
157  {
158  BufferedReader input = new BufferedReader(new FileReader(argv[i]));
159  DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2])));
160  svm_model model = svm.svm_load_model(argv[i+1]);
161  if(predict_probability == 1)
162  {
163  if(svm.svm_check_probability_model(model)==0)
164  {
165  System.err.print("Model does not support probabiliy estimates\n");
166  System.exit(1);
167  }
168  }
169  else
170  {
171  if(svm.svm_check_probability_model(model)!=0)
172  {
173  svm_predict.info("Model supports probability estimates, but disabled in prediction.\n");
174  }
175  }
176  predict(input,output,model,predict_probability);
177  input.close();
178  output.close();
179  }
180  catch(FileNotFoundException e)
181  {
182  exit_with_help();
183  }
184  catch(ArrayIndexOutOfBoundsException e)
185  {
186  exit_with_help();
187  }
188  }
189 }
static final int C_SVC
static final int NU_SVR
Definition: svm.py:1
void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
Definition: svmpredict.c:49
void exit_with_help()
Definition: libsvmread.c:21
Definition: svm.java:5
ROSCONSOLE_DECL void print(FilterBase *filter, void *logger, Level level, const char *file, int line, const char *function, const char *fmt,...) ROSCONSOLE_PRINTF_ATTRIBUTE(7
double value
Definition: svm_node.java:5
def svm_predict(y, x, m, options="")
Definition: svmutil.py:164
struct svm_node * x
Definition: svm-predict.c:12
void output(int index, double value)
Definition: svm-scale.c:367
static void(* svm_print_string)(const char *)
Definition: svm.cpp:57
int(* info)(const char *fmt,...)
Definition: svmpredict.c:18
int predict_probability
Definition: svm-predict.c:16
char * line
Definition: svm-scale.c:21
struct svm_model * model
Definition: svmtrain.c:61
int main(int argc, char **argv)
static final int EPSILON_SVR
static final int NU_SVC


ml_classifiers
Author(s): Scott Niekum , Joshua Whitley
autogenerated on Sun Dec 15 2019 03:53:50