00001
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018 #include <mexutils.h>
00019 #include <vl/svm.h>
00020 #include <vl/mathop.h>
00021 #include <vl/homkermap.h>
00022 #include <vl/stringop.h>
00023 #include <assert.h>
00024 #include <string.h>
00025
00026
00027 enum {
00028
00029 opt_epsilon,
00030 opt_max_num_iterations,
00031 opt_bias_multiplier,
00032 opt_diagnostic_function,
00033 opt_diagnostic_frequency,
00034 opt_validation_subset,
00035 opt_loss,
00036 opt_model,
00037 opt_bias,
00038 opt_weights,
00039
00040
00041 opt_verbose,
00042 opt_solver,
00043
00044
00045 opt_starting_iteration,
00046 opt_bias_learning_rate
00047
00048
00049 };
00050
00051
00052
00053 vlmxOption options [] = {
00054 {"Epsilon", 1, opt_epsilon },
00055 {"MaxNumIterations", 1, opt_max_num_iterations },
00056 {"BiasMultiplier", 1, opt_bias_multiplier },
00057 {"DiagnosticFunction", 1, opt_diagnostic_function },
00058 {"DiagnosticFrequency", 1, opt_diagnostic_frequency},
00059 {"ValidationSubset", 1, opt_validation_subset },
00060 {"Loss", 1, opt_loss },
00061 {"Verbose", 0, opt_verbose },
00062 {"Solver", 1, opt_solver },
00063 {"Model", 1, opt_model },
00064 {"Bias", 1, opt_bias },
00065 {"Weights", 1, opt_weights },
00066
00067
00068 {"StartingIteration", 1, opt_starting_iteration },
00069 {"BiasLearningRate", 1, opt_bias_learning_rate },
00070
00071
00072 {0, 0, 0 }
00073 } ;
00074
00075 mxArray * createScalarStructArray(void const **fields)
00076 {
00077 void const **iter ;
00078 char const **niter ;
00079 char const **names ;
00080 vl_size numFields = 0 ;
00081 mxArray * s ;
00082 mwSize dims [] = {1, 1} ;
00083
00084 for (iter = fields ; *iter ; iter += 2) numFields++ ;
00085
00086 names = vl_calloc(numFields, sizeof(char const*)) ;
00087
00088 for (iter = fields, niter = names ; *iter ; iter += 2, niter++) {
00089 *niter = *iter ;
00090 }
00091
00092 s = mxCreateStructArray(sizeof(dims)/sizeof(dims[0]),
00093 dims,
00094 (int)numFields,
00095 names) ;
00096 for (iter = fields, niter = names ; *iter; iter += 2, niter++) {
00097 mxSetField(s, 0, *niter, (mxArray*)(*(iter+1))) ;
00098 }
00099 return s ;
00100 }
00101
00102
00103
00104
00105
00106 VlSvmDataset * parseDataset(const mxArray * dataset_array)
00107 {
00108 VlSvmDataset * dataset ;
00109 {
00110 mxArray * data_array ;
00111 mxClassID dataClass ;
00112 vl_size dimension ;
00113 vl_size numData ;
00114 vl_type dataType ;
00115 if (! mxIsStruct(dataset_array)) {
00116 vlmxError(vlmxErrInvalidArgument, "DATASET is not a structure.") ;
00117 }
00118 if (mxGetNumberOfElements(dataset_array) != 1) {
00119 vlmxError(vlmxErrInvalidArgument, "DATASET is not a singleton.") ;
00120 }
00121 data_array = mxGetField(dataset_array, 0, "data") ;
00122 if (data_array == NULL) {
00123 vlmxError(vlmxErrInvalidArgument, "DATASET is missing the DATA field.") ;
00124 }
00125 if (!vlmxIsMatrix(data_array,-1,-1)) {
00126 vlmxError(vlmxErrInvalidArgument,"DATASET.DATA is not a matrix.") ;
00127 }
00128 dimension = mxGetM (data_array) ;
00129 numData = mxGetN (data_array) ;
00130 dataClass = mxGetClassID (data_array) ;
00131
00132 if (dimension == 0 || numData == 0) {
00133 vlmxError(vlmxErrInvalidArgument, "DATASET.DATA is empty.") ;
00134 }
00135
00136 switch (dataClass) {
00137 case mxSINGLE_CLASS : dataType = VL_TYPE_FLOAT ; break ;
00138 case mxDOUBLE_CLASS : dataType = VL_TYPE_DOUBLE ; break ;
00139 default:
00140 vlmxError(vlmxErrInvalidArgument, "DATASET.DATA is neither either SINGLE or DOUBLE.") ;
00141 }
00142 dataset = vl_svmdataset_new(dataType, mxGetData(data_array), dimension, numData) ;
00143 }
00144
00145
00146 {
00147 VlHomogeneousKernelType kernelType = VlHomogeneousKernelChi2 ;
00148 VlHomogeneousKernelMapWindowType windowType = VlHomogeneousKernelMapWindowRectangular ;
00149 double gamma = 1.0 ;
00150 double period = -1 ;
00151 int n = 1 ;
00152 VlHomogeneousKernelMap * hom = NULL ;
00153 mxArray * hom_array ;
00154 mxArray * field ;
00155
00156 hom_array = mxGetField(dataset_array, 0, "homkermap") ;
00157 if (hom_array != NULL)
00158 {
00159 if (!mxIsStruct(hom_array)) {
00160 vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP is not a structure") ;
00161 }
00162
00163 field = mxGetField(hom_array, 0, "order") ;
00164 if (field != NULL) {
00165 if (! vlmxIsPlainScalar(field)) {
00166 vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.ORDER is not a scalar.") ;
00167 }
00168 n = *mxGetPr(field) ;
00169 if (n < 0) {
00170 vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.ORDER is negative.") ;
00171 }
00172 }
00173
00174 field = mxGetField(hom_array, 0, "kernel") ;
00175 if (field != NULL) {
00176 char buffer [1024] ;
00177 mxGetString(field, buffer, sizeof(buffer) / sizeof(char)) ;
00178 if (vl_string_casei_cmp("kl1", buffer) == 0) {
00179 kernelType = VlHomogeneousKernelIntersection ;
00180 } else if (vl_string_casei_cmp("kchi2", buffer) == 0) {
00181 kernelType = VlHomogeneousKernelChi2 ;
00182 } else if (vl_string_casei_cmp("kjs", buffer) == 0) {
00183 kernelType = VlHomogeneousKernelJS ;
00184 } else if (vl_string_casei_cmp("kinters", buffer) == 0) {
00185 kernelType = VlHomogeneousKernelIntersection ;
00186 } else {
00187 vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.KERNEL is not a recognized kernel type.") ;
00188 }
00189 }
00190
00191 field = mxGetField(hom_array, 0, "window") ;
00192 if (field != NULL) {
00193 char buffer [1024] ;
00194 mxGetString(field, buffer, sizeof(buffer) / sizeof(char)) ;
00195 if (vl_string_casei_cmp("uniform", buffer) == 0) {
00196 windowType = VlHomogeneousKernelMapWindowUniform ;
00197 } else if (vl_string_casei_cmp("rectangular", buffer) == 0) {
00198 windowType = VlHomogeneousKernelMapWindowRectangular;
00199 } else {
00200 vlmxError(vlmxErrInvalidArgument, "DATASET.HOMKERMAP.WINDOW is not a recognized window type.") ;
00201 }
00202 }
00203
00204 field = mxGetField(hom_array, 0, "gamma") ;
00205 if (field != NULL) {
00206 if (! vlmxIsPlainScalar(field)) {
00207 vlmxError(vlmxErrInvalidArgument, "GAMMA is not a scalar.") ;
00208 }
00209 gamma = *mxGetPr(field) ;
00210 if (gamma <= 0) {
00211 vlmxError(vlmxErrInvalidArgument, "GAMMA is not positive.") ;
00212 }
00213 }
00214
00215 field = mxGetField(hom_array, 0, "period") ;
00216 if (field != NULL) {
00217 if (! vlmxIsPlainScalar(field)) {
00218 vlmxError(vlmxErrInvalidArgument, "PERIOD is not a scalar.") ;
00219 }
00220 period = *mxGetPr(field) ;
00221 if (period <= 0) {
00222 vlmxError(vlmxErrInvalidArgument, "PERIOD is not positive.") ;
00223 }
00224 }
00225
00226 hom = vl_homogeneouskernelmap_new (kernelType, gamma, n, period, windowType) ;
00227 vl_svmdataset_set_homogeneous_kernel_map (dataset, hom) ;
00228 }
00229 }
00230 return dataset ;
00231 }
00232
00233
00234
00235
00236
00237 mxArray * makeInfoStruct (VlSvm* svm)
00238 {
00239 VlSvmStatistics const * s = vl_svm_get_statistics(svm) ;
00240 mxArray * info = 0 ;
00241
00242 switch (vl_svm_get_solver(svm)) {
00243 case VlSvmSolverSdca:
00244 {
00245 void const * fields [] = {
00246 "solver", mxCreateString("sdca"),
00247 "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)),
00248 "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)),
00249 "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)),
00250 "objective", vlmxCreatePlainScalar(s->objective),
00251 "regularizer", vlmxCreatePlainScalar(s->regularizer),
00252 "loss", vlmxCreatePlainScalar(s->loss),
00253 "dualObjective", vlmxCreatePlainScalar(s->dualObjective),
00254 "dualLoss", vlmxCreatePlainScalar(s->dualLoss),
00255 "dualityGap", vlmxCreatePlainScalar(s->dualityGap),
00256 "iteration", vlmxCreatePlainScalar(s->iteration),
00257 "epoch", vlmxCreatePlainScalar(s->epoch),
00258 "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime),
00259 0, 0
00260 } ;
00261 info = createScalarStructArray(fields) ;
00262 break ;
00263 }
00264
00265 case VlSvmSolverSgd:
00266 {
00267 void const * fields [] = {
00268 "solver", mxCreateString("sgd"),
00269 "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)),
00270 "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)),
00271 "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)),
00272 "objective", vlmxCreatePlainScalar(s->objective),
00273 "regularizer", vlmxCreatePlainScalar(s->regularizer),
00274 "loss", vlmxCreatePlainScalar(s->loss),
00275 "scoreVariation", vlmxCreatePlainScalar(s->scoresVariation),
00276 "iteration", vlmxCreatePlainScalar(s->iteration),
00277 "epoch", vlmxCreatePlainScalar(s->epoch),
00278 "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime),
00279 0, 0
00280 } ;
00281 info = createScalarStructArray(fields) ;
00282 break ;
00283 }
00284
00285 case VlSvmSolverNone :
00286 {
00287 void const * fields [] = {
00288 "solver", mxCreateString("none"),
00289 "lambda", vlmxCreatePlainScalar(vl_svm_get_lambda(svm)),
00290 "biasMultiplier", vlmxCreatePlainScalar(vl_svm_get_bias_multiplier(svm)),
00291 "bias", vlmxCreatePlainScalar(vl_svm_get_bias(svm)),
00292 "objective", vlmxCreatePlainScalar(s->objective),
00293 "regularizer", vlmxCreatePlainScalar(s->regularizer),
00294 "loss", vlmxCreatePlainScalar(s->loss),
00295 "elapsedTime", vlmxCreatePlainScalar(s->elapsedTime),
00296 0, 0
00297 } ;
00298 info = createScalarStructArray(fields) ;
00299 break ;
00300 }
00301
00302 default:
00303 assert(0) ;
00304 }
00305 return info ;
00306 }
00307
00308
00309
00310
00311
00312 typedef struct DiagnsoticOpts_
00313 {
00314 vl_bool verbose ;
00315 mxArray const * matlabDiagonsticFunctionHandle ;
00316 } DiagnosticOpts ;
00317
00318 void diagnostic (VlSvm * svm, DiagnosticOpts * opts)
00319 {
00320 VlSvmStatistics const * s = vl_svm_get_statistics(svm) ;
00321 if ((opts->verbose && s->status != VlSvmStatusTraining) || (opts->verbose > 1)) {
00322 const char * statusName = 0 ;
00323 switch (s->status) {
00324 case VlSvmStatusTraining: statusName = "training" ; break ;
00325 case VlSvmStatusConverged: statusName = "converged" ; break ;
00326 case VlSvmStatusMaxNumIterationsReached: statusName = "max num iterations reached" ; break ;
00327 }
00328 mexPrintf("vl_svmtrain: iteration: %d (epoch: %d)\n", s->iteration+1, s->epoch+1) ;
00329 mexPrintf("\ttime elapsed: %f\n", s->elapsedTime) ;
00330 mexPrintf("\tobjective: %g (regul: %g, loss: %g)\n", s->objective, s->regularizer, s->loss) ;
00331 switch (vl_svm_get_solver(svm)) {
00332 case VlSvmSolverSgd:
00333 mexPrintf("\tscore variation: %f\n", s->scoresVariation) ;
00334 break;
00335
00336 case VlSvmSolverSdca:
00337 mexPrintf("\tdual objective: %g (dual loss: %g)\n", s->dualObjective, s->dualLoss) ;
00338 mexPrintf("\tduality gap: %g\n", s->dualityGap) ;
00339 break;
00340
00341 default:
00342 break;
00343 }
00344 mexPrintf("\tstatus: %s\n", statusName) ;
00345 }
00346 if (opts->matlabDiagonsticFunctionHandle) {
00347 mxArray *rhs[2] ;
00348 rhs[0] = (mxArray*) opts->matlabDiagonsticFunctionHandle ;
00349 rhs[1] = makeInfoStruct(svm) ;
00350 if (mxIsClass(rhs[0] , "function_handle")) {
00351 mexCallMATLAB(0,NULL,sizeof(rhs)/sizeof(rhs[0]),rhs,"feval") ;
00352 }
00353 mxDestroyArray(rhs[1]) ;
00354 }
00355 }
00356
00357
00358
00359
00360
00361 void
00362 mexFunction(int nout, mxArray *out[],
00363 int nin, const mxArray *in[])
00364 {
00365 enum {IN_DATASET = 0, IN_LABELS, IN_LAMBDA, IN_END} ;
00366 enum {OUT_MODEL = 0, OUT_BIAS, OUT_INFO, OUT_SCORES, OUT_END} ;
00367
00368 vl_int opt, next;
00369 mxArray const *optarg ;
00370
00371 VlSvmSolverType solver = VlSvmSolverSdca ;
00372 VlSvmLossType loss = VlSvmLossHinge ;
00373 int verbose = 0 ;
00374 VlSvmDataset * dataset ;
00375 double * labels ;
00376 double * weights = NULL ;
00377 double lambda ;
00378
00379 double epsilon = -1 ;
00380 double biasMultipler = -1 ;
00381 vl_index maxNumIterations = -1 ;
00382 vl_index diagnosticFrequency = -1 ;
00383 mxArray const * matlabDiagnosticFunctionHandle = NULL ;
00384
00385 mxArray const * initialModel_array = NULL ;
00386 double initialBias = VL_NAN_D ;
00387 vl_index startingIteration = -1 ;
00388
00389
00390 double sgdBiasLearningRate = -1 ;
00391
00392 VL_USE_MATLAB_ENV ;
00393
00394 if (nin < 3) {
00395 vlmxError(vlmxErrInvalidArgument, "At least three arguments are required.") ;
00396 }
00397 if (nout > OUT_END) {
00398 vlmxError(vlmxErrInvalidArgument, "Too many output arguments.");
00399 }
00400
00401 #define GET_SCALAR(NAME, variable) \
00402 if (!vlmxIsPlainScalar(optarg)) { \
00403 vlmxError(vlmxErrInvalidArgument, VL_STRINGIFY(NAME) " is not a plain scalar.") ; \
00404 } \
00405 variable = (double) *mxGetPr(optarg);
00406
00407 #define GET_NN_SCALAR(NAME, variable) GET_SCALAR(NAME, variable) \
00408 if (variable < 0) { \
00409 vlmxError(vlmxErrInvalidArgument, VL_STRINGIFY(NAME) " is negative.") ; \
00410 }
00411
00412
00413 if (mxIsNumeric(in[IN_DATASET]))
00414 {
00415 mxArray const* samples_array = in[IN_DATASET] ;
00416 vl_size dimension ;
00417 vl_size numSamples ;
00418 void * data ;
00419 vl_type dataType ;
00420
00421 if (!vlmxIsMatrix(samples_array, -1, -1)) {
00422 vlmxError (vlmxErrInvalidArgument,
00423 "X is not a matrix.") ;
00424 }
00425 if (mxGetClassID(samples_array) == mxDOUBLE_CLASS) {
00426 dataType = VL_TYPE_DOUBLE ;
00427 } else if (mxGetClassID(samples_array) == mxSINGLE_CLASS) {
00428 dataType = VL_TYPE_FLOAT ;
00429 } else {
00430 vlmxError (vlmxErrInvalidArgument, "X is not of class SINGLE or DOUBLE.") ;
00431 }
00432 data = mxGetData(samples_array) ;
00433 dimension = mxGetM(samples_array) ;
00434 numSamples = mxGetN(samples_array) ;
00435 dataset = vl_svmdataset_new(dataType, data, dimension, numSamples) ;
00436 }
00437
00438 else {
00439 dataset = parseDataset(in[IN_DATASET]) ;
00440 }
00441
00442 {
00443 mxArray const* labels_array = in[IN_LABELS] ;
00444 if (!vlmxIsPlainMatrix(labels_array, -1, -1)) {
00445 vlmxError (vlmxErrInvalidArgument, "Y is not a plain matrix.") ;
00446 }
00447 labels = mxGetPr(labels_array) ;
00448 if (mxGetNumberOfElements(labels_array) != vl_svmdataset_get_num_data(dataset)) {
00449 vlmxError (vlmxErrInvalidArgument,
00450 "The number of labels Y is not the same as the number of data samples X.") ;
00451 }
00452 optarg = in[IN_LAMBDA] ;
00453 GET_NN_SCALAR(LAMBDA, lambda) ;
00454 }
00455
00456
00457 next = 3 ;
00458 while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00459 char buf [1024] ;
00460 switch (opt) {
00461 case opt_verbose: verbose ++ ; break ;
00462 case opt_epsilon: GET_NN_SCALAR(EPSLON, epsilon) ; break ;
00463 case opt_bias_multiplier: GET_NN_SCALAR(BIASMULTIPLIER, biasMultipler) ; break ;
00464 case opt_max_num_iterations: GET_NN_SCALAR(MAXNUMITERATIONS, maxNumIterations) ; break ;
00465 case opt_diagnostic_frequency: GET_NN_SCALAR(DIAGNOSTICFREQUENCY, diagnosticFrequency) ; break ;
00466 case opt_diagnostic_function:
00467 if (!mxIsClass(optarg ,"function_handle")) {
00468 mexErrMsgTxt("DIAGNOSTICSFUNCTION is not a function handle.");
00469 }
00470 matlabDiagnosticFunctionHandle = optarg ;
00471 break ;
00472
00473 case opt_solver :
00474 if (!vlmxIsString (optarg, -1)) {
00475 vlmxError (vlmxErrInvalidArgument,
00476 "SOLVER must be a string.") ;
00477 }
00478 if (mxGetString (optarg, buf, sizeof(buf))) {
00479 vlmxError (vlmxErrInvalidArgument,
00480 "SOLVER argument too long.") ;
00481 }
00482 if (vlmxCompareStringsI("sgd", buf) == 0) {
00483 solver = VlSvmSolverSgd ;
00484 } else if (vlmxCompareStringsI("sdca", buf) == 0) {
00485 solver = VlSvmSolverSdca ;
00486 } else if (vlmxCompareStringsI("none", buf) == 0) {
00487 solver = VlSvmSolverNone;
00488 } else {
00489 vlmxError (vlmxErrInvalidArgument,
00490 "Invalid value %s for SOLVER", buf) ;
00491 }
00492 break ;
00493
00494 case opt_loss :
00495 if (!vlmxIsString (optarg, -1)) {
00496 vlmxError (vlmxErrInvalidArgument,
00497 "LOSS must be a string.") ;
00498 }
00499 if (mxGetString (optarg, buf, sizeof(buf))) {
00500 vlmxError (vlmxErrInvalidArgument,
00501 "LOSS argument too long.") ;
00502 }
00503 if (vlmxCompareStringsI("hinge", buf) == 0) {
00504 loss = VlSvmLossHinge ;
00505 } else if (vlmxCompareStringsI("hinge2", buf) == 0) {
00506 loss = VlSvmLossHinge2 ;
00507 } else if (vlmxCompareStringsI("l1", buf) == 0) {
00508 loss = VlSvmLossL1 ;
00509 } else if (vlmxCompareStringsI("l2", buf) == 0) {
00510 loss = VlSvmLossL2 ;
00511 } else if (vlmxCompareStringsI("logistic", buf) == 0) {
00512 loss = VlSvmLossLogistic ;
00513 } else {
00514 vlmxError (vlmxErrInvalidArgument,
00515 "Invalid value %s for LOSS", buf) ;
00516 }
00517 break ;
00518
00519 case opt_model :
00520 if (!vlmxIsPlainVector(optarg, vl_svmdataset_get_dimension(dataset))) {
00521 vlmxError(vlmxErrInvalidArgument, "MODEL is not a plain vector of size equal to the data dimension.") ;
00522 }
00523 initialModel_array = optarg ;
00524 break ;
00525
00526 case opt_bias: GET_SCALAR(BIAS, initialBias) ; break ;
00527
00528 case opt_weights:
00529 if (!vlmxIsPlainVector(optarg, vl_svmdataset_get_num_data(dataset))) {
00530 vlmxError(vlmxErrInvalidArgument, "WEIGHTS is not a plain vector of size equal to the number of training samples.") ;
00531 }
00532 weights = mxGetPr(optarg) ;
00533 break ;
00534
00535
00536 case opt_starting_iteration: GET_NN_SCALAR(STARTINGITERATION, startingIteration) ; break ;
00537 case opt_bias_learning_rate: GET_NN_SCALAR(BIASLEARNINGRATE, sgdBiasLearningRate) ; break ;
00538
00539
00540 }
00541 }
00542
00543 {
00544 VlSvm * svm = vl_svm_new_with_dataset(solver, dataset, labels, lambda) ;
00545 DiagnosticOpts dopts ;
00546
00547 if (initialModel_array) {
00548 if (solver != VlSvmSolverNone && solver != VlSvmSolverSgd) {
00549 vlmxError(vlmxErrInvalidArgument, "MODEL cannot be specified with this type of solver.") ;
00550 }
00551 if (mxGetNumberOfElements(initialModel_array) != vl_svm_get_dimension(svm)) {
00552 vlmxError(vlmxErrInvalidArgument, "MODEL has not the same dimension as the data.") ;
00553 }
00554 vl_svm_set_model(svm, mxGetPr(initialModel_array)) ;
00555 }
00556
00557 if (! vl_is_nan_d(initialBias)) {
00558 if (solver != VlSvmSolverNone && solver != VlSvmSolverSgd) {
00559 vlmxError(vlmxErrInvalidArgument, "BIAS cannot be specified with this type of solver.") ;
00560 }
00561 vl_svm_set_bias(svm, initialBias) ;
00562 }
00563
00564 if (epsilon >= 0) vl_svm_set_epsilon(svm, epsilon) ;
00565 if (maxNumIterations >= 0) vl_svm_set_max_num_iterations(svm, maxNumIterations) ;
00566 if (biasMultipler >= 0) vl_svm_set_bias_multiplier(svm, biasMultipler) ;
00567 if (sgdBiasLearningRate >= 0) vl_svm_set_bias_learning_rate(svm, sgdBiasLearningRate) ;
00568 if (diagnosticFrequency >= 0) vl_svm_set_diagnostic_frequency(svm, diagnosticFrequency) ;
00569 if (startingIteration >= 0) vl_svm_set_iteration_number(svm, (unsigned)startingIteration) ;
00570 if (weights) vl_svm_set_weights(svm, weights) ;
00571 vl_svm_set_loss (svm, loss) ;
00572
00573 dopts.verbose = verbose ;
00574 dopts.matlabDiagonsticFunctionHandle = matlabDiagnosticFunctionHandle ;
00575 vl_svm_set_diagnostic_function (svm, (VlSvmDiagnosticFunction)diagnostic, &dopts) ;
00576
00577 if (verbose) {
00578 double C = 1.0 / (vl_svm_get_lambda(svm) * vl_svm_get_num_data (svm)) ;
00579 char const * lossName = 0 ;
00580 switch (loss) {
00581 case VlSvmLossHinge: lossName = "hinge" ; break ;
00582 case VlSvmLossHinge2: lossName = "hinge2" ; break ;
00583 case VlSvmLossL1: lossName = "l1" ; break ;
00584 case VlSvmLossL2: lossName = "l2" ; break ;
00585 case VlSvmLossLogistic: lossName = "logistic" ; break ;
00586 }
00587 mexPrintf("vl_svmtrain: parameters (verbosity: %d)\n", verbose) ;
00588 mexPrintf("\tdata dimension: %d\n",vl_svmdataset_get_dimension(dataset)) ;
00589 mexPrintf("\tnum samples: %d\n", vl_svmdataset_get_num_data(dataset)) ;
00590 mexPrintf("\tlambda: %g (C equivalent: %g)\n", vl_svm_get_lambda(svm), C) ;
00591 mexPrintf("\tloss function: %s\n", lossName) ;
00592 mexPrintf("\tmax num iterations: %d\n", vl_svm_get_max_num_iterations(svm)) ;
00593 mexPrintf("\tepsilon: %g\n", vl_svm_get_epsilon(svm)) ;
00594 mexPrintf("\tdiagnostic frequency: %d\n", vl_svm_get_diagnostic_frequency(svm)) ;
00595 mexPrintf("\tusing custom weights: %s\n", VL_YESNO(weights)) ;
00596 mexPrintf("\tbias multiplier: %g\n", vl_svm_get_bias_multiplier(svm)) ;
00597 switch (vl_svm_get_solver(svm)) {
00598 case VlSvmSolverNone:
00599 mexPrintf("\tsolver: none (evaluation mode)\n") ;
00600 break ;
00601 case VlSvmSolverSgd:
00602 mexPrintf("\tsolver: sgd\n") ;
00603 mexPrintf("\tbias learning rate: %g\n", vl_svm_get_bias_learning_rate(svm)) ;
00604 break ;
00605 case VlSvmSolverSdca:
00606 mexPrintf("\tsolver: sdca\n") ;
00607 break ;
00608 }
00609 }
00610
00611 vl_svm_train(svm) ;
00612
00613 {
00614 mwSize dims[2] ;
00615 dims[0] = vl_svmdataset_get_dimension(dataset) ;
00616 dims[1] = 1 ;
00617 out[OUT_MODEL] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL) ;
00618 memcpy(mxGetPr(out[OUT_MODEL]),
00619 vl_svm_get_model(svm),
00620 vl_svm_get_dimension(svm) * sizeof(double)) ;
00621 }
00622 out[OUT_BIAS] = vlmxCreatePlainScalar(vl_svm_get_bias(svm)) ;
00623 if (nout >= 3) {
00624 out[OUT_INFO] = makeInfoStruct(svm) ;
00625 }
00626 if (nout >= 4) {
00627 mwSize dims[2] ;
00628 dims[0] = 1 ;
00629 dims[1] = vl_svmdataset_get_num_data(dataset) ;
00630 out[OUT_SCORES] = mxCreateNumericArray(2, dims, mxDOUBLE_CLASS, mxREAL) ;
00631 memcpy(mxGetPr(out[OUT_SCORES]),
00632 vl_svm_get_scores(svm),
00633 vl_svm_get_num_data(svm) * sizeof(double)) ;
00634 }
00635
00636
00637 vl_svm_delete(svm) ;
00638 if (vl_svmdataset_get_homogeneous_kernel_map(dataset)) {
00639 VlHomogeneousKernelMap * hom = vl_svmdataset_get_homogeneous_kernel_map(dataset) ;
00640 vl_svmdataset_set_homogeneous_kernel_map(dataset,0) ;
00641 vl_homogeneouskernelmap_delete(hom) ;
00642 }
00643 vl_svmdataset_delete(dataset) ;
00644 }
00645 }