vl_svmtrain.c
Go to the documentation of this file.
00001 
00008 /*
00009 Copyright (C) 2012 Daniele Perrone.
00010 Copyright (C) 2013 Milan Sulc
00011 Copyright (C) 2011-13 Andrea Vedaldi.
00012 All rights reserved.
00013 
00014 This file is part of the VLFeat library and is made available under
00015 the terms of the BSD license (see the COPYING file).
00016 */
00017 
00018 #include <mexutils.h>
00019 #include <vl/svm.h>
00020 #include <vl/mathop.h>
00021 #include <vl/homkermap.h>
00022 #include <vl/stringop.h>
00023 #include <assert.h>
00024 #include <string.h>
00025 
00026 /* option codes */
00027 enum {
00028   /* common */
00029   opt_epsilon,
00030   opt_max_num_iterations,
00031   opt_bias_multiplier,
00032   opt_diagnostic_function,
00033   opt_diagnostic_frequency,
00034   opt_validation_subset,
00035   opt_loss,
00036   opt_model,
00037   opt_bias,
00038   opt_weights,
00039 
00040   /* switching to SDCA */
00041   opt_verbose,
00042   opt_solver,
00043 
00044   /* SGD specific */
00045   opt_starting_iteration,
00046   opt_bias_learning_rate
00047 
00048   /* DCA specific */
00049 };
00050 
00051 
00052 /* options */
00053 vlmxOption  options [] = {
00054   {"Epsilon",             1,   opt_epsilon             },
00055   {"MaxNumIterations",    1,   opt_max_num_iterations  },
00056   {"BiasMultiplier",      1,   opt_bias_multiplier     },
00057   {"DiagnosticFunction",  1,   opt_diagnostic_function },
00058   {"DiagnosticFrequency", 1,   opt_diagnostic_frequency},
00059   {"ValidationSubset",    1,   opt_validation_subset   },
00060   {"Loss",                1,   opt_loss                },
00061   {"Verbose",             0,   opt_verbose             },
00062   {"Solver",              1,   opt_solver              },
00063   {"Model",               1,   opt_model               },
00064   {"Bias",                1,   opt_bias                },
00065   {"Weights",             1,   opt_weights             },
00066 
00067   /* SGD specific */
00068   {"StartingIteration",   1,   opt_starting_iteration  },
00069   {"BiasLearningRate",    1,   opt_bias_learning_rate  },
00070 
00071   /* DCA specific */
00072   {0,                     0,   0                       }
00073 } ;
00074 
00075 mxArray * createScalarStructArray(void const **fields)
00076 {
00077   void const **iter ;
00078   char const **niter ;
00079   char const **names ;
00080   vl_size numFields = 0 ;
00081   mxArray * s ;
00082   mwSize dims [] = {1, 1} ;
00083 
00084   for (iter = fields ; *iter ; iter += 2) numFields++ ;
00085 
00086   names = vl_calloc(numFields, sizeof(char const*)) ;
00087 
00088   for (iter = fields, niter = names ; *iter ; iter += 2, niter++) {
00089     *niter = *iter ;
00090   }
00091 
00092   s = mxCreateStructArray(sizeof(dims)/sizeof(dims[0]),
00093                           dims,
00094                           (int)numFields,
00095                           names) ;
00096   for (iter = fields, niter = names ; *iter; iter += 2, niter++) {
00097     mxSetField(s, 0, *niter, (mxArray*)(*(iter+1))) ;
00098   }
00099   return s ;
00100 }
00101 
00102 /* ---------------------------------------------------------------- */
00103 /*                                                 Parsing datasets */
00104 /* ---------------------------------------------------------------- */
00105 
00106 VlSvmDataset * parseDataset(const mxArray * dataset_array)
00107 {
00108   VlSvmDataset * dataset ;
00109   {
00110     mxArray * data_array ;
00111     mxClassID dataClass ;
00112     vl_size dimension ;
00113     vl_size numData ;
00114     vl_type dataType ;
00115     if (! mxIsStruct(dataset_array)) {
00116       vlmxError(vlmxErrInvalidArgument, "DATASET is not a structure.") ;
00117     }
00118     if (mxGetNumberOfElements(dataset_array) != 1) {
00119       vlmxError(vlmxErrInvalidArgument, "DATASET is not a singleton.") ;
00120     }
00121     data_array = mxGetField(dataset_array, 0, "data") ;
00122     if (data_array == NULL) {
00123       vlmxError(vlmxErrInvalidArgument, "DATASET is missing the DATA field.") ;
00124     }
00125     if (!vlmxIsMatrix(data_array,-1,-1)) {
00126       vlmxError(vlmxErrInvalidArgument,"DATASET.DATA is not a matrix.") ;
00127     }
00128     dimension = mxGetM (data_array) ;
00129     numData = mxGetN (data_array) ;
00130     dataClass = mxGetClassID (data_array) ;
00131 
00132     if (dimension == 0 || numData == 0) {
00133       vlmxError(vlmxErrInvalidArgument, "DATASET.DATA is empty.") ;
00134     }
00135 
00136     switch (dataClass) {
00137       case mxSINGLE_CLASS : dataType = VL_TYPE_FLOAT ; break ;
00138       case mxDOUBLE_CLASS : dataType = VL_TYPE_DOUBLE ; break ;
00139       default:
00140         vlmxError(vlmxErrInvalidArgument, "DATASET.DATA is neither either SINGLE or DOUBLE.") ;
00141     }
00142     dataset = vl_svmdataset_new(dataType, mxGetData(data_array), dimension, numData) ;
00143   }
00144 
00145   /* homogeneous kernel map support */
00146   {
00147     VlHomogeneousKernelType kernelType = VlHomogeneousKernelChi2 ;
00148     VlHomogeneousKernelMapWindowType windowType = VlHomogeneousKernelMapWindowRectangular ;
00149     double gamma = 1.0 ;
00150     double period = -1 ;
00151     int n = 1 ;
00152     VlHomogeneousKernelMap * hom = NULL ;
00153     mxArray * hom_array ;
00154     mxArray * field ;
00155 
00156     hom_array = mxGetField(dataset_array, 0, "homkermap") ;
00157     if (hom_array != NULL)
00158     {
00159       if (!mxIsStruct(hom_array)) {
00160         vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP is not a structure") ;
00161       }
00162 
00163       field = mxGetField(hom_array, 0, "order") ;
00164       if (field != NULL) {
00165         if (! vlmxIsPlainScalar(field)) {
00166           vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.ORDER is not a scalar.") ;
00167         }
00168         n = *mxGetPr(field) ;
00169         if (n < 0) {
00170           vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.ORDER is negative.") ;
00171         }
00172       }
00173 
00174       field = mxGetField(hom_array, 0, "kernel") ;
00175       if (field != NULL) {
00176         char buffer [1024] ;
00177         mxGetString(field, buffer, sizeof(buffer) / sizeof(char)) ;
00178         if (vl_string_casei_cmp("kl1", buffer) == 0) {
00179           kernelType = VlHomogeneousKernelIntersection ;
00180         } else if (vl_string_casei_cmp("kchi2", buffer) == 0) {
00181           kernelType = VlHomogeneousKernelChi2 ;
00182         } else if (vl_string_casei_cmp("kjs", buffer) == 0) {
00183           kernelType = VlHomogeneousKernelJS ;
00184         } else if (vl_string_casei_cmp("kinters", buffer) == 0) {
00185           kernelType = VlHomogeneousKernelIntersection ;
00186         } else {
00187           vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.KERNEL is not a recognized kernel type.") ;
00188         }
00189       }
00190 
00191       field = mxGetField(hom_array, 0, "window") ;
00192       if (field != NULL) {
00193         char buffer [1024] ;
00194         mxGetString(field, buffer, sizeof(buffer) / sizeof(char)) ;
00195         if (vl_string_casei_cmp("uniform", buffer) == 0) {
00196           windowType = VlHomogeneousKernelMapWindowUniform ;
00197         } else if (vl_string_casei_cmp("rectangular", buffer) == 0) {
00198           windowType = VlHomogeneousKernelMapWindowRectangular;
00199         } else {
00200           vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.WINDOW is not a recognized window type.") ;
00201         }
00202       }
00203 
00204       field = mxGetField(hom_array, 0, "gamma") ;
00205       if (field != NULL) {
00206         if (! vlmxIsPlainScalar(field)) {
00207           vlmxError(vlmxErrInvalidArgument, "GAMMA is not a scalar.") ;
00208         }
00209         gamma = *mxGetPr(field) ;
00210         if (gamma <= 0) {
00211           vlmxError(vlmxErrInvalidArgument, "GAMMA is not positive.") ;
00212         }
00213       }
00214 
00215       field = mxGetField(hom_array, 0, "period") ;
00216       if (field != NULL) {
00217         if (! vlmxIsPlainScalar(field)) {
00218           vlmxError(vlmxErrInvalidArgument, "PERIOD is not a scalar.") ;
00219         }
00220         period = *mxGetPr(field) ;
00221         if (period <= 0) {
00222           vlmxError(vlmxErrInvalidArgument, "PERIOD is not positive.") ;
00223         }
00224       }
00225 
00226       hom = vl_homogeneouskernelmap_new (kernelType, gamma, n, period, windowType) ;
00227       vl_svmdataset_set_homogeneous_kernel_map (dataset, hom) ;
00228     }
00229   }
00230   return dataset ;
00231 }
00232 
00233 /* ---------------------------------------------------------------- */
00234 /*                                               Diagnostic helpers */
00235 /* ---------------------------------------------------------------- */
00236 
00237 mxArray * makeInfoStruct (VlSvm* svm)
00238 {
00239   VlSvmStatistics const * s = vl_svm_get_statistics(svm) ;
00240   mxArray * info = 0 ;
00241 
00242   switch (vl_svm_get_solver(svm)) {
00243     case VlSvmSolverSdca:
00244     {
00245       void const * fields [] = {
00246         "solver", mxCreateString("sdca"),
00247         "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)),
00248         "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)),
00249         "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)),
00250         "objective", vlmxCreatePlainScalar(s->objective),
00251         "regularizer", vlmxCreatePlainScalar(s->regularizer),
00252         "loss", vlmxCreatePlainScalar(s->loss),
00253         "dualObjective", vlmxCreatePlainScalar(s->dualObjective),
00254         "dualLoss", vlmxCreatePlainScalar(s->dualLoss),
00255         "dualityGap", vlmxCreatePlainScalar(s->dualityGap),
00256         "iteration", vlmxCreatePlainScalar(s->iteration),
00257         "epoch", vlmxCreatePlainScalar(s->epoch),
00258         "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime),
00259         0, 0
00260       } ;
00261       info = createScalarStructArray(fields) ;
00262       break ;
00263     }
00264 
00265     case VlSvmSolverSgd:
00266     {
00267       void const * fields [] = {
00268         "solver", mxCreateString("sgd"),
00269         "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)),
00270         "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)),
00271         "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)),
00272         "objective", vlmxCreatePlainScalar(s->objective),
00273         "regularizer", vlmxCreatePlainScalar(s->regularizer),
00274         "loss", vlmxCreatePlainScalar(s->loss),
00275         "scoreVariation", vlmxCreatePlainScalar(s->scoresVariation),
00276         "iteration", vlmxCreatePlainScalar(s->iteration),
00277         "epoch", vlmxCreatePlainScalar(s->epoch),
00278         "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime),
00279         0, 0
00280       } ;
00281       info = createScalarStructArray(fields) ;
00282       break ;
00283     }
00284 
00285     case VlSvmSolverNone :
00286     {
00287       void const * fields [] = {
00288         "solver", mxCreateString("none"),
00289         "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)),
00290         "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)),
00291         "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)),
00292         "objective", vlmxCreatePlainScalar(s->objective),
00293         "regularizer", vlmxCreatePlainScalar(s->regularizer),
00294         "loss", vlmxCreatePlainScalar(s->loss),
00295         "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime),
00296         0, 0
00297       } ;
00298       info = createScalarStructArray(fields) ;
00299       break ;
00300     }
00301 
00302   default:
00303     assert(0) ;
00304   }
00305   return info ;
00306 }
00307 
00308 /* ---------------------------------------------------------------- */
00309 /*                                          SVM diagnostic callback */
00310 /* ---------------------------------------------------------------- */
00311 
00312 typedef struct DiagnsoticOpts_
00313 {
00314   vl_bool verbose ;
00315   mxArray const * matlabDiagonsticFunctionHandle ;
00316 } DiagnosticOpts ;
00317 
00318 void diagnostic (VlSvm * svm, DiagnosticOpts * opts)
00319 {
00320   VlSvmStatistics const * s = vl_svm_get_statistics(svm) ;
00321   if ((opts->verbose && s->status != VlSvmStatusTraining) || (opts->verbose > 1)) {
00322     const char * statusName = 0 ;
00323     switch (s->status) {
00324       case VlSvmStatusTraining: statusName = "training" ; break ;
00325       case VlSvmStatusConverged: statusName = "converged" ; break ;
00326       case VlSvmStatusMaxNumIterationsReached: statusName = "max num iterations reached" ; break ;
00327     }
00328     mexPrintf("vl_svmtrain: iteration: %d (epoch: %d)\n", s->iteration+1, s->epoch+1) ;
00329     mexPrintf("\ttime elapsed: %f\n", s->elapsedTime) ;
00330     mexPrintf("\tobjective: %g (regul: %g, loss: %g)\n", s->objective, s->regularizer, s->loss) ;
00331     switch (vl_svm_get_solver(svm)) {
00332       case VlSvmSolverSgd:
00333         mexPrintf("\tscore variation: %f\n", s->scoresVariation) ;
00334         break;
00335 
00336       case VlSvmSolverSdca:
00337         mexPrintf("\tdual objective: %g (dual loss: %g)\n", s->dualObjective, s->dualLoss) ;
00338         mexPrintf("\tduality gap: %g\n", s->dualityGap) ;
00339         break;
00340 
00341       default:
00342         break;
00343     }
00344     mexPrintf("\tstatus: %s\n", statusName) ;
00345   }
00346   if (opts->matlabDiagonsticFunctionHandle) {
00347     mxArray *rhs[2] ;
00348     rhs[0] = (mxArray*) opts->matlabDiagonsticFunctionHandle ;
00349     rhs[1] = makeInfoStruct(svm) ;
00350     if (mxIsClass(rhs[0] , "function_handle")) {
00351       mexCallMATLAB(0,NULL,sizeof(rhs)/sizeof(rhs[0]),rhs,"feval") ;
00352     }
00353     mxDestroyArray(rhs[1]) ;
00354   }
00355 }
00356 
00357 /* ---------------------------------------------------------------- */
00358 /*                                                  MEX entry point */
00359 /* ---------------------------------------------------------------- */
00360 
00361 void
00362 mexFunction(int nout, mxArray *out[],
00363             int nin, const mxArray *in[])
00364 {
00365   enum {IN_DATASET = 0, IN_LABELS, IN_LAMBDA, IN_END} ;
00366   enum {OUT_MODEL = 0, OUT_BIAS, OUT_INFO, OUT_SCORES, OUT_END} ;
00367 
00368   vl_int opt, next;
00369   mxArray const *optarg ;
00370 
00371   VlSvmSolverType solver = VlSvmSolverSdca ;
00372   VlSvmLossType loss = VlSvmLossHinge ;
00373   int verbose = 0 ;
00374   VlSvmDataset * dataset ;
00375   double * labels ;
00376   double * weights = NULL ;
00377   double lambda ;
00378 
00379   double epsilon = -1 ;
00380   double biasMultipler = -1 ;
00381   vl_index maxNumIterations = -1 ;
00382   vl_index diagnosticFrequency = -1 ;
00383   mxArray const * matlabDiagnosticFunctionHandle = NULL ;
00384 
00385   mxArray const * initialModel_array = NULL ;
00386   double initialBias = VL_NAN_D ;
00387   vl_index startingIteration = -1 ;
00388 
00389   /* SGD */
00390   double sgdBiasLearningRate = -1 ;
00391 
00392   VL_USE_MATLAB_ENV ;
00393 
00394   if (nin < 3) {
00395     vlmxError(vlmxErrInvalidArgument, "At least three arguments are required.") ;
00396   }
00397   if (nout > OUT_END) {
00398     vlmxError(vlmxErrInvalidArgument, "Too many output arguments.");
00399   }
00400 
00401 #define GET_SCALAR(NAME, variable) \
00402 if (!vlmxIsPlainScalar(optarg)) { \
00403 vlmxError(vlmxErrInvalidArgument, VL_STRINGIFY(NAME) " is not a plain scalar.") ; \
00404 } \
00405 variable = (double) *mxGetPr(optarg);
00406 
00407 #define GET_NN_SCALAR(NAME, variable) GET_SCALAR(NAME, variable) \
00408 if (variable < 0) { \
00409 vlmxError(vlmxErrInvalidArgument, VL_STRINGIFY(NAME) " is negative.") ; \
00410 }
00411 
00412   /* Mode 1: pass data, labels, lambda, and options */
00413   if (mxIsNumeric(in[IN_DATASET]))
00414   {
00415     mxArray const* samples_array = in[IN_DATASET] ;
00416     vl_size dimension ;
00417     vl_size numSamples ;
00418     void * data ;
00419     vl_type dataType ;
00420 
00421     if (!vlmxIsMatrix(samples_array, -1, -1)) {
00422       vlmxError (vlmxErrInvalidArgument,
00423                  "X is not a matrix.") ;
00424     }
00425     if (mxGetClassID(samples_array) == mxDOUBLE_CLASS) {
00426       dataType = VL_TYPE_DOUBLE ;
00427     } else if (mxGetClassID(samples_array) == mxSINGLE_CLASS) {
00428       dataType = VL_TYPE_FLOAT ;
00429     } else {
00430       vlmxError (vlmxErrInvalidArgument, "X is not of class SINGLE or DOUBLE.") ;
00431     }
00432     data = mxGetData(samples_array) ;
00433     dimension = mxGetM(samples_array) ;
00434     numSamples = mxGetN(samples_array) ;
00435     dataset = vl_svmdataset_new(dataType, data, dimension, numSamples) ;
00436   }
00437   /* Mode 2: pass dataset structure */
00438   else {
00439     dataset = parseDataset(in[IN_DATASET]) ;
00440   }
00441 
00442   {
00443     mxArray const* labels_array = in[IN_LABELS] ;
00444     if (!vlmxIsPlainMatrix(labels_array, -1, -1)) {
00445       vlmxError (vlmxErrInvalidArgument, "Y is not a plain matrix.") ;
00446     }
00447     labels = mxGetPr(labels_array) ;
00448     if (mxGetNumberOfElements(labels_array) != vl_svmdataset_get_num_data(dataset)) {
00449       vlmxError  (vlmxErrInvalidArgument,
00450                   "The number of labels Y is not the same as the number of data samples X.") ;
00451     }
00452     optarg = in[IN_LAMBDA] ;
00453     GET_NN_SCALAR(LAMBDA, lambda) ;
00454   }
00455 
00456   /* Parse optional arguments */
00457   next = 3 ;
00458   while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00459     char buf [1024] ;
00460     switch (opt) {
00461       case opt_verbose: verbose ++ ; break ;
00462       case opt_epsilon: GET_NN_SCALAR(EPSLON, epsilon) ; break ;
00463       case opt_bias_multiplier: GET_NN_SCALAR(BIASMULTIPLIER, biasMultipler) ; break ;
00464       case opt_max_num_iterations: GET_NN_SCALAR(MAXNUMITERATIONS, maxNumIterations) ; break ;
00465       case opt_diagnostic_frequency: GET_NN_SCALAR(DIAGNOSTICFREQUENCY, diagnosticFrequency) ; break ;
00466       case opt_diagnostic_function:
00467         if (!mxIsClass(optarg ,"function_handle")) {
00468           mexErrMsgTxt("DIAGNOSTICSFUNCTION is not a function handle.");
00469         }
00470         matlabDiagnosticFunctionHandle = optarg ;
00471         break ;
00472 
00473       case opt_solver :
00474         if (!vlmxIsString (optarg, -1)) {
00475           vlmxError (vlmxErrInvalidArgument,
00476                      "SOLVER must be a string.") ;
00477         }
00478         if (mxGetString (optarg, buf, sizeof(buf))) {
00479           vlmxError (vlmxErrInvalidArgument,
00480                      "SOLVER argument too long.") ;
00481         }
00482         if (vlmxCompareStringsI("sgd", buf) == 0) {
00483           solver = VlSvmSolverSgd ;
00484         } else if (vlmxCompareStringsI("sdca", buf) == 0) {
00485           solver = VlSvmSolverSdca ;
00486         } else if (vlmxCompareStringsI("none", buf) == 0) {
00487           solver = VlSvmSolverNone;
00488         } else {
00489           vlmxError (vlmxErrInvalidArgument,
00490                      "Invalid value %s for SOLVER", buf) ;
00491         }
00492         break ;
00493 
00494       case opt_loss :
00495         if (!vlmxIsString (optarg, -1)) {
00496           vlmxError (vlmxErrInvalidArgument,
00497                      "LOSS must be a string.") ;
00498         }
00499         if (mxGetString (optarg, buf, sizeof(buf))) {
00500           vlmxError (vlmxErrInvalidArgument,
00501                      "LOSS argument too long.") ;
00502         }
00503         if (vlmxCompareStringsI("hinge", buf) == 0) {
00504           loss = VlSvmLossHinge ;
00505         } else if (vlmxCompareStringsI("hinge2", buf) == 0) {
00506           loss = VlSvmLossHinge2 ;
00507         } else if (vlmxCompareStringsI("l1", buf) == 0) {
00508           loss = VlSvmLossL1 ;
00509         } else if (vlmxCompareStringsI("l2", buf) == 0) {
00510           loss = VlSvmLossL2 ;
00511         } else if (vlmxCompareStringsI("logistic", buf) == 0) {
00512           loss = VlSvmLossLogistic ;
00513         } else {
00514           vlmxError (vlmxErrInvalidArgument,
00515                      "Invalid value %s for LOSS", buf) ;
00516         }
00517         break ;
00518 
00519       case opt_model :
00520         if (!vlmxIsPlainVector(optarg, vl_svmdataset_get_dimension(dataset))) {
00521           vlmxError(vlmxErrInvalidArgument, "MODEL is not a plain vector of size equal to the data dimension.") ;
00522         }
00523         initialModel_array = optarg ;
00524         break ;
00525 
00526       case opt_bias: GET_SCALAR(BIAS, initialBias) ; break ;
00527 
00528       case opt_weights:
00529         if (!vlmxIsPlainVector(optarg, vl_svmdataset_get_num_data(dataset))) {
00530           vlmxError(vlmxErrInvalidArgument, "WEIGHTS is not a plain vector of size equal to the number of training samples.") ;
00531         }
00532         weights = mxGetPr(optarg) ;
00533         break ;
00534 
00535       /* SGD specific */
00536       case opt_starting_iteration: GET_NN_SCALAR(STARTINGITERATION, startingIteration) ; break ;
00537       case opt_bias_learning_rate: GET_NN_SCALAR(BIASLEARNINGRATE, sgdBiasLearningRate) ; break ;
00538 
00539       /* DCA specific */
00540     } /* choose option */
00541   } /* next option */
00542 
00543   {
00544     VlSvm * svm = vl_svm_new_with_dataset(solver, dataset, labels, lambda) ;
00545     DiagnosticOpts dopts ;
00546 
00547     if (initialModel_array) {
00548       if (solver != VlSvmSolverNone && solver != VlSvmSolverSgd) {
00549         vlmxError(vlmxErrInvalidArgument, "MODEL cannot be specified with this type of solver.") ;
00550       }
00551       if (mxGetNumberOfElements(initialModel_array) != vl_svm_get_dimension(svm)) {
00552         vlmxError(vlmxErrInvalidArgument, "MODEL has not the same dimension as the data.") ;
00553       }
00554       vl_svm_set_model(svm, mxGetPr(initialModel_array)) ;
00555     }
00556 
00557     if (! vl_is_nan_d(initialBias)) {
00558       if (solver != VlSvmSolverNone && solver != VlSvmSolverSgd) {
00559         vlmxError(vlmxErrInvalidArgument, "BIAS cannot be specified with this type of solver.") ;
00560       }
00561       vl_svm_set_bias(svm, initialBias) ;
00562     }
00563 
00564     if (epsilon >= 0) vl_svm_set_epsilon(svm, epsilon) ;
00565     if (maxNumIterations >= 0) vl_svm_set_max_num_iterations(svm, maxNumIterations) ;
00566     if (biasMultipler >= 0) vl_svm_set_bias_multiplier(svm, biasMultipler) ;
00567     if (sgdBiasLearningRate >= 0) vl_svm_set_bias_learning_rate(svm, sgdBiasLearningRate) ;
00568     if (diagnosticFrequency >= 0) vl_svm_set_diagnostic_frequency(svm, diagnosticFrequency) ;
00569     if (startingIteration >= 0) vl_svm_set_iteration_number(svm, (unsigned)startingIteration) ;
00570     if (weights) vl_svm_set_weights(svm, weights) ;
00571     vl_svm_set_loss (svm, loss) ;
00572 
00573     dopts.verbose = verbose ;
00574     dopts.matlabDiagonsticFunctionHandle = matlabDiagnosticFunctionHandle ;
00575     vl_svm_set_diagnostic_function (svm, (VlSvmDiagnosticFunction)diagnostic, &dopts) ;
00576 
00577     if (verbose) {
00578       double C = 1.0 / (vl_svm_get_lambda(svm) * vl_svm_get_num_data (svm)) ;
00579       char const * lossName = 0 ;
00580       switch (loss) {
00581         case VlSvmLossHinge: lossName = "hinge" ; break ;
00582         case VlSvmLossHinge2: lossName = "hinge2" ; break ;
00583         case VlSvmLossL1: lossName = "l1" ; break ;
00584         case VlSvmLossL2: lossName = "l2" ; break ;
00585         case VlSvmLossLogistic: lossName = "logistic" ; break ;
00586       }
00587       mexPrintf("vl_svmtrain: parameters (verbosity: %d)\n", verbose) ;
00588       mexPrintf("\tdata dimension: %d\n",vl_svmdataset_get_dimension(dataset)) ;
00589       mexPrintf("\tnum samples: %d\n", vl_svmdataset_get_num_data(dataset)) ;
00590       mexPrintf("\tlambda: %g (C equivalent: %g)\n", vl_svm_get_lambda(svm), C) ;
00591       mexPrintf("\tloss function: %s\n", lossName) ;
00592       mexPrintf("\tmax num iterations: %d\n", vl_svm_get_max_num_iterations(svm)) ;
00593       mexPrintf("\tepsilon: %g\n", vl_svm_get_epsilon(svm)) ;
00594       mexPrintf("\tdiagnostic frequency: %d\n", vl_svm_get_diagnostic_frequency(svm)) ;
00595       mexPrintf("\tusing custom weights: %s\n", VL_YESNO(weights)) ;
00596       mexPrintf("\tbias multiplier: %g\n", vl_svm_get_bias_multiplier(svm)) ;
00597       switch (vl_svm_get_solver(svm)) {
00598         case VlSvmSolverNone:
00599           mexPrintf("\tsolver: none (evaluation mode)\n") ;
00600           break ;
00601         case VlSvmSolverSgd:
00602           mexPrintf("\tsolver: sgd\n") ;
00603           mexPrintf("\tbias learning rate: %g\n", vl_svm_get_bias_learning_rate(svm)) ;
00604           break ;
00605         case VlSvmSolverSdca:
00606           mexPrintf("\tsolver: sdca\n") ;
00607           break ;
00608       }
00609     }
00610 
00611     vl_svm_train(svm) ;
00612 
00613     {
00614       mwSize dims[2] ;
00615       dims[0] = vl_svmdataset_get_dimension(dataset) ;
00616       dims[1] = 1 ;
00617       out[OUT_MODEL] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL) ;
00618       memcpy(mxGetPr(out[OUT_MODEL]),
00619              vl_svm_get_model(svm),
00620              vl_svm_get_dimension(svm) * sizeof(double)) ;
00621     }
00622     out[OUT_BIAS] = vlmxCreatePlainScalar(vl_svm_get_bias(svm)) ;
00623     if (nout >= 3) {
00624       out[OUT_INFO] = makeInfoStruct(svm) ;
00625     }
00626     if (nout >= 4) {
00627       mwSize dims[2] ;
00628       dims[0] = 1 ;
00629       dims[1] = vl_svmdataset_get_num_data(dataset) ;
00630       out[OUT_SCORES] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL) ;
00631       memcpy(mxGetPr(out[OUT_SCORES]),
00632              vl_svm_get_scores(svm),
00633              vl_svm_get_num_data(svm) * sizeof(double)) ;
00634     }
00635 
00636 
00637     vl_svm_delete(svm) ;
00638     if (vl_svmdataset_get_homogeneous_kernel_map(dataset)) {
00639       VlHomogeneousKernelMap * hom = vl_svmdataset_get_homogeneous_kernel_map(dataset) ;
00640       vl_svmdataset_set_homogeneous_kernel_map(dataset,0) ;
00641       vl_homogeneouskernelmap_delete(hom) ;
00642     }
00643     vl_svmdataset_delete(dataset) ;
00644   }
00645 }


libvlfeat
Author(s): Andrea Vedaldi
autogenerated on Thu Jun 6 2019 20:25:51