00001 #include <stdlib.h>
00002 #include <string.h>
00003 #include "../svm.h"
00004
00005 #include "mex.h"
00006
00007 #ifdef MX_API_VER
00008 #if MX_API_VER < 0x07030000
00009 typedef int mwIndex;
00010 #endif
00011 #endif
00012
00013 #define NUM_OF_RETURN_FIELD 10
00014
00015 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
00016
00017 static const char *field_names[] = {
00018 "Parameters",
00019 "nr_class",
00020 "totalSV",
00021 "rho",
00022 "Label",
00023 "ProbA",
00024 "ProbB",
00025 "nSV",
00026 "sv_coef",
00027 "SVs"
00028 };
00029
00030 const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
00031 {
00032 int i, j, n;
00033 double *ptr;
00034 mxArray *return_model, **rhs;
00035 int out_id = 0;
00036
00037 rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
00038
00039
00040 rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL);
00041 ptr = mxGetPr(rhs[out_id]);
00042 ptr[0] = model->param.svm_type;
00043 ptr[1] = model->param.kernel_type;
00044 ptr[2] = model->param.degree;
00045 ptr[3] = model->param.gamma;
00046 ptr[4] = model->param.coef0;
00047 out_id++;
00048
00049
00050 rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00051 ptr = mxGetPr(rhs[out_id]);
00052 ptr[0] = model->nr_class;
00053 out_id++;
00054
00055
00056 rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00057 ptr = mxGetPr(rhs[out_id]);
00058 ptr[0] = model->l;
00059 out_id++;
00060
00061
00062 n = model->nr_class*(model->nr_class-1)/2;
00063 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00064 ptr = mxGetPr(rhs[out_id]);
00065 for(i = 0; i < n; i++)
00066 ptr[i] = model->rho[i];
00067 out_id++;
00068
00069
00070 if(model->label)
00071 {
00072 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00073 ptr = mxGetPr(rhs[out_id]);
00074 for(i = 0; i < model->nr_class; i++)
00075 ptr[i] = model->label[i];
00076 }
00077 else
00078 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00079 out_id++;
00080
00081
00082 if(model->probA != NULL)
00083 {
00084 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00085 ptr = mxGetPr(rhs[out_id]);
00086 for(i = 0; i < n; i++)
00087 ptr[i] = model->probA[i];
00088 }
00089 else
00090 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00091 out_id ++;
00092
00093
00094 if(model->probB != NULL)
00095 {
00096 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00097 ptr = mxGetPr(rhs[out_id]);
00098 for(i = 0; i < n; i++)
00099 ptr[i] = model->probB[i];
00100 }
00101 else
00102 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00103 out_id++;
00104
00105
00106 if(model->nSV)
00107 {
00108 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00109 ptr = mxGetPr(rhs[out_id]);
00110 for(i = 0; i < model->nr_class; i++)
00111 ptr[i] = model->nSV[i];
00112 }
00113 else
00114 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00115 out_id++;
00116
00117
00118 rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL);
00119 ptr = mxGetPr(rhs[out_id]);
00120 for(i = 0; i < model->nr_class-1; i++)
00121 for(j = 0; j < model->l; j++)
00122 ptr[(i*(model->l))+j] = model->sv_coef[i][j];
00123 out_id++;
00124
00125
00126 {
00127 int ir_index, nonzero_element;
00128 mwIndex *ir, *jc;
00129 mxArray *pprhs[1], *pplhs[1];
00130
00131 if(model->param.kernel_type == PRECOMPUTED)
00132 {
00133 nonzero_element = model->l;
00134 num_of_feature = 1;
00135 }
00136 else
00137 {
00138 nonzero_element = 0;
00139 for(i = 0; i < model->l; i++) {
00140 j = 0;
00141 while(model->SV[i][j].index != -1)
00142 {
00143 nonzero_element++;
00144 j++;
00145 }
00146 }
00147 }
00148
00149
00150 rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL);
00151 ir = mxGetIr(rhs[out_id]);
00152 jc = mxGetJc(rhs[out_id]);
00153 ptr = mxGetPr(rhs[out_id]);
00154 jc[0] = ir_index = 0;
00155 for(i = 0;i < model->l; i++)
00156 {
00157 if(model->param.kernel_type == PRECOMPUTED)
00158 {
00159
00160 ir[ir_index] = 0;
00161 ptr[ir_index] = model->SV[i][0].value;
00162 ir_index++;
00163 jc[i+1] = jc[i] + 1;
00164 }
00165 else
00166 {
00167 int x_index = 0;
00168 while (model->SV[i][x_index].index != -1)
00169 {
00170 ir[ir_index] = model->SV[i][x_index].index - 1;
00171 ptr[ir_index] = model->SV[i][x_index].value;
00172 ir_index++, x_index++;
00173 }
00174 jc[i+1] = jc[i] + x_index;
00175 }
00176 }
00177
00178 pprhs[0] = rhs[out_id];
00179 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00180 return "cannot transpose SV matrix";
00181 rhs[out_id] = pplhs[0];
00182 out_id++;
00183 }
00184
00185
00186 return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
00187
00188
00189 for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
00190 mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
00191
00192 plhs[0] = return_model;
00193 mxFree(rhs);
00194
00195 return NULL;
00196 }
00197
00198 struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
00199 {
00200 int i, j, n, num_of_fields;
00201 double *ptr;
00202 int id = 0;
00203 struct svm_node *x_space;
00204 struct svm_model *model;
00205 mxArray **rhs;
00206
00207 num_of_fields = mxGetNumberOfFields(matlab_struct);
00208 if(num_of_fields != NUM_OF_RETURN_FIELD)
00209 {
00210 *msg = "number of return field is not correct";
00211 return NULL;
00212 }
00213 rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
00214
00215 for(i=0;i<num_of_fields;i++)
00216 rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
00217
00218 model = Malloc(struct svm_model, 1);
00219 model->rho = NULL;
00220 model->probA = NULL;
00221 model->probB = NULL;
00222 model->label = NULL;
00223 model->nSV = NULL;
00224 model->free_sv = 1;
00225
00226 ptr = mxGetPr(rhs[id]);
00227 model->param.svm_type = (int)ptr[0];
00228 model->param.kernel_type = (int)ptr[1];
00229 model->param.degree = (int)ptr[2];
00230 model->param.gamma = ptr[3];
00231 model->param.coef0 = ptr[4];
00232 id++;
00233
00234 ptr = mxGetPr(rhs[id]);
00235 model->nr_class = (int)ptr[0];
00236 id++;
00237
00238 ptr = mxGetPr(rhs[id]);
00239 model->l = (int)ptr[0];
00240 id++;
00241
00242
00243 n = model->nr_class * (model->nr_class-1)/2;
00244 model->rho = (double*) malloc(n*sizeof(double));
00245 ptr = mxGetPr(rhs[id]);
00246 for(i=0;i<n;i++)
00247 model->rho[i] = ptr[i];
00248 id++;
00249
00250
00251 if(mxIsEmpty(rhs[id]) == 0)
00252 {
00253 model->label = (int*) malloc(model->nr_class*sizeof(int));
00254 ptr = mxGetPr(rhs[id]);
00255 for(i=0;i<model->nr_class;i++)
00256 model->label[i] = (int)ptr[i];
00257 }
00258 id++;
00259
00260
00261 if(mxIsEmpty(rhs[id]) == 0)
00262 {
00263 model->probA = (double*) malloc(n*sizeof(double));
00264 ptr = mxGetPr(rhs[id]);
00265 for(i=0;i<n;i++)
00266 model->probA[i] = ptr[i];
00267 }
00268 id++;
00269
00270
00271 if(mxIsEmpty(rhs[id]) == 0)
00272 {
00273 model->probB = (double*) malloc(n*sizeof(double));
00274 ptr = mxGetPr(rhs[id]);
00275 for(i=0;i<n;i++)
00276 model->probB[i] = ptr[i];
00277 }
00278 id++;
00279
00280
00281 if(mxIsEmpty(rhs[id]) == 0)
00282 {
00283 model->nSV = (int*) malloc(model->nr_class*sizeof(int));
00284 ptr = mxGetPr(rhs[id]);
00285 for(i=0;i<model->nr_class;i++)
00286 model->nSV[i] = (int)ptr[i];
00287 }
00288 id++;
00289
00290
00291 ptr = mxGetPr(rhs[id]);
00292 model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double));
00293 for( i=0 ; i< model->nr_class -1 ; i++ )
00294 model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double));
00295 for(i = 0; i < model->nr_class - 1; i++)
00296 for(j = 0; j < model->l; j++)
00297 model->sv_coef[i][j] = ptr[i*(model->l)+j];
00298 id++;
00299
00300
00301 {
00302 int sr, sc, elements;
00303 int num_samples;
00304 mwIndex *ir, *jc;
00305 mxArray *pprhs[1], *pplhs[1];
00306
00307
00308 pprhs[0] = rhs[id];
00309 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00310 {
00311 svm_free_and_destroy_model(&model);
00312 *msg = "cannot transpose SV matrix";
00313 return NULL;
00314 }
00315 rhs[id] = pplhs[0];
00316
00317 sr = (int)mxGetN(rhs[id]);
00318 sc = (int)mxGetM(rhs[id]);
00319
00320 ptr = mxGetPr(rhs[id]);
00321 ir = mxGetIr(rhs[id]);
00322 jc = mxGetJc(rhs[id]);
00323
00324 num_samples = (int)mxGetNzmax(rhs[id]);
00325
00326 elements = num_samples + sr;
00327
00328 model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *));
00329 x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node));
00330
00331
00332 for(i=0;i<sr;i++)
00333 {
00334 int low = (int)jc[i], high = (int)jc[i+1];
00335 int x_index = 0;
00336 model->SV[i] = &x_space[low+i];
00337 for(j=low;j<high;j++)
00338 {
00339 model->SV[i][x_index].index = (int)ir[j] + 1;
00340 model->SV[i][x_index].value = ptr[j];
00341 x_index++;
00342 }
00343 model->SV[i][x_index].index = -1;
00344 }
00345
00346 id++;
00347 }
00348 mxFree(rhs);
00349
00350 return model;
00351 }