svm_model_matlab.c
Go to the documentation of this file.
00001 #include <stdlib.h>
00002 #include <string.h>
00003 #include "../svm.h"
00004 
00005 #include "mex.h"
00006 
00007 #ifdef MX_API_VER
00008 #if MX_API_VER < 0x07030000
00009 typedef int mwIndex;
00010 #endif
00011 #endif
00012 
00013 #define NUM_OF_RETURN_FIELD 10
00014 
00015 #define Malloc(type,n) (type *)malloc((n)*sizeof(type))
00016 
00017 static const char *field_names[] = {
00018         "Parameters",
00019         "nr_class",
00020         "totalSV",
00021         "rho",
00022         "Label",
00023         "ProbA",
00024         "ProbB",
00025         "nSV",
00026         "sv_coef",
00027         "SVs"
00028 };
00029 
00030 const char *model_to_matlab_structure(mxArray *plhs[], int num_of_feature, struct svm_model *model)
00031 {
00032         int i, j, n;
00033         double *ptr;
00034         mxArray *return_model, **rhs;
00035         int out_id = 0;
00036 
00037         rhs = (mxArray **)mxMalloc(sizeof(mxArray *)*NUM_OF_RETURN_FIELD);
00038 
00039         // Parameters
00040         rhs[out_id] = mxCreateDoubleMatrix(5, 1, mxREAL);
00041         ptr = mxGetPr(rhs[out_id]);
00042         ptr[0] = model->param.svm_type;
00043         ptr[1] = model->param.kernel_type;
00044         ptr[2] = model->param.degree;
00045         ptr[3] = model->param.gamma;
00046         ptr[4] = model->param.coef0;
00047         out_id++;
00048 
00049         // nr_class
00050         rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00051         ptr = mxGetPr(rhs[out_id]);
00052         ptr[0] = model->nr_class;
00053         out_id++;
00054 
00055         // total SV
00056         rhs[out_id] = mxCreateDoubleMatrix(1, 1, mxREAL);
00057         ptr = mxGetPr(rhs[out_id]);
00058         ptr[0] = model->l;
00059         out_id++;
00060 
00061         // rho
00062         n = model->nr_class*(model->nr_class-1)/2;
00063         rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00064         ptr = mxGetPr(rhs[out_id]);
00065         for(i = 0; i < n; i++)
00066                 ptr[i] = model->rho[i];
00067         out_id++;
00068 
00069         // Label
00070         if(model->label)
00071         {
00072                 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00073                 ptr = mxGetPr(rhs[out_id]);
00074                 for(i = 0; i < model->nr_class; i++)
00075                         ptr[i] = model->label[i];
00076         }
00077         else
00078                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00079         out_id++;
00080 
00081         // probA
00082         if(model->probA != NULL)
00083         {
00084                 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00085                 ptr = mxGetPr(rhs[out_id]);
00086                 for(i = 0; i < n; i++)
00087                         ptr[i] = model->probA[i];
00088         }
00089         else
00090                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00091         out_id ++;
00092 
00093         // probB
00094         if(model->probB != NULL)
00095         {
00096                 rhs[out_id] = mxCreateDoubleMatrix(n, 1, mxREAL);
00097                 ptr = mxGetPr(rhs[out_id]);
00098                 for(i = 0; i < n; i++)
00099                         ptr[i] = model->probB[i];
00100         }
00101         else
00102                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00103         out_id++;
00104 
00105         // nSV
00106         if(model->nSV)
00107         {
00108                 rhs[out_id] = mxCreateDoubleMatrix(model->nr_class, 1, mxREAL);
00109                 ptr = mxGetPr(rhs[out_id]);
00110                 for(i = 0; i < model->nr_class; i++)
00111                         ptr[i] = model->nSV[i];
00112         }
00113         else
00114                 rhs[out_id] = mxCreateDoubleMatrix(0, 0, mxREAL);
00115         out_id++;
00116 
00117         // sv_coef
00118         rhs[out_id] = mxCreateDoubleMatrix(model->l, model->nr_class-1, mxREAL);
00119         ptr = mxGetPr(rhs[out_id]);
00120         for(i = 0; i < model->nr_class-1; i++)
00121                 for(j = 0; j < model->l; j++)
00122                         ptr[(i*(model->l))+j] = model->sv_coef[i][j];
00123         out_id++;
00124 
00125         // SVs
00126         {
00127                 int ir_index, nonzero_element;
00128                 mwIndex *ir, *jc;
00129                 mxArray *pprhs[1], *pplhs[1];   
00130 
00131                 if(model->param.kernel_type == PRECOMPUTED)
00132                 {
00133                         nonzero_element = model->l;
00134                         num_of_feature = 1;
00135                 }
00136                 else
00137                 {
00138                         nonzero_element = 0;
00139                         for(i = 0; i < model->l; i++) {
00140                                 j = 0;
00141                                 while(model->SV[i][j].index != -1) 
00142                                 {
00143                                         nonzero_element++;
00144                                         j++;
00145                                 }
00146                         }
00147                 }
00148 
00149                 // SV in column, easier accessing
00150                 rhs[out_id] = mxCreateSparse(num_of_feature, model->l, nonzero_element, mxREAL);
00151                 ir = mxGetIr(rhs[out_id]);
00152                 jc = mxGetJc(rhs[out_id]);
00153                 ptr = mxGetPr(rhs[out_id]);
00154                 jc[0] = ir_index = 0;           
00155                 for(i = 0;i < model->l; i++)
00156                 {
00157                         if(model->param.kernel_type == PRECOMPUTED)
00158                         {
00159                                 // make a (1 x model->l) matrix
00160                                 ir[ir_index] = 0; 
00161                                 ptr[ir_index] = model->SV[i][0].value;
00162                                 ir_index++;
00163                                 jc[i+1] = jc[i] + 1;
00164                         }
00165                         else
00166                         {
00167                                 int x_index = 0;
00168                                 while (model->SV[i][x_index].index != -1)
00169                                 {
00170                                         ir[ir_index] = model->SV[i][x_index].index - 1; 
00171                                         ptr[ir_index] = model->SV[i][x_index].value;
00172                                         ir_index++, x_index++;
00173                                 }
00174                                 jc[i+1] = jc[i] + x_index;
00175                         }
00176                 }
00177                 // transpose back to SV in row
00178                 pprhs[0] = rhs[out_id];
00179                 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose"))
00180                         return "cannot transpose SV matrix";
00181                 rhs[out_id] = pplhs[0];
00182                 out_id++;
00183         }
00184 
00185         /* Create a struct matrix contains NUM_OF_RETURN_FIELD fields */
00186         return_model = mxCreateStructMatrix(1, 1, NUM_OF_RETURN_FIELD, field_names);
00187 
00188         /* Fill struct matrix with input arguments */
00189         for(i = 0; i < NUM_OF_RETURN_FIELD; i++)
00190                 mxSetField(return_model,0,field_names[i],mxDuplicateArray(rhs[i]));
00191         /* return */
00192         plhs[0] = return_model;
00193         mxFree(rhs);
00194 
00195         return NULL;
00196 }
00197 
00198 struct svm_model *matlab_matrix_to_model(const mxArray *matlab_struct, const char **msg)
00199 {
00200         int i, j, n, num_of_fields;
00201         double *ptr;
00202         int id = 0;
00203         struct svm_node *x_space;
00204         struct svm_model *model;
00205         mxArray **rhs;
00206 
00207         num_of_fields = mxGetNumberOfFields(matlab_struct);
00208         if(num_of_fields != NUM_OF_RETURN_FIELD) 
00209         {
00210                 *msg = "number of return field is not correct";
00211                 return NULL;
00212         }
00213         rhs = (mxArray **) mxMalloc(sizeof(mxArray *)*num_of_fields);
00214 
00215         for(i=0;i<num_of_fields;i++)
00216                 rhs[i] = mxGetFieldByNumber(matlab_struct, 0, i);
00217 
00218         model = Malloc(struct svm_model, 1);
00219         model->rho = NULL;
00220         model->probA = NULL;
00221         model->probB = NULL;
00222         model->label = NULL;
00223         model->nSV = NULL;
00224         model->free_sv = 1; // XXX
00225 
00226         ptr = mxGetPr(rhs[id]);
00227         model->param.svm_type = (int)ptr[0];
00228         model->param.kernel_type  = (int)ptr[1];
00229         model->param.degree       = (int)ptr[2];
00230         model->param.gamma        = ptr[3];
00231         model->param.coef0        = ptr[4];
00232         id++;
00233 
00234         ptr = mxGetPr(rhs[id]);
00235         model->nr_class = (int)ptr[0];
00236         id++;
00237 
00238         ptr = mxGetPr(rhs[id]);
00239         model->l = (int)ptr[0];
00240         id++;
00241 
00242         // rho
00243         n = model->nr_class * (model->nr_class-1)/2;
00244         model->rho = (double*) malloc(n*sizeof(double));
00245         ptr = mxGetPr(rhs[id]);
00246         for(i=0;i<n;i++)
00247                 model->rho[i] = ptr[i];
00248         id++;
00249 
00250         // label
00251         if(mxIsEmpty(rhs[id]) == 0)
00252         {
00253                 model->label = (int*) malloc(model->nr_class*sizeof(int));
00254                 ptr = mxGetPr(rhs[id]);
00255                 for(i=0;i<model->nr_class;i++)
00256                         model->label[i] = (int)ptr[i];
00257         }
00258         id++;
00259 
00260         // probA
00261         if(mxIsEmpty(rhs[id]) == 0)
00262         {
00263                 model->probA = (double*) malloc(n*sizeof(double));
00264                 ptr = mxGetPr(rhs[id]);
00265                 for(i=0;i<n;i++)
00266                         model->probA[i] = ptr[i];
00267         }
00268         id++;
00269 
00270         // probB
00271         if(mxIsEmpty(rhs[id]) == 0)
00272         {
00273                 model->probB = (double*) malloc(n*sizeof(double));
00274                 ptr = mxGetPr(rhs[id]);
00275                 for(i=0;i<n;i++)
00276                         model->probB[i] = ptr[i];
00277         }
00278         id++;
00279 
00280         // nSV
00281         if(mxIsEmpty(rhs[id]) == 0)
00282         {
00283                 model->nSV = (int*) malloc(model->nr_class*sizeof(int));
00284                 ptr = mxGetPr(rhs[id]);
00285                 for(i=0;i<model->nr_class;i++)
00286                         model->nSV[i] = (int)ptr[i];
00287         }
00288         id++;
00289 
00290         // sv_coef
00291         ptr = mxGetPr(rhs[id]);
00292         model->sv_coef = (double**) malloc((model->nr_class-1)*sizeof(double));
00293         for( i=0 ; i< model->nr_class -1 ; i++ )
00294                 model->sv_coef[i] = (double*) malloc((model->l)*sizeof(double));
00295         for(i = 0; i < model->nr_class - 1; i++)
00296                 for(j = 0; j < model->l; j++)
00297                         model->sv_coef[i][j] = ptr[i*(model->l)+j];
00298         id++;
00299 
00300         // SV
00301         {
00302                 int sr, sc, elements;
00303                 int num_samples;
00304                 mwIndex *ir, *jc;
00305                 mxArray *pprhs[1], *pplhs[1];
00306 
00307                 // transpose SV
00308                 pprhs[0] = rhs[id];
00309                 if(mexCallMATLAB(1, pplhs, 1, pprhs, "transpose")) 
00310                 {
00311                         svm_free_and_destroy_model(&model);
00312                         *msg = "cannot transpose SV matrix";
00313                         return NULL;
00314                 }
00315                 rhs[id] = pplhs[0];
00316 
00317                 sr = (int)mxGetN(rhs[id]);
00318                 sc = (int)mxGetM(rhs[id]);
00319 
00320                 ptr = mxGetPr(rhs[id]);
00321                 ir = mxGetIr(rhs[id]);
00322                 jc = mxGetJc(rhs[id]);
00323 
00324                 num_samples = (int)mxGetNzmax(rhs[id]);
00325 
00326                 elements = num_samples + sr;
00327 
00328                 model->SV = (struct svm_node **) malloc(sr * sizeof(struct svm_node *));
00329                 x_space = (struct svm_node *)malloc(elements * sizeof(struct svm_node));
00330 
00331                 // SV is in column
00332                 for(i=0;i<sr;i++)
00333                 {
00334                         int low = (int)jc[i], high = (int)jc[i+1];
00335                         int x_index = 0;
00336                         model->SV[i] = &x_space[low+i];
00337                         for(j=low;j<high;j++)
00338                         {
00339                                 model->SV[i][x_index].index = (int)ir[j] + 1; 
00340                                 model->SV[i][x_index].value = ptr[j];
00341                                 x_index++;
00342                         }
00343                         model->SV[i][x_index].index = -1;
00344                 }
00345 
00346                 id++;
00347         }
00348         mxFree(rhs);
00349 
00350         return model;
00351 }


haf_grasping
Author(s): David Fischinger
autogenerated on Thu Jun 6 2019 18:35:09