00001 #include <stdio.h>
00002 #include <stdlib.h>
00003 #include <string.h>
00004 #include "../svm.h"
00005
00006 #include "mex.h"
00007 #include "svm_model_matlab.h"
00008
00009 #ifdef MX_API_VER
00010 #if MX_API_VER < 0x07030000
00011 typedef int mwIndex;
00012 #endif
00013 #endif
00014
00015 #define CMD_LEN 2048
00016
00017 void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
00018 {
00019 int i, j, low, high;
00020 mwIndex *ir, *jc;
00021 double *samples;
00022
00023 ir = mxGetIr(prhs);
00024 jc = mxGetJc(prhs);
00025 samples = mxGetPr(prhs);
00026
00027
00028 j = 0;
00029 low = (int)jc[index], high = (int)jc[index+1];
00030 for(i=low;i<high;i++)
00031 {
00032 x[j].index = (int)ir[i] + 1;
00033 x[j].value = samples[i];
00034 j++;
00035 }
00036 x[j].index = -1;
00037 }
00038
00039 static void fake_answer(mxArray *plhs[])
00040 {
00041 plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
00042 plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
00043 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00044 }
00045
00046 void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
00047 {
00048 int label_vector_row_num, label_vector_col_num;
00049 int feature_number, testing_instance_number;
00050 int instance_index;
00051 double *ptr_instance, *ptr_label, *ptr_predict_label;
00052 double *ptr_prob_estimates, *ptr_dec_values, *ptr;
00053 struct svm_node *x;
00054 mxArray *pplhs[1];
00055
00056 int correct = 0;
00057 int total = 0;
00058 double error = 0;
00059 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
00060
00061 int svm_type=svm_get_svm_type(model);
00062 int nr_class=svm_get_nr_class(model);
00063 double *prob_estimates=NULL;
00064
00065
00066 feature_number = (int)mxGetN(prhs[1]);
00067 testing_instance_number = (int)mxGetM(prhs[1]);
00068 label_vector_row_num = (int)mxGetM(prhs[0]);
00069 label_vector_col_num = (int)mxGetN(prhs[0]);
00070
00071 if(label_vector_row_num!=testing_instance_number)
00072 {
00073 mexPrintf("Length of label vector does not match # of instances.\n");
00074 fake_answer(plhs);
00075 return;
00076 }
00077 if(label_vector_col_num!=1)
00078 {
00079 mexPrintf("label (1st argument) should be a vector (# of column is 1).\n");
00080 fake_answer(plhs);
00081 return;
00082 }
00083
00084 ptr_instance = mxGetPr(prhs[1]);
00085 ptr_label = mxGetPr(prhs[0]);
00086
00087
00088 if(mxIsSparse(prhs[1]))
00089 {
00090 if(model->param.kernel_type == PRECOMPUTED)
00091 {
00092
00093 mxArray *rhs[1], *lhs[1];
00094 rhs[0] = mxDuplicateArray(prhs[1]);
00095 if(mexCallMATLAB(1, lhs, 1, rhs, "full"))
00096 {
00097 mexPrintf("Error: cannot full testing instance matrix\n");
00098 fake_answer(plhs);
00099 return;
00100 }
00101 ptr_instance = mxGetPr(lhs[0]);
00102 mxDestroyArray(rhs[0]);
00103 }
00104 else
00105 {
00106 mxArray *pprhs[1];
00107 pprhs[0] = mxDuplicateArray(prhs[1]);
00108 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00109 {
00110 mexPrintf("Error: cannot transpose testing instance matrix\n");
00111 fake_answer(plhs);
00112 return;
00113 }
00114 }
00115 }
00116
00117 if(predict_probability)
00118 {
00119 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00120 mexPrintf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
00121 else
00122 prob_estimates = (double *) malloc(nr_class*sizeof(double));
00123 }
00124
00125 plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00126 if(predict_probability)
00127 {
00128
00129 if(svm_type==C_SVC || svm_type==NU_SVC)
00130 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
00131 else
00132 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
00133 }
00134 else
00135 {
00136
00137 if(svm_type == ONE_CLASS ||
00138 svm_type == EPSILON_SVR ||
00139 svm_type == NU_SVR ||
00140 nr_class == 1)
00141 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
00142 else
00143 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
00144 }
00145
00146 ptr_predict_label = mxGetPr(plhs[0]);
00147 ptr_prob_estimates = mxGetPr(plhs[2]);
00148 ptr_dec_values = mxGetPr(plhs[2]);
00149 x = (struct svm_node*)malloc((feature_number+1)*sizeof(struct svm_node) );
00150 for(instance_index=0;instance_index<testing_instance_number;instance_index++)
00151 {
00152 int i;
00153 double target_label, predict_label;
00154
00155 target_label = ptr_label[instance_index];
00156
00157 if(mxIsSparse(prhs[1]) && model->param.kernel_type != PRECOMPUTED)
00158 read_sparse_instance(pplhs[0], instance_index, x);
00159 else
00160 {
00161 for(i=0;i<feature_number;i++)
00162 {
00163 x[i].index = i+1;
00164 x[i].value = ptr_instance[testing_instance_number*i+instance_index];
00165 }
00166 x[feature_number].index = -1;
00167 }
00168
00169 if(predict_probability)
00170 {
00171 if(svm_type==C_SVC || svm_type==NU_SVC)
00172 {
00173 predict_label = svm_predict_probability(model, x, prob_estimates);
00174 ptr_predict_label[instance_index] = predict_label;
00175 for(i=0;i<nr_class;i++)
00176 ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
00177 } else {
00178 predict_label = svm_predict(model,x);
00179 ptr_predict_label[instance_index] = predict_label;
00180 }
00181 }
00182 else
00183 {
00184 if(svm_type == ONE_CLASS ||
00185 svm_type == EPSILON_SVR ||
00186 svm_type == NU_SVR)
00187 {
00188 double res;
00189 predict_label = svm_predict_values(model, x, &res);
00190 ptr_dec_values[instance_index] = res;
00191 }
00192 else
00193 {
00194 double *dec_values = (double *) malloc(sizeof(double) * nr_class*(nr_class-1)/2);
00195 predict_label = svm_predict_values(model, x, dec_values);
00196 if(nr_class == 1)
00197 ptr_dec_values[instance_index] = 1;
00198 else
00199 for(i=0;i<(nr_class*(nr_class-1))/2;i++)
00200 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
00201 free(dec_values);
00202 }
00203 ptr_predict_label[instance_index] = predict_label;
00204 }
00205
00206 if(predict_label == target_label)
00207 ++correct;
00208 error += (predict_label-target_label)*(predict_label-target_label);
00209 sump += predict_label;
00210 sumt += target_label;
00211 sumpp += predict_label*predict_label;
00212 sumtt += target_label*target_label;
00213 sumpt += predict_label*target_label;
00214 ++total;
00215 }
00216 if(svm_type==NU_SVR || svm_type==EPSILON_SVR)
00217 {
00218 mexPrintf("Mean squared error = %g (regression)\n",error/total);
00219 mexPrintf("Squared correlation coefficient = %g (regression)\n",
00220 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00221 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
00222 );
00223 }
00224 else
00225 mexPrintf("Accuracy = %g%% (%d/%d) (classification)\n",
00226 (double)correct/total*100,correct,total);
00227
00228
00229 plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
00230 ptr = mxGetPr(plhs[1]);
00231 ptr[0] = (double)correct/total*100;
00232 ptr[1] = error/total;
00233 ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
00234 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
00235
00236 free(x);
00237 if(prob_estimates != NULL)
00238 free(prob_estimates);
00239 }
00240
00241 void exit_with_help()
00242 {
00243 mexPrintf(
00244 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n"
00245 "Parameters:\n"
00246 " model: SVM model structure from svmtrain.\n"
00247 " libsvm_options:\n"
00248 " -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n"
00249 "Returns:\n"
00250 " predicted_label: SVM prediction output vector.\n"
00251 " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n"
00252 " prob_estimates: If selected, probability estimate vector.\n"
00253 );
00254 }
00255
00256 void mexFunction( int nlhs, mxArray *plhs[],
00257 int nrhs, const mxArray *prhs[] )
00258 {
00259 int prob_estimate_flag = 0;
00260 struct svm_model *model;
00261
00262 if(nrhs > 4 || nrhs < 3)
00263 {
00264 exit_with_help();
00265 fake_answer(plhs);
00266 return;
00267 }
00268
00269 if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
00270 mexPrintf("Error: label vector and instance matrix must be double\n");
00271 fake_answer(plhs);
00272 return;
00273 }
00274
00275 if(mxIsStruct(prhs[2]))
00276 {
00277 const char *error_msg;
00278
00279
00280 if(nrhs==4)
00281 {
00282 int i, argc = 1;
00283 char cmd[CMD_LEN], *argv[CMD_LEN/2];
00284
00285
00286 mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
00287 if((argv[argc] = strtok(cmd, " ")) != NULL)
00288 while((argv[++argc] = strtok(NULL, " ")) != NULL)
00289 ;
00290
00291 for(i=1;i<argc;i++)
00292 {
00293 if(argv[i][0] != '-') break;
00294 if(++i>=argc)
00295 {
00296 exit_with_help();
00297 fake_answer(plhs);
00298 return;
00299 }
00300 switch(argv[i-1][1])
00301 {
00302 case 'b':
00303 prob_estimate_flag = atoi(argv[i]);
00304 break;
00305 default:
00306 mexPrintf("Unknown option: -%c\n", argv[i-1][1]);
00307 exit_with_help();
00308 fake_answer(plhs);
00309 return;
00310 }
00311 }
00312 }
00313
00314 model = matlab_matrix_to_model(prhs[2], &error_msg);
00315 if (model == NULL)
00316 {
00317 mexPrintf("Error: can't read model: %s\n", error_msg);
00318 fake_answer(plhs);
00319 return;
00320 }
00321
00322 if(prob_estimate_flag)
00323 {
00324 if(svm_check_probability_model(model)==0)
00325 {
00326 mexPrintf("Model does not support probabiliy estimates\n");
00327 fake_answer(plhs);
00328 svm_free_and_destroy_model(&model);
00329 return;
00330 }
00331 }
00332 else
00333 {
00334 if(svm_check_probability_model(model)!=0)
00335 mexPrintf("Model supports probability estimates, but disabled in predicton.\n");
00336 }
00337
00338 predict(plhs, prhs, model, prob_estimate_flag);
00339
00340 svm_free_and_destroy_model(&model);
00341 }
00342 else
00343 {
00344 mexPrintf("model file should be a struct array\n");
00345 fake_answer(plhs);
00346 }
00347
00348 return;
00349 }