10 #if MX_API_VER < 0x07030000 18 int (*
info)(
const char *fmt, ...) = &mexPrintf;
28 samples = mxGetPr(prhs);
32 low = (int)jc[index], high = (
int)jc[index + 1];
33 for (i = low; i < high; i++)
35 x[j].
index = (int)ir[i] + 1;
36 x[j].
value = samples[i];
44 plhs[0] = mxCreateDoubleMatrix(0, 0, mxREAL);
45 plhs[1] = mxCreateDoubleMatrix(0, 0, mxREAL);
46 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
51 int label_vector_row_num, label_vector_col_num;
52 int feature_number, testing_instance_number;
54 double *ptr_instance, *ptr_label, *ptr_predict_label;
55 double *ptr_prob_estimates, *ptr_dec_values, *ptr;
62 double sump = 0, sumt = 0, sumpp = 0, sumtt = 0, sumpt = 0;
66 double *prob_estimates = NULL;
69 feature_number = (int)mxGetN(prhs[1]);
70 testing_instance_number = (int)mxGetM(prhs[1]);
71 label_vector_row_num = (int)mxGetM(prhs[0]);
72 label_vector_col_num = (int)mxGetN(prhs[0]);
74 if (label_vector_row_num != testing_instance_number)
76 mexPrintf(
"Length of label vector does not match # of instances.\n");
80 if (label_vector_col_num != 1)
82 mexPrintf(
"label (1st argument) should be a vector (# of column is 1).\n");
87 ptr_instance = mxGetPr(prhs[1]);
88 ptr_label = mxGetPr(prhs[0]);
91 if (mxIsSparse(prhs[1]))
96 mxArray *rhs[1], *lhs[1];
97 rhs[0] = mxDuplicateArray(prhs[1]);
98 if (mexCallMATLAB(1, lhs, 1, rhs,
"full"))
100 mexPrintf(
"Error: cannot full testing instance matrix\n");
104 ptr_instance = mxGetPr(lhs[0]);
105 mxDestroyArray(rhs[0]);
110 pprhs[0] = mxDuplicateArray(prhs[1]);
111 if (mexCallMATLAB(1, pplhs, 1, pprhs,
"transpose"))
113 mexPrintf(
"Error: cannot transpose testing instance matrix\n");
120 if (predict_probability)
123 info(
"Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",
svm_get_svr_probability(model));
125 prob_estimates = (
double *) malloc(nr_class *
sizeof(
double));
128 plhs[0] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
129 if (predict_probability)
133 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class, mxREAL);
135 plhs[2] = mxCreateDoubleMatrix(0, 0, mxREAL);
144 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, 1, mxREAL);
146 plhs[2] = mxCreateDoubleMatrix(testing_instance_number, nr_class * (nr_class - 1) / 2, mxREAL);
149 ptr_predict_label = mxGetPr(plhs[0]);
150 ptr_prob_estimates = mxGetPr(plhs[2]);
151 ptr_dec_values = mxGetPr(plhs[2]);
152 x = (
struct svm_node*)malloc((feature_number + 1) *
sizeof(
struct svm_node));
153 for (instance_index = 0; instance_index < testing_instance_number; instance_index++)
156 double target_label, predict_label;
158 target_label = ptr_label[instance_index];
164 for (i = 0; i < feature_number; i++)
167 x[i].
value = ptr_instance[testing_instance_number * i + instance_index];
169 x[feature_number].
index = -1;
172 if (predict_probability)
177 ptr_predict_label[instance_index] = predict_label;
178 for (i = 0; i < nr_class; i++)
179 ptr_prob_estimates[instance_index + i * testing_instance_number] = prob_estimates[i];
184 ptr_predict_label[instance_index] = predict_label;
195 ptr_dec_values[instance_index] = res;
199 double *dec_values = (
double *) malloc(
sizeof(
double) * nr_class * (nr_class - 1) / 2);
202 ptr_dec_values[instance_index] = 1;
204 for (i = 0; i < (nr_class * (nr_class - 1)) / 2; i++)
205 ptr_dec_values[instance_index + i * testing_instance_number] = dec_values[i];
208 ptr_predict_label[instance_index] = predict_label;
211 if (predict_label == target_label)
213 error += (predict_label - target_label) * (predict_label - target_label);
214 sump += predict_label;
215 sumt += target_label;
216 sumpp += predict_label * predict_label;
217 sumtt += target_label * target_label;
218 sumpt += predict_label * target_label;
223 info(
"Mean squared error = %g (regression)\n", error / total);
224 info(
"Squared correlation coefficient = %g (regression)\n",
225 ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt)) /
226 ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt))
230 info(
"Accuracy = %g%% (%d/%d) (classification)\n",
231 (
double)correct / total * 100, correct, total);
234 plhs[1] = mxCreateDoubleMatrix(3, 1, mxREAL);
235 ptr = mxGetPr(plhs[1]);
236 ptr[0] = (double)correct / total * 100;
237 ptr[1] = error / total;
238 ptr[2] = ((total * sumpt - sump * sumt) * (total * sumpt - sump * sumt)) /
239 ((total * sumpp - sump * sump) * (total * sumtt - sumt * sumt));
242 if (prob_estimates != NULL)
243 free(prob_estimates);
249 "Usage: [predicted_label, accuracy, decision_values/prob_estimates] = svmpredict(testing_label_vector, testing_instance_matrix, model, 'libsvm_options')\n" 251 " model: SVM model structure from svmtrain.\n" 253 " -b probability_estimates: whether to predict probability estimates, 0 or 1 (default 0); one-class SVM not supported yet\n" 254 " -q : quiet mode (no outputs)\n" 256 " predicted_label: SVM prediction output vector.\n" 257 " accuracy: a vector with accuracy, mean squared error, squared correlation coefficient.\n" 258 " prob_estimates: If selected, probability estimate vector.\n" 263 int nrhs,
const mxArray *prhs[])
265 int prob_estimate_flag = 0;
269 if (nrhs > 4 || nrhs < 3)
276 if (!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]))
278 mexPrintf(
"Error: label vector and instance matrix must be double\n");
283 if (mxIsStruct(prhs[2]))
285 const char *error_msg;
294 mxGetString(prhs[3], cmd, mxGetN(prhs[3]) + 1);
295 if ((argv[argc] = strtok(cmd,
" ")) != NULL)
296 while ((argv[++argc] = strtok(NULL,
" ")) != NULL)
299 for (i = 1; i < argc; i++)
301 if (argv[i][0] !=
'-')
break;
302 if ((++i >= argc) && argv[i - 1][1] !=
'q')
308 switch (argv[i - 1][1])
311 prob_estimate_flag = atoi(argv[i]);
318 mexPrintf(
"Unknown option: -%c\n", argv[i - 1][1]);
329 mexPrintf(
"Error: can't read model: %s\n", error_msg);
334 if (prob_estimate_flag)
338 mexPrintf(
"Model does not support probabiliy estimates\n");
347 info(
"Model supports probability estimates, but disabled in predicton.\n");
350 predict(plhs, prhs, model, prob_estimate_flag);
356 mexPrintf(
"model file should be a struct array\n");
void read_sparse_instance(const mxArray *prhs, int index, struct svm_node *x)
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
void predict(mxArray *plhs[], const mxArray *prhs[], struct svm_model *model, const int predict_probability)
double svm_get_svr_probability(const svm_model *model)
def svm_predict(y, x, m, options="")
static void fake_answer(mxArray *plhs[])
int svm_get_nr_class(const svm_model *model)
int svm_check_probability_model(const svm_model *model)
int(* info)(const char *fmt,...)
struct svm_model * matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
void svm_free_and_destroy_model(svm_model **model_ptr_ptr)
double svm_predict_values(const svm_model *model, const svm_node *x, double *dec_values)
double svm_predict_probability(const svm_model *model, const svm_node *x, double *prob_estimates)
struct svm_parameter param
int svm_get_svm_type(const svm_model *model)
int print_null(const char *s,...)