00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include "svm.h"
00005
00006 #include "mex.h"
00007 #include "svm_model_matlab.h"
00008
00009 #ifdef MX_API_VER
00010 #if MX_API_VER < 0x07030000
00011 typedef int mwIndex;
00012 #endif
00013 #endif
00014
00015 #define CMD_LEN 2048
00016
00017 int print_null(const char *s,...) {}
00018 int (*info)(const char *fmt,...) = &mexPrintf;
00019
00020 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
00021 {
00022 int i, j, low, high;
00023 mwIndex *ir, *jc;
00024 double *samples;
00025
00026 ir = mxGetIr(prhs);
00027 jc = mxGetJc(prhs);
00028 samples = mxGetPr(prhs);
00029
00030
00031 j = 0;
00032 low = (int)jc[index], high = (int)jc[index+1];
00033 for(i=low;i<high;i++)
00034 {
00035 x[j].index = (int)ir[i] + 1;
00036 x[j].value = samples[i];
00037 j++;
00038 }
00039 x[j].index = -1;
00040 }
00041
00042 static void fake_answer(int nlhs, mxArray *plhs[])
00043 {
00044 int i;
00045 for(i=0;i<nlhs;i++)
00046 plhs[i] = mxCreateDoubleMatrix(0, 0, mxREAL);
00047 }
00048
00049 void predict(int nlhs, mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
00050 {
00051 int label_vector_row_num, label_vector_col_num;
00052 int feature_number, testing_instance_number;
00053 int instance_index;
00054 double *ptr_instance, *ptr_label, *ptr_predict_label;
00055 double *ptr_prob_estimates, *ptr_dec_values, *ptr;
00056 struct svm_node *x;
00057 mxArray *pplhs[1];
00058 mxArray *tplhs[3];
00059
00060 int correct = 0;
00061 int total = 0;
00062 double error = 0;
00063 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
00064
00065 int svm_type=svm_get_svm_type(model);
00066 int nr_class=svm_get_nr_class(model);
00067 double *prob_estimates=NULL;
00068
00069
00070 feature_number = (int)mxGetN(prhs[1]);
00071 testing_instance_number = (int)mxGetM(prhs[1]);
00072 label_vector_row_num = (int)mxGetM(prhs[0]);
00073 label_vector_col_num = (int)mxGetN(prhs[0]);
00074
00075 if(label_vector_row_num!=testing_instance_number)
00076 {
00077 mexPrintf("Length of label vector does not match # of instances.\n");
00078 fake_answer(nlhs, plhs);
00079 return;
00080 }
00081 if(label_vector_col_num!=1)
00082 {
00083 mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
00084 fake_answer(nlhs, plhs);
00085 return;
00086 }
00087
00088 ptr_instance = mxGetPr(prhs[1]);
00089 ptr_label = mxGetPr(prhs[0]);
00090
00091
00092 if(mxIsSparse(prhs[1]))
00093 {
00094 if(model->param.kernel_type == PRECOMPUTED)
00095 {
00096
00097 mxArray *rhs[1], *lhs[1];
00098 rhs[0] = mxDuplicateArray(prhs[1]);
00099 if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
00100 {
00101 mexPrintf("Error: cannot full testing instance matrix\n");
00102 fake_answer(nlhs, plhs);
00103 return;
00104 }
00105 ptr_instance = mxGetPr(lhs[0]);
00106 mxDestroyArray(rhs[0]);
00107 }
00108 else
00109 {
00110 mxArray *pprhs[1];
00111 pprhs[0] = mxDuplicateArray(prhs[1]);
00112 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00113 {
00114 mexPrintf("Error: cannot transpose testing instance matrix\n");
00115 fake_answer(nlhs, plhs);
00116 return;
00117 }
00118 }
00119 }
00120
00121 if(predict_probability)
00122 {
00123 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00124 info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
00125 else
00126 prob_estimates = (double *) malloc(nr_class*sizeof(double));
00127 }
00128
00129 tplhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00130 if(predict_probability)
00131 {
00132
00133 if(svm_type==C_SVC || svm_type==NU_SVC)
00134 tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
00135 else
00136 tplhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00137 }
00138 else
00139 {
00140
00141 if(svm_type == ONE_CLASS ||
00142 svm_type == EPSILON_SVR ||
00143 svm_type == NU_SVR ||
00144 nr_class == 1)
00145 tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00146 else
00147 tplhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
00148 }
00149
00150 ptr_predict_label = mxGetPr(tplhs[0]);
00151 ptr_prob_estimates = mxGetPr(tplhs[2]);
00152 ptr_dec_values = mxGetPr(tplhs[2]);
00153 x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
00154 for(instance_index=0;instance_index<testing_instance_number;instance_index++)
00155 {
00156 int i;
00157 double target_label, predict_label;
00158
00159 target_label = ptr_label[instance_index];
00160
00161 if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED)
00162 read_sparse_instance(pplhs[0], instance_index, x);
00163 else
00164 {
00165 for(i=0;i<feature_number;i++)
00166 {
00167 x[i].index = i+1;
00168 x[i].value = ptr_instance[testing_instance_number*i+instance_index];
00169 }
00170 x[feature_number].index = -1;
00171 }
00172
00173 if(predict_probability)
00174 {
00175 if(svm_type==C_SVC || svm_type==NU_SVC)
00176 {
00177 predict_label = svm_predict_probability(model, x, prob_estimates);
00178 ptr_predict_label[instance_index] = predict_label;
00179 for(i=0;i<nr_class;i++)
00180 ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
00181 } else {
00182 predict_label = svm_predict(model,x);
00183 ptr_predict_label[instance_index] = predict_label;
00184 }
00185 }
00186 else
00187 {
00188 if(svm_type == ONE_CLASS ||
00189 svm_type == EPSILON_SVR ||
00190 svm_type == NU_SVR)
00191 {
00192 double res;
00193 predict_label = svm_predict_values(model, x, &res);
00194 ptr_dec_values[instance_index] = res;
00195 }
00196 else
00197 {
00198 double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
00199 predict_label = svm_predict_values(model, x, dec_values);
00200 if(nr_class == 1)
00201 ptr_dec_values[instance_index] = 1;
00202 else
00203 for(i=0;i<(nr_class*(nr_class-1))/2;i++)
00204 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
00205 free(dec_values);
00206 }
00207 ptr_predict_label[instance_index] = predict_label;
00208 }
00209
00210 if(predict_label == target_label)
00211 ++correct;
00212 error += (predict_label-target_label)*(predict_label-target_label);
00213 sump += predict_label;
00214 sumt += target_label;
00215 sumpp += predict_label*predict_label;
00216 sumtt += target_label*target_label;
00217 sumpt += predict_label*target_label;
00218 ++total;
00219 }
00220 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00221 {
00222 info("Mean squared error = %g (regression)\n",error/total);
00223 info("Squared correlation coefficient = %g (regression)\n",
00224 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00225 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
00226 );
00227 }
00228 else
00229 info("Accuracy = %g%% (%d/%d) (classification)\n",
00230 (double)correct/total*100,correct,total);
00231
00232
00233 tplhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
00234 ptr = mxGetPr(tplhs[1]);
00235 ptr[0] = (double)correct/total*100;
00236 ptr[1] = error/total;
00237 ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00238 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
00239
00240 free(x);
00241 if(prob_estimates != NULL)
00242 free(prob_estimates);
00243
00244 switch(nlhs)
00245 {
00246 case 3:
00247 plhs[2] = tplhs[2];
00248 plhs[1] = tplhs[1];
00249 case 1:
00250 case 0:
00251 plhs[0] = tplhs[0];
00252 }
00253 }
00254
00255 void exit_with_help()
00256 {
00257 mexPrintf(
00258 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00259 " [predicted_label] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00260 "Parameters:\n"
00261 " model: SVM model structure from svmtrain.\n"
00262 " libsvm_options:\n"
00263 " -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
00264 " -q : quiet mode (no outputs)\n"
00265 "Returns:\n"
00266 " predicted_label: SVM prediction output vector.\n"
00267 " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
00268 " prob_estimates: If selected, probability estimate vector.\n"
00269 );
00270 }
00271
00272 void mexFunction( int nlhs, mxArray *plhs[],
00273 int nrhs, const mxArray *prhs[] )
00274 {
00275 int prob_estimate_flag = 0;
00276 struct svm_model *model;
00277 info = &mexPrintf;
00278
00279 if(nlhs == 2 || nlhs > 3 || nrhs > 4 || nrhs < 3)
00280 {
00281 exit_with_help();
00282 fake_answer(nlhs, plhs);
00283 return;
00284 }
00285
00286 if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
00287 mexPrintf("Error: label vector and instance matrix must be double\n");
00288 fake_answer(nlhs, plhs);
00289 return;
00290 }
00291
00292 if(mxIsStruct(prhs[2]))
00293 {
00294 const char *error_msg;
00295
00296
00297 if(nrhs==4)
00298 {
00299 int i, argc = 1;
00300 char cmd[CMD_LEN], *argv[CMD_LEN/2];
00301
00302
00303 mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
00304 if((argv[argc] = strtok(cmd, " ")) != NULL)
00305 while((argv[++argc] = strtok(NULL, " ")) != NULL)
00306 ;
00307
00308 for(i=1;i<argc;i++)
00309 {
00310 if(argv[i][0] != '-') break;
00311 if((++i>=argc) && argv[i-1][1] != 'q')
00312 {
00313 exit_with_help();
00314 fake_answer(nlhs, plhs);
00315 return;
00316 }
00317 switch(argv[i-1][1])
00318 {
00319 case 'b':
00320 prob_estimate_flag = atoi(argv[i]);
00321 break;
00322 case 'q':
00323 i--;
00324 info = &print_null;
00325 break;
00326 default:
00327 mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
00328 exit_with_help();
00329 fake_answer(nlhs, plhs);
00330 return;
00331 }
00332 }
00333 }
00334
00335 model = matlab_matrix_to_model(prhs[2], &error_msg);
00336 if (model == NULL)
00337 {
00338 mexPrintf("Error: can't read model: %s\n", error_msg);
00339 fake_answer(nlhs, plhs);
00340 return;
00341 }
00342
00343 if(prob_estimate_flag)
00344 {
00345 if(svm_check_probability_model(model)==0)
00346 {
00347 mexPrintf("Model does not support probabiliy estimates\n");
00348 fake_answer(nlhs, plhs);
00349 svm_free_and_destroy_model(&model);
00350 return;
00351 }
00352 }
00353 else
00354 {
00355 if(svm_check_probability_model(model)!=0)
00356 info("Model supports probability estimates, but disabled in predicton.\n");
00357 }
00358
00359 predict(nlhs, plhs, prhs, model, prob_estimate_flag);
00360
00361 svm_free_and_destroy_model(&model);
00362 }
00363 else
00364 {
00365 mexPrintf("model file should be a struct array\n");
00366 fake_answer(nlhs, plhs);
00367 }
00368
00369 return;
00370 }