00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include "../svm.h"
00005
00006 #include "mex.h"
00007 #include "svm_model_matlab.h"
00008
00009 #ifdef MX_API_VER
00010 #if MX_API_VER < 0x07030000
00011 typedef int mwIndex;
00012 #endif
00013 #endif
00014
00015 #define CMD_LEN 2048
00016
00017 int print_null(const char *s,...) {}
00018 int (*info)(const char *fmt,...) = &mexPrintf;
00019
00020 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
00021 {
00022 int i, j, low, high;
00023 mwIndex *ir, *jc;
00024 double *samples;
00025
00026 ir = mxGetIr(prhs);
00027 jc = mxGetJc(prhs);
00028 samples = mxGetPr(prhs);
00029
00030
00031 j = 0;
00032 low = (int)jc[index], high = (int)jc[index+1];
00033 for(i=low;i<high;i++)
00034 {
00035 x[j].index = (int)ir[i] + 1;
00036 x[j].value = samples[i];
00037 j++;
00038 }
00039 x[j].index = -1;
00040 }
00041
00042 static void fake_answer(mxArray *plhs[])
00043 {
00044 plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
00045 plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
00046 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00047 }
00048
00049 void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
00050 {
00051 int label_vector_row_num, label_vector_col_num;
00052 int feature_number, testing_instance_number;
00053 int instance_index;
00054 double *ptr_instance, *ptr_label, *ptr_predict_label;
00055 double *ptr_prob_estimates, *ptr_dec_values, *ptr;
00056 struct svm_node *x;
00057 mxArray *pplhs[1];
00058
00059 int correct = 0;
00060 int total = 0;
00061 double error = 0;
00062 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
00063
00064 int svm_type=svm_get_svm_type(model);
00065 int nr_class=svm_get_nr_class(model);
00066 double *prob_estimates=NULL;
00067
00068
00069 feature_number = (int)mxGetN(prhs[1]);
00070 testing_instance_number = (int)mxGetM(prhs[1]);
00071 label_vector_row_num = (int)mxGetM(prhs[0]);
00072 label_vector_col_num = (int)mxGetN(prhs[0]);
00073
00074 if(label_vector_row_num!=testing_instance_number)
00075 {
00076 mexPrintf("Length of label vector does not match # of instances.\n");
00077 fake_answer(plhs);
00078 return;
00079 }
00080 if(label_vector_col_num!=1)
00081 {
00082 mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
00083 fake_answer(plhs);
00084 return;
00085 }
00086
00087 ptr_instance = mxGetPr(prhs[1]);
00088 ptr_label = mxGetPr(prhs[0]);
00089
00090
00091 if(mxIsSparse(prhs[1]))
00092 {
00093 if(model->param.kernel_type == PRECOMPUTED)
00094 {
00095
00096 mxArray *rhs[1], *lhs[1];
00097 rhs[0] = mxDuplicateArray(prhs[1]);
00098 if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
00099 {
00100 mexPrintf("Error: cannot full testing instance matrix\n");
00101 fake_answer(plhs);
00102 return;
00103 }
00104 ptr_instance = mxGetPr(lhs[0]);
00105 mxDestroyArray(rhs[0]);
00106 }
00107 else
00108 {
00109 mxArray *pprhs[1];
00110 pprhs[0] = mxDuplicateArray(prhs[1]);
00111 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00112 {
00113 mexPrintf("Error: cannot transpose testing instance matrix\n");
00114 fake_answer(plhs);
00115 return;
00116 }
00117 }
00118 }
00119
00120 if(predict_probability)
00121 {
00122 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00123 info("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
00124 else
00125 prob_estimates = (double *) malloc(nr_class*sizeof(double));
00126 }
00127
00128 plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00129 if(predict_probability)
00130 {
00131
00132 if(svm_type==C_SVC || svm_type==NU_SVC)
00133 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
00134 else
00135 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00136 }
00137 else
00138 {
00139
00140 if(svm_type == ONE_CLASS ||
00141 svm_type == EPSILON_SVR ||
00142 svm_type == NU_SVR ||
00143 nr_class == 1)
00144 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00145 else
00146 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
00147 }
00148
00149 ptr_predict_label = mxGetPr(plhs[0]);
00150 ptr_prob_estimates = mxGetPr(plhs[2]);
00151 ptr_dec_values = mxGetPr(plhs[2]);
00152 x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
00153 for(instance_index=0;instance_index<testing_instance_number;instance_index++)
00154 {
00155 int i;
00156 double target_label, predict_label;
00157
00158 target_label = ptr_label[instance_index];
00159
00160 if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED)
00161 read_sparse_instance(pplhs[0], instance_index, x);
00162 else
00163 {
00164 for(i=0;i<feature_number;i++)
00165 {
00166 x[i].index = i+1;
00167 x[i].value = ptr_instance[testing_instance_number*i+instance_index];
00168 }
00169 x[feature_number].index = -1;
00170 }
00171
00172 if(predict_probability)
00173 {
00174 if(svm_type==C_SVC || svm_type==NU_SVC)
00175 {
00176 predict_label = svm_predict_probability(model, x, prob_estimates);
00177 ptr_predict_label[instance_index] = predict_label;
00178 for(i=0;i<nr_class;i++)
00179 ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
00180 } else {
00181 predict_label = svm_predict(model,x);
00182 ptr_predict_label[instance_index] = predict_label;
00183 }
00184 }
00185 else
00186 {
00187 if(svm_type == ONE_CLASS ||
00188 svm_type == EPSILON_SVR ||
00189 svm_type == NU_SVR)
00190 {
00191 double res;
00192 predict_label = svm_predict_values(model, x, &res);
00193 ptr_dec_values[instance_index] = res;
00194 }
00195 else
00196 {
00197 double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
00198 predict_label = svm_predict_values(model, x, dec_values);
00199 if(nr_class == 1)
00200 ptr_dec_values[instance_index] = 1;
00201 else
00202 for(i=0;i<(nr_class*(nr_class-1))/2;i++)
00203 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
00204 free(dec_values);
00205 }
00206 ptr_predict_label[instance_index] = predict_label;
00207 }
00208
00209 if(predict_label == target_label)
00210 ++correct;
00211 error += (predict_label-target_label)*(predict_label-target_label);
00212 sump += predict_label;
00213 sumt += target_label;
00214 sumpp += predict_label*predict_label;
00215 sumtt += target_label*target_label;
00216 sumpt += predict_label*target_label;
00217 ++total;
00218 }
00219 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00220 {
00221 info("Mean squared error = %g (regression)\n",error/total);
00222 info("Squared correlation coefficient = %g (regression)\n",
00223 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00224 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
00225 );
00226 }
00227 else
00228 info("Accuracy = %g%% (%d/%d) (classification)\n",
00229 (double)correct/total*100,correct,total);
00230
00231
00232 plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
00233 ptr = mxGetPr(plhs[1]);
00234 ptr[0] = (double)correct/total*100;
00235 ptr[1] = error/total;
00236 ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00237 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
00238
00239 free(x);
00240 if(prob_estimates != NULL)
00241 free(prob_estimates);
00242 }
00243
00244 void exit_with_help()
00245 {
00246 mexPrintf(
00247 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00248 "Parameters:\n"
00249 " model: SVM model structure from svmtrain.\n"
00250 " libsvm_options:\n"
00251 " -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
00252 " -q : quiet mode (no outputs)\n"
00253 "Returns:\n"
00254 " predicted_label: SVM prediction output vector.\n"
00255 " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
00256 " prob_estimates: If selected, probability estimate vector.\n"
00257 );
00258 }
00259
00260 void mexFunction( int nlhs, mxArray *plhs[],
00261 int nrhs, const mxArray *prhs[] )
00262 {
00263 int prob_estimate_flag = 0;
00264 struct svm_model *model;
00265 info = &mexPrintf;
00266
00267 if(nrhs > 4 || nrhs < 3)
00268 {
00269 exit_with_help();
00270 fake_answer(plhs);
00271 return;
00272 }
00273
00274 if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
00275 mexPrintf("Error: label vector and instance matrix must be double\n");
00276 fake_answer(plhs);
00277 return;
00278 }
00279
00280 if(mxIsStruct(prhs[2]))
00281 {
00282 const char *error_msg;
00283
00284
00285 if(nrhs==4)
00286 {
00287 int i, argc = 1;
00288 char cmd[CMD_LEN], *argv[CMD_LEN/2];
00289
00290
00291 mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
00292 if((argv[argc] = strtok(cmd, " ")) != NULL)
00293 while((argv[++argc] = strtok(NULL, " ")) != NULL)
00294 ;
00295
00296 for(i=1;i<argc;i++)
00297 {
00298 if(argv[i][0] != '-') break;
00299 if((++i>=argc) && argv[i-1][1] != 'q')
00300 {
00301 exit_with_help();
00302 fake_answer(plhs);
00303 return;
00304 }
00305 switch(argv[i-1][1])
00306 {
00307 case 'b':
00308 prob_estimate_flag = atoi(argv[i]);
00309 break;
00310 case 'q':
00311 i--;
00312 info = &print_null;
00313 break;
00314 default:
00315 mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
00316 exit_with_help();
00317 fake_answer(plhs);
00318 return;
00319 }
00320 }
00321 }
00322
00323 model = matlab_matrix_to_model(prhs[2], &error_msg);
00324 if (model == NULL)
00325 {
00326 mexPrintf("Error: can't read model: %s\n", error_msg);
00327 fake_answer(plhs);
00328 return;
00329 }
00330
00331 if(prob_estimate_flag)
00332 {
00333 if(svm_check_probability_model(model)==0)
00334 {
00335 mexPrintf("Model does not support probabiliy estimates\n");
00336 fake_answer(plhs);
00337 svm_free_and_destroy_model(&model);
00338 return;
00339 }
00340 }
00341 else
00342 {
00343 if(svm_check_probability_model(model)!=0)
00344 info("Model supports probability estimates, but disabled in predicton.\n");
00345 }
00346
00347 predict(plhs, prhs, model, prob_estimate_flag);
00348
00349 svm_free_and_destroy_model(&model);
00350 }
00351 else
00352 {
00353 mexPrintf("model file should be a struct array\n");
00354 fake_answer(plhs);
00355 }
00356
00357 return;
00358 }