svmpredict.c
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include "../svm.h"
00005 
00006 #include "mex.h"
00007 #include "svm_model_matlab.h"
00008 
00009 #ifdef MX_API_VER
00010 #if MX_API_VER < 0x07030000
00011 typedef int mwIndex;
00012 #endif
00013 #endif
00014 
00015 #define CMD_LEN 2048
00016 
00017 int print_null(const char *s,...) {}
00018 int (*info)(const char *fmt,...) = &mexPrintf;
00019 
00020 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
00021 {
00022         int i, j, low, high;
00023         mwIndex *ir, *jc;
00024         double *samples;
00025 
00026         ir = mxGetIr(prhs);
00027         jc = mxGetJc(prhs);
00028         samples = mxGetPr(prhs);
00029 
00030         // each column is one instance
00031         j = 0;
00032         low = (int)jc[index], high = (int)jc[index+1];
00033         for(i=low;i<high;i++)
00034         {
00035                 x[j].index = (int)ir[i] + 1;
00036                 x[j].value = samples[i];
00037                 j++;
00038         }
00039         x[j].index = -1;
00040 }
00041 
00042 static void fake_answer(mxArray *plhs[])
00043 {
00044         plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
00045         plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
00046         plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00047 }
00048 
00049 void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
00050 {
00051         int label_vector_row_num, label_vector_col_num;
00052         int feature_number, testing_instance_number;
00053         int instance_index;
00054         double *ptr_instance, *ptr_label, *ptr_predict_label; 
00055         double *ptr_prob_estimates, *ptr_dec_values, *ptr;
00056         struct svm_node *x;
00057         mxArray *pplhs[1]; // transposed instance sparse matrix
00058 
00059         int correct = 0;
00060         int total = 0;
00061         double error = 0;
00062         double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
00063 
00064         int svm_type=svm_get_svm_type(model);
00065         int nr_class=svm_get_nr_class(model);
00066         double *prob_estimates=NULL;
00067 
00068         // prhs[1] = testing instance matrix
00069         feature_number = (int)mxGetN(prhs[1]);
00070         testing_instance_number = (int)mxGetM(prhs[1]);
00071         label_vector_row_num = (int)mxGetM(prhs[0]);
00072         label_vector_col_num = (int)mxGetN(prhs[0]);
00073 
00074         if(label_vector_row_num!=testing_instance_number)
00075         {
00076                 mexPrintf("Length of label vector does not match # of instances.\n");
00077                 fake_answer(plhs);
00078                 return;
00079         }
00080         if(label_vector_col_num!=1)
00081         {
00082                 mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
00083                 fake_answer(plhs);
00084                 return;
00085         }
00086 
00087         ptr_instance = mxGetPr(prhs[1]);
00088         ptr_label    = mxGetPr(prhs[0]);
00089 
00090         // transpose instance matrix
00091         if(mxIsSparse(prhs[1]))
00092         {
00093                 if(model->param.kernel_type == PRECOMPUTED)
00094                 {
00095                         // precomputed kernel requires dense matrix, so we make one
00096                         mxArray *rhs[1], *lhs[1];
00097                         rhs[0] = mxDuplicateArray(prhs[1]);
00098                         if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
00099                         {
00100                                 mexPrintf("Error: cannot full testing instance matrix\n");
00101                                 fake_answer(plhs);
00102                                 return;
00103                         }
00104                         ptr_instance = mxGetPr(lhs[0]);
00105                         mxDestroyArray(rhs[0]);
00106                 }
00107                 else
00108                 {
00109                         mxArray *pprhs[1];
00110                         pprhs[0] = mxDuplicateArray(prhs[1]);
00111                         if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00112                         {
00113                                 mexPrintf("Error: cannot transpose testing instance matrix\n");
00114                                 fake_answer(plhs);
00115                                 return;
00116                         }
00117                 }
00118         }
00119 
00120         if(predict_probability)
00121         {
00122                 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00123                         info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
00124                 else
00125                         prob_estimates = (double *) malloc(nr_class*sizeof(double));
00126         }
00127 
00128         plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00129         if(predict_probability)
00130         {
00131                 // prob estimates are in plhs[2]
00132                 if(svm_type==C_SVC || svm_type==NU_SVC)
00133                         plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
00134                 else
00135                         plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00136         }
00137         else
00138         {
00139                 // decision values are in plhs[2]
00140                 if(svm_type == ONE_CLASS ||
00141                    svm_type == EPSILON_SVR ||
00142                    svm_type == NU_SVR ||
00143                    nr_class == 1) // if only one class in training data, decision values are still returned.
00144                         plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00145                 else
00146                         plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
00147         }
00148 
00149         ptr_predict_label = mxGetPr(plhs[0]);
00150         ptr_prob_estimates = mxGetPr(plhs[2]);
00151         ptr_dec_values = mxGetPr(plhs[2]);
00152         x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
00153         for(instance_index=0;instance_index<testing_instance_number;instance_index++)
00154         {
00155                 int i;
00156                 double target_label, predict_label;
00157 
00158                 target_label = ptr_label[instance_index];
00159 
00160                 if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
00161                         read_sparse_instance(pplhs[0], instance_index, x);
00162                 else
00163                 {
00164                         for(i=0;i<feature_number;i++)
00165                         {
00166                                 x[i].index = i+1;
00167                                 x[i].value = ptr_instance[testing_instance_number*i+instance_index];
00168                         }
00169                         x[feature_number].index = -1;
00170                 }
00171 
00172                 if(predict_probability)
00173                 {
00174                         if(svm_type==C_SVC || svm_type==NU_SVC)
00175                         {
00176                                 predict_label = svm_predict_probability(model, x, prob_estimates);
00177                                 ptr_predict_label[instance_index] = predict_label;
00178                                 for(i=0;i<nr_class;i++)
00179                                         ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
00180                         } else {
00181                                 predict_label = svm_predict(model,x);
00182                                 ptr_predict_label[instance_index] = predict_label;
00183                         }
00184                 }
00185                 else
00186                 {
00187                         if(svm_type == ONE_CLASS ||
00188                            svm_type == EPSILON_SVR ||
00189                            svm_type == NU_SVR)
00190                         {
00191                                 double res;
00192                                 predict_label = svm_predict_values(model, x, &res);
00193                                 ptr_dec_values[instance_index] = res;
00194                         }
00195                         else
00196                         {
00197                                 double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
00198                                 predict_label = svm_predict_values(model, x, dec_values);
00199                                 if(nr_class == 1) 
00200                                         ptr_dec_values[instance_index] = 1;
00201                                 else
00202                                         for(i=0;i<(nr_class*(nr_class-1))/2;i++)
00203                                                 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
00204                                 free(dec_values);
00205                         }
00206                         ptr_predict_label[instance_index] = predict_label;
00207                 }
00208 
00209                 if(predict_label == target_label)
00210                         ++correct;
00211                 error += (predict_label-target_label)*(predict_label-target_label);
00212                 sump += predict_label;
00213                 sumt += target_label;
00214                 sumpp += predict_label*predict_label;
00215                 sumtt += target_label*target_label;
00216                 sumpt += predict_label*target_label;
00217                 ++total;
00218         }
00219         if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00220         {
00221                 info("Mean squared error = %g (regression)\n",error/total);
00222                 info("Squared correlation coefficient = %g (regression)\n",
00223                         ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00224                         ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
00225                         );
00226         }
00227         else
00228                 info("Accuracy = %g%% (%d/%d) (classification)\n",
00229                         (double)correct/total*100,correct,total);
00230 
00231         // return accuracy, mean squared error, squared correlation coefficient
00232         plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
00233         ptr = mxGetPr(plhs[1]);
00234         ptr[0] = (double)correct/total*100;
00235         ptr[1] = error/total;
00236         ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00237                                 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
00238 
00239         free(x);
00240         if(prob_estimates != NULL)
00241                 free(prob_estimates);
00242 }
00243 
00244 void exit_with_help()
00245 {
00246         mexPrintf(
00247                 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00248                 "Parameters:\n"
00249                 "  model: SVM model structure from svmtrain.\n"
00250                 "  libsvm_options:\n"
00251                 "    -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
00252                 "    -q : quiet mode (no outputs)\n"
00253                 "Returns:\n"
00254                 "  predicted_label: SVM prediction output vector.\n"
00255                 "  accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
00256                 "  prob_estimates: If selected, probability estimate vector.\n"
00257         );
00258 }
00259 
00260 void mexFunction( int nlhs, mxArray *plhs[],
00261                  int nrhs, const mxArray *prhs[] )
00262 {
00263         int prob_estimate_flag = 0;
00264         struct svm_model *model;
00265         info = &mexPrintf;
00266 
00267         if(nrhs > 4 || nrhs < 3)
00268         {
00269                 exit_with_help();
00270                 fake_answer(plhs);
00271                 return;
00272         }
00273 
00274         if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
00275                 mexPrintf("Error: label vector and instance matrix must be double\n");
00276                 fake_answer(plhs);
00277                 return;
00278         }
00279 
00280         if(mxIsStruct(prhs[2]))
00281         {
00282                 const char *error_msg;
00283 
00284                 // parse options
00285                 if(nrhs==4)
00286                 {
00287                         int i, argc = 1;
00288                         char cmd[CMD_LEN], *argv[CMD_LEN/2];
00289 
00290                         // put options in argv[]
00291                         mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
00292                         if((argv[argc] = strtok(cmd, " ")) != NULL)
00293                                 while((argv[++argc] = strtok(NULL, " ")) != NULL)
00294                                         ;
00295 
00296                         for(i=1;i<argc;i++)
00297                         {
00298                                 if(argv[i][0] != '-') break;
00299                                 if((++i>=argc) && argv[i-1][1] != 'q')
00300                                 {
00301                                         exit_with_help();
00302                                         fake_answer(plhs);
00303                                         return;
00304                                 }
00305                                 switch(argv[i-1][1])
00306                                 {
00307                                         case 'b':
00308                                                 prob_estimate_flag = atoi(argv[i]);
00309                                                 break;
00310                                         case 'q':
00311                                                 i--;
00312                                                 info = &print_null;
00313                                                 break;
00314                                         default:
00315                                                 mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
00316                                                 exit_with_help();
00317                                                 fake_answer(plhs);
00318                                                 return;
00319                                 }
00320                         }
00321                 }
00322 
00323                 model = matlab_matrix_to_model(prhs[2], &error_msg);
00324                 if (model == NULL)
00325                 {
00326                         mexPrintf("Error: can't read model: %s\n", error_msg);
00327                         fake_answer(plhs);
00328                         return;
00329                 }
00330 
00331                 if(prob_estimate_flag)
00332                 {
00333                         if(svm_check_probability_model(model)==0)
00334                         {
00335                                 mexPrintf("Model does not support probabiliy estimates\n");
00336                                 fake_answer(plhs);
00337                                 svm_free_and_destroy_model(&model);
00338                                 return;
00339                         }
00340                 }
00341                 else
00342                 {
00343                         if(svm_check_probability_model(model)!=0)
00344                                 info("Model supports probability estimates, but disabled in predicton.\n");
00345                 }
00346 
00347                 predict(plhs, prhs, model, prob_estimate_flag);
00348                 // destroy model
00349                 svm_free_and_destroy_model(&model);
00350         }
00351         else
00352         {
00353                 mexPrintf("model file should be a struct array\n");
00354                 fake_answer(plhs);
00355         }
00356 
00357         return;
00358 }


ml_classifiers
Author(s): Scott Niekum
autogenerated on Mon Oct 6 2014 02:20:58