svmpredict.c
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include "svm.h"
00005 
00006 #include "mex.h"
00007 #include "svm_model_matlab.h"
00008 
00009 #ifdef MX_API_VER
00010 #if MX_API_VER < 0x07030000
00011 typedef int mwIndex;
00012 #endif
00013 #endif
00014 
00015 #define CMD_LEN 2048
00016 
00017 int print_null(const char *s,...) {}
00018 int (*info)(const char *fmt,...) = &mexPrintf;
00019 
00020 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
00021 {
00022         int i, j, low, high;
00023         mwIndex *ir, *jc;
00024         double *samples;
00025 
00026         ir = mxGetIr(prhs);
00027         jc = mxGetJc(prhs);
00028         samples = mxGetPr(prhs);
00029 
00030         // each column is one instance
00031         j = 0;
00032         low = (int)jc[index], high = (int)jc[index+1];
00033         for(i=low;i<high;i++)
00034         {
00035                 x[j].index = (int)ir[i] + 1;
00036                 x[j].value = samples[i];
00037                 j++;
00038         }
00039         x[j].index = -1;
00040 }
00041 
00042 static void fake_answer(int nlhs, mxArray *plhs[])
00043 {
00044         int i;
00045         for(i=0;i<nlhs;i++)
00046                 plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
00047 }
00048 
00049 void predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
00050 {
00051         int label_vector_row_num, label_vector_col_num;
00052         int feature_number, testing_instance_number;
00053         int instance_index;
00054         double *ptr_instance, *ptr_label, *ptr_predict_label; 
00055         double *ptr_prob_estimates, *ptr_dec_values, *ptr;
00056         struct svm_node *x;
00057         mxArray *pplhs[1]; // transposed instance sparse matrix
00058         mxArray *tplhs[3]; // temporary storage for plhs[]
00059 
00060         int correct = 0;
00061         int total = 0;
00062         double error = 0;
00063         double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
00064 
00065         int svm_type=svm_get_svm_type(model);
00066         int nr_class=svm_get_nr_class(model);
00067         double *prob_estimates=NULL;
00068 
00069         // prhs[1] = testing instance matrix
00070         feature_number = (int)mxGetN(prhs[1]);
00071         testing_instance_number = (int)mxGetM(prhs[1]);
00072         label_vector_row_num = (int)mxGetM(prhs[0]);
00073         label_vector_col_num = (int)mxGetN(prhs[0]);
00074 
00075         if(label_vector_row_num!=testing_instance_number)
00076         {
00077                 mexPrintf("Length of label vector does not match # of instances.\n");
00078                 fake_answer(nlhs, plhs);
00079                 return;
00080         }
00081         if(label_vector_col_num!=1)
00082         {
00083                 mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
00084                 fake_answer(nlhs, plhs);
00085                 return;
00086         }
00087 
00088         ptr_instance = mxGetPr(prhs[1]);
00089         ptr_label    = mxGetPr(prhs[0]);
00090 
00091         // transpose instance matrix
00092         if(mxIsSparse(prhs[1]))
00093         {
00094                 if(model->param.kernel_type == PRECOMPUTED)
00095                 {
00096                         // precomputed kernel requires dense matrix, so we make one
00097                         mxArray *rhs[1], *lhs[1];
00098                         rhs[0] = mxDuplicateArray(prhs[1]);
00099                         if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
00100                         {
00101                                 mexPrintf("Error: cannot full testing instance matrix\n");
00102                                 fake_answer(nlhs, plhs);
00103                                 return;
00104                         }
00105                         ptr_instance = mxGetPr(lhs[0]);
00106                         mxDestroyArray(rhs[0]);
00107                 }
00108                 else
00109                 {
00110                         mxArray *pprhs[1];
00111                         pprhs[0] = mxDuplicateArray(prhs[1]);
00112                         if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00113                         {
00114                                 mexPrintf("Error: cannot transpose testing instance matrix\n");
00115                                 fake_answer(nlhs, plhs);
00116                                 return;
00117                         }
00118                 }
00119         }
00120 
00121         if(predict_probability)
00122         {
00123                 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00124                         info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
00125                 else
00126                         prob_estimates = (double *) malloc(nr_class*sizeof(double));
00127         }
00128 
00129         tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00130         if(predict_probability)
00131         {
00132                 // prob estimates are in plhs[2]
00133                 if(svm_type==C_SVC || svm_type==NU_SVC)
00134                         tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
00135                 else
00136                         tplhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00137         }
00138         else
00139         {
00140                 // decision values are in plhs[2]
00141                 if(svm_type == ONE_CLASS ||
00142                    svm_type == EPSILON_SVR ||
00143                    svm_type == NU_SVR ||
00144                    nr_class == 1) // if only one class in training data, decision values are still returned.
00145                         tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00146                 else
00147                         tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
00148         }
00149 
00150         ptr_predict_label = mxGetPr(tplhs[0]);
00151         ptr_prob_estimates = mxGetPr(tplhs[2]);
00152         ptr_dec_values = mxGetPr(tplhs[2]);
00153         x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
00154         for(instance_index=0;instance_index<testing_instance_number;instance_index++)
00155         {
00156                 int i;
00157                 double target_label, predict_label;
00158 
00159                 target_label = ptr_label[instance_index];
00160 
00161                 if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
00162                         read_sparse_instance(pplhs[0], instance_index, x);
00163                 else
00164                 {
00165                         for(i=0;i<feature_number;i++)
00166                         {
00167                                 x[i].index = i+1;
00168                                 x[i].value = ptr_instance[testing_instance_number*i+instance_index];
00169                         }
00170                         x[feature_number].index = -1;
00171                 }
00172 
00173                 if(predict_probability)
00174                 {
00175                         if(svm_type==C_SVC || svm_type==NU_SVC)
00176                         {
00177                                 predict_label = svm_predict_probability(model, x, prob_estimates);
00178                                 ptr_predict_label[instance_index] = predict_label;
00179                                 for(i=0;i<nr_class;i++)
00180                                         ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
00181                         } else {
00182                                 predict_label = svm_predict(model,x);
00183                                 ptr_predict_label[instance_index] = predict_label;
00184                         }
00185                 }
00186                 else
00187                 {
00188                         if(svm_type == ONE_CLASS ||
00189                            svm_type == EPSILON_SVR ||
00190                            svm_type == NU_SVR)
00191                         {
00192                                 double res;
00193                                 predict_label = svm_predict_values(model, x, &res);
00194                                 ptr_dec_values[instance_index] = res;
00195                         }
00196                         else
00197                         {
00198                                 double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
00199                                 predict_label = svm_predict_values(model, x, dec_values);
00200                                 if(nr_class == 1) 
00201                                         ptr_dec_values[instance_index] = 1;
00202                                 else
00203                                         for(i=0;i<(nr_class*(nr_class-1))/2;i++)
00204                                                 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
00205                                 free(dec_values);
00206                         }
00207                         ptr_predict_label[instance_index] = predict_label;
00208                 }
00209 
00210                 if(predict_label == target_label)
00211                         ++correct;
00212                 error += (predict_label-target_label)*(predict_label-target_label);
00213                 sump += predict_label;
00214                 sumt += target_label;
00215                 sumpp += predict_label*predict_label;
00216                 sumtt += target_label*target_label;
00217                 sumpt += predict_label*target_label;
00218                 ++total;
00219         }
00220         if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00221         {
00222                 info("Mean squared error = %g (regression)\n",error/total);
00223                 info("Squared correlation coefficient = %g (regression)\n",
00224                         ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00225                         ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
00226                         );
00227         }
00228         else
00229                 info("Accuracy = %g%% (%d/%d) (classification)\n",
00230                         (double)correct/total*100,correct,total);
00231 
00232         // return accuracy, mean squared error, squared correlation coefficient
00233         tplhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
00234         ptr = mxGetPr(tplhs[1]);
00235         ptr[0] = (double)correct/total*100;
00236         ptr[1] = error/total;
00237         ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00238                                 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
00239 
00240         free(x);
00241         if(prob_estimates != NULL)
00242                 free(prob_estimates);
00243 
00244         switch(nlhs)
00245         {
00246                 case 3:
00247                         plhs[2] = tplhs[2];
00248                         plhs[1] = tplhs[1];
00249                 case 1:
00250                 case 0:
00251                         plhs[0] = tplhs[0];
00252         }
00253 }
00254 
00255 void exit_with_help()
00256 {
00257         mexPrintf(
00258                 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00259                 "       [predicted_label] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00260                 "Parameters:\n"
00261                 "  model: SVM model structure from svmtrain.\n"
00262                 "  libsvm_options:\n"
00263                 "    -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
00264                 "    -q : quiet mode (no outputs)\n"
00265                 "Returns:\n"
00266                 "  predicted_label: SVM prediction output vector.\n"
00267                 "  accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
00268                 "  prob_estimates: If selected, probability estimate vector.\n"
00269         );
00270 }
00271 
00272 void mexFunction( int nlhs, mxArray *plhs[],
00273                  int nrhs, const mxArray *prhs[] )
00274 {
00275         int prob_estimate_flag = 0;
00276         struct svm_model *model;
00277         info = &mexPrintf;
00278 
00279         if(nlhs == 2 || nlhs > 3 || nrhs > 4 || nrhs < 3)
00280         {
00281                 exit_with_help();
00282                 fake_answer(nlhs, plhs);
00283                 return;
00284         }
00285 
00286         if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
00287                 mexPrintf("Error: label vector and instance matrix must be double\n");
00288                 fake_answer(nlhs, plhs);
00289                 return;
00290         }
00291 
00292         if(mxIsStruct(prhs[2]))
00293         {
00294                 const char *error_msg;
00295 
00296                 // parse options
00297                 if(nrhs==4)
00298                 {
00299                         int i, argc = 1;
00300                         char cmd[CMD_LEN], *argv[CMD_LEN/2];
00301 
00302                         // put options in argv[]
00303                         mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
00304                         if((argv[argc] = strtok(cmd, " ")) != NULL)
00305                                 while((argv[++argc] = strtok(NULL, " ")) != NULL)
00306                                         ;
00307 
00308                         for(i=1;i<argc;i++)
00309                         {
00310                                 if(argv[i][0] != '-') break;
00311                                 if((++i>=argc) && argv[i-1][1] != 'q')
00312                                 {
00313                                         exit_with_help();
00314                                         fake_answer(nlhs, plhs);
00315                                         return;
00316                                 }
00317                                 switch(argv[i-1][1])
00318                                 {
00319                                         case 'b':
00320                                                 prob_estimate_flag = atoi(argv[i]);
00321                                                 break;
00322                                         case 'q':
00323                                                 i--;
00324                                                 info = &print_null;
00325                                                 break;
00326                                         default:
00327                                                 mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
00328                                                 exit_with_help();
00329                                                 fake_answer(nlhs, plhs);
00330                                                 return;
00331                                 }
00332                         }
00333                 }
00334 
00335                 model = matlab_matrix_to_model(prhs[2], &error_msg);
00336                 if (model == NULL)
00337                 {
00338                         mexPrintf("Error: can't read model: %s\n", error_msg);
00339                         fake_answer(nlhs, plhs);
00340                         return;
00341                 }
00342 
00343                 if(prob_estimate_flag)
00344                 {
00345                         if(svm_check_probability_model(model)==0)
00346                         {
00347                                 mexPrintf("Model does not support probabiliy estimates\n");
00348                                 fake_answer(nlhs, plhs);
00349                                 svm_free_and_destroy_model(&model);
00350                                 return;
00351                         }
00352                 }
00353                 else
00354                 {
00355                         if(svm_check_probability_model(model)!=0)
00356                                 info("Model supports probability estimates, but disabled in predicton.\n");
00357                 }
00358 
00359                 predict(nlhs, plhs, prhs, model, prob_estimate_flag);
00360                 // destroy model
00361                 svm_free_and_destroy_model(&model);
00362         }
00363         else
00364         {
00365                 mexPrintf("model file should be a struct array\n");
00366                 fake_answer(nlhs, plhs);
00367         }
00368 
00369         return;
00370 }


target_obejct_detector
Author(s): CIR-KIT
autogenerated on Thu Jun 6 2019 20:19:57