00001 #include <stdlib.h>
00002 #include <string.h>
00003 #include "../svm.h"
00004
00005 #include "mex.h"
00006
00007 #ifdef MX_API_VER
00008 #if MX_API_VER < 0x07030000
00009 typedef int mwIndex;
00010 #endif
00011 #endif
00012
00013 #define NUM_OF_RETURN_FIELD 11
00014
00015 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
00016
00017 static const char *field_names[] = {
00018 "Parameters",
00019 "nr_class",
00020 "totalSV",
00021 "rho",
00022 "Label",
00023 "sv_indices",
00024 "ProbA",
00025 "ProbB",
00026 "nSV",
00027 "sv_coef",
00028 "SVs"
00029 };
00030
00031 const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
00032 {
00033 int i, j, n;
00034 double *ptr;
00035 mxArray *return_model, **rhs;
00036 int out_id = 0;
00037
00038 rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
00039
00040
00041 rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL);
00042 ptr = mxGetPr(rhs[out_id]);
00043 ptr[0] = model->param.svm_type;
00044 ptr[1] = model->param.kernel_type;
00045 ptr[2] = model->param.degree;
00046 ptr[3] = model->param.gamma;
00047 ptr[4] = model->param.coef0;
00048 out_id++;
00049
00050
00051 rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00052 ptr = mxGetPr(rhs[out_id]);
00053 ptr[0] = model->nr_class;
00054 out_id++;
00055
00056
00057 rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00058 ptr = mxGetPr(rhs[out_id]);
00059 ptr[0] = model->l;
00060 out_id++;
00061
00062
00063 n = model->nr_class*(model->nr_class-1)/2;
00064 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00065 ptr = mxGetPr(rhs[out_id]);
00066 for(i = 0; i < n; i++)
00067 ptr[i] = model->rho[i];
00068 out_id++;
00069
00070
00071 if(model->label)
00072 {
00073 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00074 ptr = mxGetPr(rhs[out_id]);
00075 for(i = 0; i < model->nr_class; i++)
00076 ptr[i] = model->label[i];
00077 }
00078 else
00079 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00080 out_id++;
00081
00082
00083 if(model->sv_indices)
00084 {
00085 rhs[out_id] = mxCreateDoubleMatrix(model->l, 1, mxREAL);
00086 ptr = mxGetPr(rhs[out_id]);
00087 for(i = 0; i < model->l; i++)
00088 ptr[i] = model->sv_indices[i];
00089 }
00090 else
00091 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00092 out_id++;
00093
00094
00095 if(model->probA != NULL)
00096 {
00097 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00098 ptr = mxGetPr(rhs[out_id]);
00099 for(i = 0; i < n; i++)
00100 ptr[i] = model->probA[i];
00101 }
00102 else
00103 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00104 out_id ++;
00105
00106
00107 if(model->probB != NULL)
00108 {
00109 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00110 ptr = mxGetPr(rhs[out_id]);
00111 for(i = 0; i < n; i++)
00112 ptr[i] = model->probB[i];
00113 }
00114 else
00115 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00116 out_id++;
00117
00118
00119 if(model->nSV)
00120 {
00121 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00122 ptr = mxGetPr(rhs[out_id]);
00123 for(i = 0; i < model->nr_class; i++)
00124 ptr[i] = model->nSV[i];
00125 }
00126 else
00127 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00128 out_id++;
00129
00130
00131 rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL);
00132 ptr = mxGetPr(rhs[out_id]);
00133 for(i = 0; i < model->nr_class-1; i++)
00134 for(j = 0; j < model->l; j++)
00135 ptr[(i*(model->l))+j] = model->sv_coef[i][j];
00136 out_id++;
00137
00138
00139 {
00140 int ir_index, nonzero_element;
00141 mwIndex *ir, *jc;
00142 mxArray *pprhs[1], *pplhs[1];
00143
00144 if(model->param.kernel_type == PRECOMPUTED)
00145 {
00146 nonzero_element = model->l;
00147 num_of_feature = 1;
00148 }
00149 else
00150 {
00151 nonzero_element = 0;
00152 for(i = 0; i < model->l; i++) {
00153 j = 0;
00154 while(model->SV[i][j].index != -1)
00155 {
00156 nonzero_element++;
00157 j++;
00158 }
00159 }
00160 }
00161
00162
00163 rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL);
00164 ir = mxGetIr(rhs[out_id]);
00165 jc = mxGetJc(rhs[out_id]);
00166 ptr = mxGetPr(rhs[out_id]);
00167 jc[0] = ir_index = 0;
00168 for(i = 0;i < model->l; i++)
00169 {
00170 if(model->param.kernel_type == PRECOMPUTED)
00171 {
00172
00173 ir[ir_index] = 0;
00174 ptr[ir_index] = model->SV[i][0].value;
00175 ir_index++;
00176 jc[i+1] = jc[i] + 1;
00177 }
00178 else
00179 {
00180 int x_index = 0;
00181 while (model->SV[i][x_index].index != -1)
00182 {
00183 ir[ir_index] = model->SV[i][x_index].index - 1;
00184 ptr[ir_index] = model->SV[i][x_index].value;
00185 ir_index++, x_index++;
00186 }
00187 jc[i+1] = jc[i] + x_index;
00188 }
00189 }
00190
00191 pprhs[0] = rhs[out_id];
00192 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00193 return "cannot transpose SV matrix";
00194 rhs[out_id] = pplhs[0];
00195 out_id++;
00196 }
00197
00198
00199 return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
00200
00201
00202 for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
00203 mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
00204
00205 plhs[0] = return_model;
00206 mxFree(rhs);
00207
00208 return NULL;
00209 }
00210
00211 struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
00212 {
00213 int i, j, n, num_of_fields;
00214 double *ptr;
00215 int id = 0;
00216 struct svm_node *x_space;
00217 struct svm_model *model;
00218 mxArray **rhs;
00219
00220 num_of_fields = mxGetNumberOfFields(matlab_struct);
00221 if(num_of_fields != NUM_OF_RETURN_FIELD)
00222 {
00223 *msg = "number of return field is not correct";
00224 return NULL;
00225 }
00226 rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
00227
00228 for(i=0;i<num_of_fields;i++)
00229 rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
00230
00231 model = Malloc(struct svm_model, 1);
00232 model->rho = NULL;
00233 model->probA = NULL;
00234 model->probB = NULL;
00235 model->label = NULL;
00236 model->sv_indices = NULL;
00237 model->nSV = NULL;
00238 model->free_sv = 1;
00239
00240 ptr = mxGetPr(rhs[id]);
00241 model->param.svm_type = (int)ptr[0];
00242 model->param.kernel_type = (int)ptr[1];
00243 model->param.degree = (int)ptr[2];
00244 model->param.gamma = ptr[3];
00245 model->param.coef0 = ptr[4];
00246 id++;
00247
00248 ptr = mxGetPr(rhs[id]);
00249 model->nr_class = (int)ptr[0];
00250 id++;
00251
00252 ptr = mxGetPr(rhs[id]);
00253 model->l = (int)ptr[0];
00254 id++;
00255
00256
00257 n = model->nr_class * (model->nr_class-1)/2;
00258 model->rho = (double*) malloc(n*sizeof(double));
00259 ptr = mxGetPr(rhs[id]);
00260 for(i=0;i<n;i++)
00261 model->rho[i] = ptr[i];
00262 id++;
00263
00264
00265 if(mxIsEmpty(rhs[id]) == 0)
00266 {
00267 model->label = (int*) malloc(model->nr_class*sizeof(int));
00268 ptr = mxGetPr(rhs[id]);
00269 for(i=0;i<model->nr_class;i++)
00270 model->label[i] = (int)ptr[i];
00271 }
00272 id++;
00273
00274
00275 if(mxIsEmpty(rhs[id]) == 0)
00276 {
00277 model->sv_indices = (int*) malloc(model->l*sizeof(int));
00278 ptr = mxGetPr(rhs[id]);
00279 for(i=0;i<model->l;i++)
00280 model->sv_indices[i] = (int)ptr[i];
00281 }
00282 id++;
00283
00284
00285 if(mxIsEmpty(rhs[id]) == 0)
00286 {
00287 model->probA = (double*) malloc(n*sizeof(double));
00288 ptr = mxGetPr(rhs[id]);
00289 for(i=0;i<n;i++)
00290 model->probA[i] = ptr[i];
00291 }
00292 id++;
00293
00294
00295 if(mxIsEmpty(rhs[id]) == 0)
00296 {
00297 model->probB = (double*) malloc(n*sizeof(double));
00298 ptr = mxGetPr(rhs[id]);
00299 for(i=0;i<n;i++)
00300 model->probB[i] = ptr[i];
00301 }
00302 id++;
00303
00304
00305 if(mxIsEmpty(rhs[id]) == 0)
00306 {
00307 model->nSV = (int*) malloc(model->nr_class*sizeof(int));
00308 ptr = mxGetPr(rhs[id]);
00309 for(i=0;i<model->nr_class;i++)
00310 model->nSV[i] = (int)ptr[i];
00311 }
00312 id++;
00313
00314
00315 ptr = mxGetPr(rhs[id]);
00316 model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double));
00317 for( i=0 ; i< model->nr_class -1 ; i++ )
00318 model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double));
00319 for(i = 0; i < model->nr_class - 1; i++)
00320 for(j = 0; j < model->l; j++)
00321 model->sv_coef[i][j] = ptr[i*(model->l)+j];
00322 id++;
00323
00324
00325 {
00326 int sr, sc, elements;
00327 int num_samples;
00328 mwIndex *ir, *jc;
00329 mxArray *pprhs[1], *pplhs[1];
00330
00331
00332 pprhs[0] = rhs[id];
00333 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00334 {
00335 svm_free_and_destroy_model(&model);
00336 *msg = "cannot transpose SV matrix";
00337 return NULL;
00338 }
00339 rhs[id] = pplhs[0];
00340
00341 sr = (int)mxGetN(rhs[id]);
00342 sc = (int)mxGetM(rhs[id]);
00343
00344 ptr = mxGetPr(rhs[id]);
00345 ir = mxGetIr(rhs[id]);
00346 jc = mxGetJc(rhs[id]);
00347
00348 num_samples = (int)mxGetNzmax(rhs[id]);
00349
00350 elements = num_samples + sr;
00351
00352 model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *));
00353 x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node));
00354
00355
00356 for(i=0;i<sr;i++)
00357 {
00358 int low = (int)jc[i], high = (int)jc[i+1];
00359 int x_index = 0;
00360 model->SV[i] = &x_space[low+i];
00361 for(j=low;j<high;j++)
00362 {
00363 model->SV[i][x_index].index = (int)ir[j] + 1;
00364 model->SV[i][x_index].value = ptr[j];
00365 x_index++;
00366 }
00367 model->SV[i][x_index].index = -1;
00368 }
00369
00370 id++;
00371 }
00372 mxFree(rhs);
00373
00374 return model;
00375 }