10 #if MX_API_VER < 0x07030000 25 samples = mxGetPr(prhs);
29 low = (int)jc[index], high = (
int)jc[index+1];
32 x[j].
index = (int)ir[i] + 1;
33 x[j].
value = samples[i];
41 plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
42 plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
43 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
48 int label_vector_row_num, label_vector_col_num;
49 int feature_number, testing_instance_number;
51 double *ptr_instance, *ptr_label, *ptr_predict_label;
52 double *ptr_prob_estimates, *ptr_dec_values, *ptr;
59 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
63 double *prob_estimates=NULL;
66 feature_number = (int)mxGetN(prhs[1]);
67 testing_instance_number = (int)mxGetM(prhs[1]);
68 label_vector_row_num = (int)mxGetM(prhs[0]);
69 label_vector_col_num = (int)mxGetN(prhs[0]);
71 if(label_vector_row_num!=testing_instance_number)
73 mexPrintf(
"Length of label vector does not match # of instances.\n");
77 if(label_vector_col_num!=1)
79 mexPrintf(
"label (1st argument) should be a vector (# of column is 1).\n");
84 ptr_instance = mxGetPr(prhs[1]);
85 ptr_label = mxGetPr(prhs[0]);
88 if(mxIsSparse(prhs[1]))
93 mxArray *rhs[1], *lhs[1];
94 rhs[0] = mxDuplicateArray(prhs[1]);
95 if(mexCallMATLAB(1, lhs, 1, rhs,
"full"))
97 mexPrintf(
"Error: cannot full testing instance matrix\n");
101 ptr_instance = mxGetPr(lhs[0]);
102 mxDestroyArray(rhs[0]);
107 pprhs[0] = mxDuplicateArray(prhs[1]);
108 if(mexCallMATLAB(1, pplhs, 1, pprhs,
"transpose"))
110 mexPrintf(
"Error: cannot transpose testing instance matrix\n");
117 if(predict_probability)
120 mexPrintf(
"Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",
svm_get_svr_probability(model));
122 prob_estimates = (
double *) malloc(nr_class*
sizeof(
double));
125 plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
126 if(predict_probability)
130 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
132 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
141 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
143 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class*(nr_class-1)/2, mxREAL);
146 ptr_predict_label = mxGetPr(plhs[0]);
147 ptr_prob_estimates = mxGetPr(plhs[2]);
148 ptr_dec_values = mxGetPr(plhs[2]);
150 for(instance_index=0;instance_index<testing_instance_number;instance_index++)
153 double target_label, predict_label;
155 target_label = ptr_label[instance_index];
161 for(i=0;i<feature_number;i++)
164 x[i].
value = ptr_instance[testing_instance_number*i+instance_index];
166 x[feature_number].
index = -1;
169 if(predict_probability)
174 ptr_predict_label[instance_index] = predict_label;
175 for(i=0;i<nr_class;i++)
176 ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
179 ptr_predict_label[instance_index] = predict_label;
190 ptr_dec_values[instance_index] = res;
194 double *dec_values = (
double *) malloc(
sizeof(
double) * nr_class*(nr_class-1)/2);
197 ptr_dec_values[instance_index] = 1;
199 for(i=0;i<(nr_class*(nr_class-1))/2;i++)
200 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
203 ptr_predict_label[instance_index] = predict_label;
206 if(predict_label == target_label)
208 error += (predict_label-target_label)*(predict_label-target_label);
209 sump += predict_label;
210 sumt += target_label;
211 sumpp += predict_label*predict_label;
212 sumtt += target_label*target_label;
213 sumpt += predict_label*target_label;
218 mexPrintf(
"Mean squared error = %g (regression)\n",error/total);
219 mexPrintf(
"Squared correlation coefficient = %g (regression)\n",
220 ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
221 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt))
225 mexPrintf(
"Accuracy = %g%% (%d/%d) (classification)\n",
226 (
double)correct/total*100,correct,total);
229 plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
230 ptr = mxGetPr(plhs[1]);
231 ptr[0] = (double)correct/total*100;
232 ptr[1] = error/total;
233 ptr[2] = ((total*sumpt-sump*sumt)*(total*sumpt-sump*sumt))/
234 ((total*sumpp-sump*sump)*(total*sumtt-sumt*sumt));
237 if(prob_estimates != NULL)
238 free(prob_estimates);
244 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n" 246 " model: SVM model structure from svmtrain.\n" 248 " -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n" 250 " predicted_label: SVM prediction output vector.\n" 251 " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n" 252 " prob_estimates: If selected, probability estimate vector.\n" 257 int nrhs,
const mxArray *prhs[] )
259 int prob_estimate_flag = 0;
262 if(nrhs > 4 || nrhs < 3)
269 if(!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1])) {
270 mexPrintf(
"Error: label vector and instance matrix must be double\n");
275 if(mxIsStruct(prhs[2]))
277 const char *error_msg;
286 mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
287 if((argv[argc] = strtok(cmd,
" ")) != NULL)
288 while((argv[++argc] = strtok(NULL,
" ")) != NULL)
293 if(argv[i][0] !=
'-')
break;
303 prob_estimate_flag = atoi(argv[i]);
306 mexPrintf(
"Unknown option: -%c\n", argv[i-1][1]);
317 mexPrintf(
"Error: can't read model: %s\n", error_msg);
322 if(prob_estimate_flag)
326 mexPrintf(
"Model does not support probabiliy estimates\n");
335 mexPrintf(
"Model supports probability estimates, but disabled in predicton.\n");
338 predict(plhs, prhs, model, prob_estimate_flag);
344 mexPrintf(
"model file should be a struct array\n");
void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
double svm_get_svr_probability(const svm_model *model)
def svm_predict(y, x, m, options="")
static void fake_answer(mxArray *plhs[])
int svm_get_nr_class(const svm_model *model)
int svm_check_probability_model(const svm_model *model)
struct svm_model * matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
void svm_free_and_destroy_model(svm_model **model_ptr_ptr)
double svm_predict_values(const svm_model *model, const svm_node *x, double *dec_values)
double svm_predict_probability(const svm_model *model, const svm_node *x, double *prob_estimates)
struct svm_parameter param
int svm_get_svm_type(const svm_model *model)