svmpredict.c
Go to the documentation of this file.
00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include "../svm.h"
00005 
00006 #include "mex.h"
00007 #include "svm_model_matlab.h"
00008 
00009 #ifdef MX_API_VER
00010 #if MX_API_VER < 0x07030000
00011 typedef int mwIndex;
00012 #endif
00013 #endif
00014 
00015 #define CMD_LEN 2048
00016 
00017 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
00018 {
00019         int i, j, low, high;
00020         mwIndex *ir, *jc;
00021         double *samples;
00022 
00023         ir = mxGetIr(prhs);
00024         jc = mxGetJc(prhs);
00025         samples = mxGetPr(prhs);
00026 
00027         // each column is one instance
00028         j = 0;
00029         low = (int)jc[index], high = (int)jc[index+1];
00030         for(i=low;i<high;i++)
00031         {
00032                 x[j].index = (int)ir[i] + 1;
00033                 x[j].value = samples[i];
00034                 j++;
00035         }
00036         x[j].index = -1;
00037 }
00038 
00039 static void fake_answer(mxArray *plhs[])
00040 {
00041         plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
00042         plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
00043         plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00044 }
00045 
00046 void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
00047 {
00048         int label_vector_row_num, label_vector_col_num;
00049         int feature_number, testing_instance_number;
00050         int instance_index;
00051         double *ptr_instance, *ptr_label, *ptr_predict_label; 
00052         double *ptr_prob_estimates, *ptr_dec_values, *ptr;
00053         struct svm_node *x;
00054         mxArray *pplhs[1]; // transposed instance sparse matrix
00055 
00056         int correct = 0;
00057         int total = 0;
00058         double error = 0;
00059         double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
00060 
00061         int svm_type=svm_get_svm_type(model);
00062         int nr_class=svm_get_nr_class(model);
00063         double *prob_estimates=NULL;
00064 
00065         // prhs[1] = testing instance matrix
00066         feature_number = (int)mxGetN(prhs[1]);
00067         testing_instance_number = (int)mxGetM(prhs[1]);
00068         label_vector_row_num = (int)mxGetM(prhs[0]);
00069         label_vector_col_num = (int)mxGetN(prhs[0]);
00070 
00071         if(label_vector_row_num!=testing_instance_number)
00072         {
00073                 mexPrintf("Length of label vector does not match # of instances.\n");
00074                 fake_answer(plhs);
00075                 return;
00076         }
00077         if(label_vector_col_num!=1)
00078         {
00079                 mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
00080                 fake_answer(plhs);
00081                 return;
00082         }
00083 
00084         ptr_instance = mxGetPr(prhs[1]);
00085         ptr_label    = mxGetPr(prhs[0]);
00086 
00087         // transpose instance matrix
00088         if(mxIsSparse(prhs[1]))
00089         {
00090                 if(model->param.kernel_type == PRECOMPUTED)
00091                 {
00092                         // precomputed kernel requires dense matrix, so we make one
00093                         mxArray *rhs[1], *lhs[1];
00094                         rhs[0] = mxDuplicateArray(prhs[1]);
00095                         if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
00096                         {
00097                                 mexPrintf("Error: cannot full testing instance matrix\n");
00098                                 fake_answer(plhs);
00099                                 return;
00100                         }
00101                         ptr_instance = mxGetPr(lhs[0]);
00102                         mxDestroyArray(rhs[0]);
00103                 }
00104                 else
00105                 {
00106                         mxArray *pprhs[1];
00107                         pprhs[0] = mxDuplicateArray(prhs[1]);
00108                         if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00109                         {
00110                                 mexPrintf("Error: cannot transpose testing instance matrix\n");
00111                                 fake_answer(plhs);
00112                                 return;
00113                         }
00114                 }
00115         }
00116 
00117         if(predict_probability)
00118         {
00119                 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00120                         mexPrintf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
00121                 else
00122                         prob_estimates = (double *) malloc(nr_class*sizeof(double));
00123         }
00124 
00125         plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00126         if(predict_probability)
00127         {
00128                 // prob estimates are in plhs[2]
00129                 if(svm_type==C_SVC || svm_type==NU_SVC)
00130                         plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
00131                 else
00132                         plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00133         }
00134         else
00135         {
00136                 // decision values are in plhs[2]
00137                 if(svm_type == ONE_CLASS ||
00138                    svm_type == EPSILON_SVR ||
00139                    svm_type == NU_SVR ||
00140                    nr_class == 1) // if only one class in training data, decision values are still returned.
00141                         plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00142                 else
00143                         plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
00144         }
00145 
00146         ptr_predict_label = mxGetPr(plhs[0]);
00147         ptr_prob_estimates = mxGetPr(plhs[2]);
00148         ptr_dec_values = mxGetPr(plhs[2]);
00149         x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
00150         for(instance_index=0;instance_index<testing_instance_number;instance_index++)
00151         {
00152                 int i;
00153                 double target_label, predict_label;
00154 
00155                 target_label = ptr_label[instance_index];
00156 
00157                 if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED) // prhs[1]^T is still sparse
00158                         read_sparse_instance(pplhs[0], instance_index, x);
00159                 else
00160                 {
00161                         for(i=0;i<feature_number;i++)
00162                         {
00163                                 x[i].index = i+1;
00164                                 x[i].value = ptr_instance[testing_instance_number*i+instance_index];
00165                         }
00166                         x[feature_number].index = -1;
00167                 }
00168 
00169                 if(predict_probability)
00170                 {
00171                         if(svm_type==C_SVC || svm_type==NU_SVC)
00172                         {
00173                                 predict_label = svm_predict_probability(model, x, prob_estimates);
00174                                 ptr_predict_label[instance_index] = predict_label;
00175                                 for(i=0;i<nr_class;i++)
00176                                         ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
00177                         } else {
00178                                 predict_label = svm_predict(model,x);
00179                                 ptr_predict_label[instance_index] = predict_label;
00180                         }
00181                 }
00182                 else
00183                 {
00184                         if(svm_type == ONE_CLASS ||
00185                            svm_type == EPSILON_SVR ||
00186                            svm_type == NU_SVR)
00187                         {
00188                                 double res;
00189                                 predict_label = svm_predict_values(model, x, &res);
00190                                 ptr_dec_values[instance_index] = res;
00191                         }
00192                         else
00193                         {
00194                                 double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
00195                                 predict_label = svm_predict_values(model, x, dec_values);
00196                                 if(nr_class == 1) 
00197                                         ptr_dec_values[instance_index] = 1;
00198                                 else
00199                                         for(i=0;i<(nr_class*(nr_class-1))/2;i++)
00200                                                 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
00201                                 free(dec_values);
00202                         }
00203                         ptr_predict_label[instance_index] = predict_label;
00204                 }
00205 
00206                 if(predict_label == target_label)
00207                         ++correct;
00208                 error += (predict_label-target_label)*(predict_label-target_label);
00209                 sump += predict_label;
00210                 sumt += target_label;
00211                 sumpp += predict_label*predict_label;
00212                 sumtt += target_label*target_label;
00213                 sumpt += predict_label*target_label;
00214                 ++total;
00215         }
00216         if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00217         {
00218                 mexPrintf("Mean squared error = %g (regression)\n",error/total);
00219                 mexPrintf("Squared correlation coefficient = %g (regression)\n",
00220                         ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00221                         ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
00222                         );
00223         }
00224         else
00225                 mexPrintf("Accuracy = %g%% (%d/%d) (classification)\n",
00226                         (double)correct/total*100,correct,total);
00227 
00228         // return accuracy, mean squared error, squared correlation coefficient
00229         plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
00230         ptr = mxGetPr(plhs[1]);
00231         ptr[0] = (double)correct/total*100;
00232         ptr[1] = error/total;
00233         ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00234                                 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
00235 
00236         free(x);
00237         if(prob_estimates != NULL)
00238                 free(prob_estimates);
00239 }
00240 
00241 void exit_with_help()
00242 {
00243         mexPrintf(
00244                 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00245                 "Parameters:\n"
00246                 "  model: SVM model structure from svmtrain.\n"
00247                 "  libsvm_options:\n"
00248                 "    -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
00249                 "Returns:\n"
00250                 "  predicted_label: SVM prediction output vector.\n"
00251                 "  accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
00252                 "  prob_estimates: If selected, probability estimate vector.\n"
00253         );
00254 }
00255 
00256 void mexFunction( int nlhs, mxArray *plhs[],
00257                  int nrhs, const mxArray *prhs[] )
00258 {
00259         int prob_estimate_flag = 0;
00260         struct svm_model *model;
00261 
00262         if(nrhs > 4 || nrhs < 3)
00263         {
00264                 exit_with_help();
00265                 fake_answer(plhs);
00266                 return;
00267         }
00268 
00269         if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
00270                 mexPrintf("Error: label vector and instance matrix must be double\n");
00271                 fake_answer(plhs);
00272                 return;
00273         }
00274 
00275         if(mxIsStruct(prhs[2]))
00276         {
00277                 const char *error_msg;
00278 
00279                 // parse options
00280                 if(nrhs==4)
00281                 {
00282                         int i, argc = 1;
00283                         char cmd[CMD_LEN], *argv[CMD_LEN/2];
00284 
00285                         // put options in argv[]
00286                         mxGetString(prhs[3], cmd,  mxGetN(prhs[3]) + 1);
00287                         if((argv[argc] = strtok(cmd, " ")) != NULL)
00288                                 while((argv[++argc] = strtok(NULL, " ")) != NULL)
00289                                         ;
00290 
00291                         for(i=1;i<argc;i++)
00292                         {
00293                                 if(argv[i][0] != '-') break;
00294                                 if(++i>=argc)
00295                                 {
00296                                         exit_with_help();
00297                                         fake_answer(plhs);
00298                                         return;
00299                                 }
00300                                 switch(argv[i-1][1])
00301                                 {
00302                                         case 'b':
00303                                                 prob_estimate_flag = atoi(argv[i]);
00304                                                 break;
00305                                         default:
00306                                                 mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
00307                                                 exit_with_help();
00308                                                 fake_answer(plhs);
00309                                                 return;
00310                                 }
00311                         }
00312                 }
00313 
00314                 model = matlab_matrix_to_model(prhs[2], &error_msg);
00315                 if (model == NULL)
00316                 {
00317                         mexPrintf("Error: can't read model: %s\n", error_msg);
00318                         fake_answer(plhs);
00319                         return;
00320                 }
00321 
00322                 if(prob_estimate_flag)
00323                 {
00324                         if(svm_check_probability_model(model)==0)
00325                         {
00326                                 mexPrintf("Model does not support probabiliy estimates\n");
00327                                 fake_answer(plhs);
00328                                 svm_free_and_destroy_model(&model);
00329                                 return;
00330                         }
00331                 }
00332                 else
00333                 {
00334                         if(svm_check_probability_model(model)!=0)
00335                                 mexPrintf("Model supports probability estimates, but disabled in predicton.\n");
00336                 }
00337 
00338                 predict(plhs, prhs, model, prob_estimate_flag);
00339                 // destroy model
00340                 svm_free_and_destroy_model(&model);
00341         }
00342         else
00343         {
00344                 mexPrintf("model file should be a struct array\n");
00345                 fake_answer(plhs);
00346         }
00347 
00348         return;
00349 }


haf_grasping
Author(s): David Fischinger
autogenerated on Thu Jun 6 2019 18:35:09