svm_model_matlab.c
Go to the documentation of this file.
00001 #include <stdlib.h>
00002 #include <string.h>
00003 #include "../svm.h"
00004 
00005 #include "mex.h"
00006 
00007 #ifdef MX_API_VER
00008 #if MX_API_VER < 0x07030000
00009 typedef int mwIndex;
00010 #endif
00011 #endif
00012 
00013 #define NUM_OF_RETURN_FIELD 11
00014 
00015 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
00016 
00017 static const char *field_names[] = {
00018         "Parameters",
00019         "nr_class",
00020         "totalSV",
00021         "rho",
00022         "Label",
00023         "sv_indices",
00024         "ProbA",
00025         "ProbB",
00026         "nSV",
00027         "sv_coef",
00028         "SVs"
00029 };
00030 
00031 const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
00032 {
00033         int i, j, n;
00034         double *ptr;
00035         mxArray *return_model, **rhs;
00036         int out_id = 0;
00037 
00038         rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
00039 
00040         // Parameters
00041         rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL);
00042         ptr = mxGetPr(rhs[out_id]);
00043         ptr[0] = model->param.svm_type;
00044         ptr[1] = model->param.kernel_type;
00045         ptr[2] = model->param.degree;
00046         ptr[3] = model->param.gamma;
00047         ptr[4] = model->param.coef0;
00048         out_id++;
00049 
00050         // nr_class
00051         rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00052         ptr = mxGetPr(rhs[out_id]);
00053         ptr[0] = model->nr_class;
00054         out_id++;
00055 
00056         // total SV
00057         rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00058         ptr = mxGetPr(rhs[out_id]);
00059         ptr[0] = model->l;
00060         out_id++;
00061 
00062         // rho
00063         n = model->nr_class*(model->nr_class-1)/2;
00064         rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00065         ptr = mxGetPr(rhs[out_id]);
00066         for(i = 0; i < n; i++)
00067                 ptr[i] = model->rho[i];
00068         out_id++;
00069 
00070         // Label
00071         if(model->label)
00072         {
00073                 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00074                 ptr = mxGetPr(rhs[out_id]);
00075                 for(i = 0; i < model->nr_class; i++)
00076                         ptr[i] = model->label[i];
00077         }
00078         else
00079                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00080         out_id++;
00081 
00082         // sv_indices
00083         if(model->sv_indices)
00084         {
00085                 rhs[out_id] = mxCreateDoubleMatrix(model->l, 1, mxREAL);
00086                 ptr = mxGetPr(rhs[out_id]);
00087                 for(i = 0; i < model->l; i++)
00088                         ptr[i] = model->sv_indices[i];
00089         }
00090         else
00091                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00092         out_id++;
00093 
00094         // probA
00095         if(model->probA != NULL)
00096         {
00097                 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00098                 ptr = mxGetPr(rhs[out_id]);
00099                 for(i = 0; i < n; i++)
00100                         ptr[i] = model->probA[i];
00101         }
00102         else
00103                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00104         out_id ++;
00105 
00106         // probB
00107         if(model->probB != NULL)
00108         {
00109                 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00110                 ptr = mxGetPr(rhs[out_id]);
00111                 for(i = 0; i < n; i++)
00112                         ptr[i] = model->probB[i];
00113         }
00114         else
00115                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00116         out_id++;
00117 
00118         // nSV
00119         if(model->nSV)
00120         {
00121                 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00122                 ptr = mxGetPr(rhs[out_id]);
00123                 for(i = 0; i < model->nr_class; i++)
00124                         ptr[i] = model->nSV[i];
00125         }
00126         else
00127                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00128         out_id++;
00129 
00130         // sv_coef
00131         rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL);
00132         ptr = mxGetPr(rhs[out_id]);
00133         for(i = 0; i < model->nr_class-1; i++)
00134                 for(j = 0; j < model->l; j++)
00135                         ptr[(i*(model->l))+j] = model->sv_coef[i][j];
00136         out_id++;
00137 
00138         // SVs
00139         {
00140                 int ir_index, nonzero_element;
00141                 mwIndex *ir, *jc;
00142                 mxArray *pprhs[1], *pplhs[1];   
00143 
00144                 if(model->param.kernel_type == PRECOMPUTED)
00145                 {
00146                         nonzero_element = model->l;
00147                         num_of_feature = 1;
00148                 }
00149                 else
00150                 {
00151                         nonzero_element = 0;
00152                         for(i = 0; i < model->l; i++) {
00153                                 j = 0;
00154                                 while(model->SV[i][j].index != -1) 
00155                                 {
00156                                         nonzero_element++;
00157                                         j++;
00158                                 }
00159                         }
00160                 }
00161 
00162                 // SV in column, easier accessing
00163                 rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL);
00164                 ir = mxGetIr(rhs[out_id]);
00165                 jc = mxGetJc(rhs[out_id]);
00166                 ptr = mxGetPr(rhs[out_id]);
00167                 jc[0] = ir_index = 0;           
00168                 for(i = 0;i < model->l; i++)
00169                 {
00170                         if(model->param.kernel_type == PRECOMPUTED)
00171                         {
00172                                 // make a (1 x model->l) matrix
00173                                 ir[ir_index] = 0; 
00174                                 ptr[ir_index] = model->SV[i][0].value;
00175                                 ir_index++;
00176                                 jc[i+1] = jc[i] + 1;
00177                         }
00178                         else
00179                         {
00180                                 int x_index = 0;
00181                                 while (model->SV[i][x_index].index != -1)
00182                                 {
00183                                         ir[ir_index] = model->SV[i][x_index].index - 1; 
00184                                         ptr[ir_index] = model->SV[i][x_index].value;
00185                                         ir_index++, x_index++;
00186                                 }
00187                                 jc[i+1] = jc[i] + x_index;
00188                         }
00189                 }
00190                 // transpose back to SV in row
00191                 pprhs[0] = rhs[out_id];
00192                 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00193                         return "cannot transpose SV matrix";
00194                 rhs[out_id] = pplhs[0];
00195                 out_id++;
00196         }
00197 
00198         /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
00199         return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
00200 
00201         /* Fill struct matrix with input arguments */
00202         for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
00203                 mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
00204         /* return */
00205         plhs[0] = return_model;
00206         mxFree(rhs);
00207 
00208         return NULL;
00209 }
00210 
00211 struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
00212 {
00213         int i, j, n, num_of_fields;
00214         double *ptr;
00215         int id = 0;
00216         struct svm_node *x_space;
00217         struct svm_model *model;
00218         mxArray **rhs;
00219 
00220         num_of_fields = mxGetNumberOfFields(matlab_struct);
00221         if(num_of_fields != NUM_OF_RETURN_FIELD) 
00222         {
00223                 *msg = "number of return field is not correct";
00224                 return NULL;
00225         }
00226         rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
00227 
00228         for(i=0;i<num_of_fields;i++)
00229                 rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
00230 
00231         model = Malloc(struct svm_model, 1);
00232         model->rho = NULL;
00233         model->probA = NULL;
00234         model->probB = NULL;
00235         model->label = NULL;
00236         model->sv_indices = NULL;
00237         model->nSV = NULL;
00238         model->free_sv = 1; // XXX
00239 
00240         ptr = mxGetPr(rhs[id]);
00241         model->param.svm_type = (int)ptr[0];
00242         model->param.kernel_type  = (int)ptr[1];
00243         model->param.degree       = (int)ptr[2];
00244         model->param.gamma        = ptr[3];
00245         model->param.coef0        = ptr[4];
00246         id++;
00247 
00248         ptr = mxGetPr(rhs[id]);
00249         model->nr_class = (int)ptr[0];
00250         id++;
00251 
00252         ptr = mxGetPr(rhs[id]);
00253         model->l = (int)ptr[0];
00254         id++;
00255 
00256         // rho
00257         n = model->nr_class * (model->nr_class-1)/2;
00258         model->rho = (double*) malloc(n*sizeof(double));
00259         ptr = mxGetPr(rhs[id]);
00260         for(i=0;i<n;i++)
00261                 model->rho[i] = ptr[i];
00262         id++;
00263 
00264         // label
00265         if(mxIsEmpty(rhs[id]) == 0)
00266         {
00267                 model->label = (int*) malloc(model->nr_class*sizeof(int));
00268                 ptr = mxGetPr(rhs[id]);
00269                 for(i=0;i<model->nr_class;i++)
00270                         model->label[i] = (int)ptr[i];
00271         }
00272         id++;
00273 
00274         // sv_indices
00275         if(mxIsEmpty(rhs[id]) == 0)
00276         {
00277                 model->sv_indices = (int*) malloc(model->l*sizeof(int));
00278                 ptr = mxGetPr(rhs[id]);
00279                 for(i=0;i<model->l;i++)
00280                         model->sv_indices[i] = (int)ptr[i];
00281         }
00282         id++;
00283 
00284         // probA
00285         if(mxIsEmpty(rhs[id]) == 0)
00286         {
00287                 model->probA = (double*) malloc(n*sizeof(double));
00288                 ptr = mxGetPr(rhs[id]);
00289                 for(i=0;i<n;i++)
00290                         model->probA[i] = ptr[i];
00291         }
00292         id++;
00293 
00294         // probB
00295         if(mxIsEmpty(rhs[id]) == 0)
00296         {
00297                 model->probB = (double*) malloc(n*sizeof(double));
00298                 ptr = mxGetPr(rhs[id]);
00299                 for(i=0;i<n;i++)
00300                         model->probB[i] = ptr[i];
00301         }
00302         id++;
00303 
00304         // nSV
00305         if(mxIsEmpty(rhs[id]) == 0)
00306         {
00307                 model->nSV = (int*) malloc(model->nr_class*sizeof(int));
00308                 ptr = mxGetPr(rhs[id]);
00309                 for(i=0;i<model->nr_class;i++)
00310                         model->nSV[i] = (int)ptr[i];
00311         }
00312         id++;
00313 
00314         // sv_coef
00315         ptr = mxGetPr(rhs[id]);
00316         model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double));
00317         for( i=0 ; i< model->nr_class -1 ; i++ )
00318                 model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double));
00319         for(i = 0; i < model->nr_class - 1; i++)
00320                 for(j = 0; j < model->l; j++)
00321                         model->sv_coef[i][j] = ptr[i*(model->l)+j];
00322         id++;
00323 
00324         // SV
00325         {
00326                 int sr, sc, elements;
00327                 int num_samples;
00328                 mwIndex *ir, *jc;
00329                 mxArray *pprhs[1], *pplhs[1];
00330 
00331                 // transpose SV
00332                 pprhs[0] = rhs[id];
00333                 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose")) 
00334                 {
00335                         svm_free_and_destroy_model(&model);
00336                         *msg = "cannot transpose SV matrix";
00337                         return NULL;
00338                 }
00339                 rhs[id] = pplhs[0];
00340 
00341                 sr = (int)mxGetN(rhs[id]);
00342                 sc = (int)mxGetM(rhs[id]);
00343 
00344                 ptr = mxGetPr(rhs[id]);
00345                 ir = mxGetIr(rhs[id]);
00346                 jc = mxGetJc(rhs[id]);
00347 
00348                 num_samples = (int)mxGetNzmax(rhs[id]);
00349 
00350                 elements = num_samples + sr;
00351 
00352                 model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *));
00353                 x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node));
00354 
00355                 // SV is in column
00356                 for(i=0;i<sr;i++)
00357                 {
00358                         int low = (int)jc[i], high = (int)jc[i+1];
00359                         int x_index = 0;
00360                         model->SV[i] = &x_space[low+i];
00361                         for(j=low;j<high;j++)
00362                         {
00363                                 model->SV[i][x_index].index = (int)ir[j] + 1; 
00364                                 model->SV[i][x_index].value = ptr[j];
00365                                 x_index++;
00366                         }
00367                         model->SV[i][x_index].index = -1;
00368                 }
00369 
00370                 id++;
00371         }
00372         mxFree(rhs);
00373 
00374         return model;
00375 }


ml_classifiers
Author(s): Scott Niekum
autogenerated on Thu Aug 27 2015 13:59:04