Go to the documentation of this file.00001 import libsvm.*;
00002 import java.io.*;
00003 import java.util.*;
00004
00005 class svm_predict {
00006 private static svm_print_interface svm_print_null = new svm_print_interface()
00007 {
00008 public void print(String s) {}
00009 };
00010
00011 private static svm_print_interface svm_print_stdout = new svm_print_interface()
00012 {
00013 public void print(String s)
00014 {
00015 System.out.print(s);
00016 }
00017 };
00018
00019 private static svm_print_interface svm_print_string = svm_print_stdout;
00020
00021 static void info(String s)
00022 {
00023 svm_print_string.print(s);
00024 }
00025
00026 private static double atof(String s)
00027 {
00028 return Double.valueOf(s).doubleValue();
00029 }
00030
00031 private static int atoi(String s)
00032 {
00033 return Integer.parseInt(s);
00034 }
00035
00036 private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException
00037 {
00038 int correct = 0;
00039 int total = 0;
00040 double error = 0;
00041 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
00042
00043 int svm_type=svm.svm_get_svm_type(model);
00044 int nr_class=svm.svm_get_nr_class(model);
00045 double[] prob_estimates=null;
00046
00047 if(predict_probability == 1)
00048 {
00049 if(svm_type == svm_parameter.EPSILON_SVR ||
00050 svm_type == svm_parameter.NU_SVR)
00051 {
00052 svm_predict.info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+svm.svm_get_svr_probability(model)+"\n");
00053 }
00054 else
00055 {
00056 int[] labels=new int[nr_class];
00057 svm.svm_get_labels(model,labels);
00058 prob_estimates = new double[nr_class];
00059 output.writeBytes("labels");
00060 for(int j=0;j<nr_class;j++)
00061 output.writeBytes(" "+labels[j]);
00062 output.writeBytes("\n");
00063 }
00064 }
00065 while(true)
00066 {
00067 String line = input.readLine();
00068 if(line == null) break;
00069
00070 StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
00071
00072 double target = atof(st.nextToken());
00073 int m = st.countTokens()/2;
00074 svm_node[] x = new svm_node[m];
00075 for(int j=0;j<m;j++)
00076 {
00077 x[j] = new svm_node();
00078 x[j].index = atoi(st.nextToken());
00079 x[j].value = atof(st.nextToken());
00080 }
00081
00082 double v;
00083 if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm_parameter.NU_SVC))
00084 {
00085 v = svm.svm_predict_probability(model,x,prob_estimates);
00086 output.writeBytes(v+" ");
00087 for(int j=0;j<nr_class;j++)
00088 output.writeBytes(prob_estimates[j]+" ");
00089 output.writeBytes("\n");
00090 }
00091 else
00092 {
00093 v = svm.svm_predict(model,x);
00094 output.writeBytes(v+"\n");
00095 }
00096
00097 if(v == target)
00098 ++correct;
00099 error += (v-target)*(v-target);
00100 sumv += v;
00101 sumy += target;
00102 sumvv += v*v;
00103 sumyy += target*target;
00104 sumvy += v*target;
00105 ++total;
00106 }
00107 if(svm_type == svm_parameter.EPSILON_SVR ||
00108 svm_type == svm_parameter.NU_SVR)
00109 {
00110 svm_predict.info("Mean squared error = "+error/total+" (regression)\n");
00111 svm_predict.info("Squared correlation coefficient = "+
00112 ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
00113 ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+
00114 " (regression)\n");
00115 }
00116 else
00117 svm_predict.info("Accuracy = "+(double)correct/total*100+
00118 "% ("+correct+"/"+total+") (classification)\n");
00119 }
00120
00121 private static void exit_with_help()
00122 {
00123 System.err.print("usage: svm_predict [options] test_file model_file output_file\n"
00124 +"options:\n"
00125 +"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
00126 +"-q : quiet mode (no outputs)\n");
00127 System.exit(1);
00128 }
00129
00130 public static void main(String argv[]) throws IOException
00131 {
00132 int i, predict_probability=0;
00133 svm_print_string = svm_print_stdout;
00134
00135
00136 for(i=0;i<argv.length;i++)
00137 {
00138 if(argv[i].charAt(0) != '-') break;
00139 ++i;
00140 switch(argv[i-1].charAt(1))
00141 {
00142 case 'b':
00143 predict_probability = atoi(argv[i]);
00144 break;
00145 case 'q':
00146 svm_print_string = svm_print_null;
00147 i--;
00148 break;
00149 default:
00150 System.err.print("Unknown option: " + argv[i-1] + "\n");
00151 exit_with_help();
00152 }
00153 }
00154 if(i>=argv.length-2)
00155 exit_with_help();
00156 try
00157 {
00158 BufferedReader input = new BufferedReader(new FileReader(argv[i]));
00159 DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2])));
00160 svm_model model = svm.svm_load_model(argv[i+1]);
00161 if(predict_probability == 1)
00162 {
00163 if(svm.svm_check_probability_model(model)==0)
00164 {
00165 System.err.print("Model does not support probabiliy estimates\n");
00166 System.exit(1);
00167 }
00168 }
00169 else
00170 {
00171 if(svm.svm_check_probability_model(model)!=0)
00172 {
00173 svm_predict.info("Model supports probability estimates, but disabled in prediction.\n");
00174 }
00175 }
00176 predict(input,output,model,predict_probability);
00177 input.close();
00178 output.close();
00179 }
00180 catch(FileNotFoundException e)
00181 {
00182 exit_with_help();
00183 }
00184 catch(ArrayIndexOutOfBoundsException e)
00185 {
00186 exit_with_help();
00187 }
00188 }
00189 }