svm_predict.java
Go to the documentation of this file.
00001 import libsvm.*;
00002 import java.io.*;
00003 import java.util.*;
00004 
00005 class svm_predict {
00006         private static double atof(String s)
00007         {
00008                 return Double.valueOf(s).doubleValue();
00009         }
00010 
00011         private static int atoi(String s)
00012         {
00013                 return Integer.parseInt(s);
00014         }
00015 
00016         private static void predict(BufferedReader input, DataOutputStream output, svm_model model, int predict_probability) throws IOException
00017         {
00018                 int correct = 0;
00019                 int total = 0;
00020                 double error = 0;
00021                 double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
00022 
00023                 int svm_type=svm.svm_get_svm_type(model);
00024                 int nr_class=svm.svm_get_nr_class(model);
00025                 double[] prob_estimates=null;
00026 
00027                 if(predict_probability == 1)
00028                 {
00029                         if(svm_type == svm_parameter.EPSILON_SVR ||
00030                            svm_type == svm_parameter.NU_SVR)
00031                         {
00032                                 System.out.print("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma="+svm.svm_get_svr_probability(model)+"\n");
00033                         }
00034                         else
00035                         {
00036                                 int[] labels=new int[nr_class];
00037                                 svm.svm_get_labels(model,labels);
00038                                 prob_estimates = new double[nr_class];
00039                                 output.writeBytes("labels");
00040                                 for(int j=0;j<nr_class;j++)
00041                                         output.writeBytes(" "+labels[j]);
00042                                 output.writeBytes("\n");
00043                         }
00044                 }
00045                 while(true)
00046                 {
00047                         String line = input.readLine();
00048                         if(line == null) break;
00049 
00050                         StringTokenizer st = new StringTokenizer(line," \t\n\r\f:");
00051 
00052                         double target = atof(st.nextToken());
00053                         int m = st.countTokens()/2;
00054                         svm_node[] x = new svm_node[m];
00055                         for(int j=0;j<m;j++)
00056                         {
00057                                 x[j] = new svm_node();
00058                                 x[j].index = atoi(st.nextToken());
00059                                 x[j].value = atof(st.nextToken());
00060                         }
00061 
00062                         double v;
00063                         if (predict_probability==1 && (svm_type==svm_parameter.C_SVC || svm_type==svm_parameter.NU_SVC))
00064                         {
00065                                 v = svm.svm_predict_probability(model,x,prob_estimates);
00066                                 output.writeBytes(v+" ");
00067                                 for(int j=0;j<nr_class;j++)
00068                                         output.writeBytes(prob_estimates[j]+" ");
00069                                 output.writeBytes("\n");
00070                         }
00071                         else
00072                         {
00073                                 v = svm.svm_predict(model,x);
00074                                 output.writeBytes(v+"\n");
00075                         }
00076 
00077                         if(v == target)
00078                                 ++correct;
00079                         error += (v-target)*(v-target);
00080                         sumv += v;
00081                         sumy += target;
00082                         sumvv += v*v;
00083                         sumyy += target*target;
00084                         sumvy += v*target;
00085                         ++total;
00086                 }
00087                 if(svm_type == svm_parameter.EPSILON_SVR ||
00088                    svm_type == svm_parameter.NU_SVR)
00089                 {
00090                         System.out.print("Mean squared error = "+error/total+" (regression)\n");
00091                         System.out.print("Squared correlation coefficient = "+
00092                                  ((total*sumvy-sumv*sumy)*(total*sumvy-sumv*sumy))/
00093                                  ((total*sumvv-sumv*sumv)*(total*sumyy-sumy*sumy))+
00094                                  " (regression)\n");
00095                 }
00096                 else
00097                         System.out.print("Accuracy = "+(double)correct/total*100+
00098                                  "% ("+correct+"/"+total+") (classification)\n");
00099         }
00100 
00101         private static void exit_with_help()
00102         {
00103                 System.err.print("usage: svm_predict [options] test_file model_file output_file\n"
00104                 +"options:\n"
00105                 +"-b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n");
00106                 System.exit(1);
00107         }
00108 
00109         public static void main(String argv[]) throws IOException
00110         {
00111                 int i, predict_probability=0;
00112 
00113                 // parse options
00114                 for(i=0;i<argv.length;i++)
00115                 {
00116                         if(argv[i].charAt(0) != '-') break;
00117                         ++i;
00118                         switch(argv[i-1].charAt(1))
00119                         {
00120                                 case 'b':
00121                                         predict_probability = atoi(argv[i]);
00122                                         break;
00123                                 default:
00124                                         System.err.print("Unknown option: " + argv[i-1] + "\n");
00125                                         exit_with_help();
00126                         }
00127                 }
00128                 if(i>=argv.length-2)
00129                         exit_with_help();
00130                 try 
00131                 {
00132                         BufferedReader input = new BufferedReader(new FileReader(argv[i]));
00133                         DataOutputStream output = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(argv[i+2])));
00134                         svm_model model = svm.svm_load_model(argv[i+1]);
00135                         if(predict_probability == 1)
00136                         {
00137                                 if(svm.svm_check_probability_model(model)==0)
00138                                 {
00139                                         System.err.print("Model does not support probabiliy estimates\n");
00140                                         System.exit(1);
00141                                 }
00142                         }
00143                         else
00144                         {
00145                                 if(svm.svm_check_probability_model(model)!=0)
00146                                 {
00147                                         System.out.print("Model supports probability estimates, but disabled in prediction.\n");
00148                                 }
00149                         }
00150                         predict(input,output,model,predict_probability);
00151                         input.close();
00152                         output.close();
00153                 } 
00154                 catch(FileNotFoundException e) 
00155                 {
00156                         exit_with_help();
00157                 }
00158                 catch(ArrayIndexOutOfBoundsException e) 
00159                 {
00160                         exit_with_help();
00161                 }
00162         }
00163 }


libsvm3
Author(s): various
autogenerated on Wed Nov 27 2013 11:36:23