vl_kmeans.c
Go to the documentation of this file.
00001 
00006 /*
00007 Copyright (C) 2007-12 Andrea Vedaldi and Brian Fulkerson.
00008 All rights reserved.
00009 
00010 This file is part of the VLFeat library and is made available under
00011 the terms of the BSD license (see the COPYING file).
00012 */
00013 
00014 #include <vl/kmeans.h>
00015 #include <mexutils.h>
00016 #include <string.h>
00017 #include <stdio.h>
00018 
00019 enum {
00020   opt_max_num_iterations,
00021   opt_algorithm,
00022   opt_distance,
00023   opt_initialization,
00024   opt_num_repetitions,
00025   opt_verbose,
00026   opt_num_comparisons,
00027   opt_min_energy_variation,
00028   opt_num_trees,
00029   opt_multithreading
00030 } ;
00031 
00032 enum {
00033   INIT_RANDSEL,
00034   INIT_PLUSPLUS
00035 } ;
00036 
00037 vlmxOption  options [] = {
00038   {"MaxNumIterations",  1,   opt_max_num_iterations  },
00039   {"Algorithm",         1,   opt_algorithm           },
00040   {"Distance",          1,   opt_distance            },
00041   {"Verbose",           0,   opt_verbose             },
00042   {"NumRepetitions",    1,   opt_num_repetitions,    },
00043   {"Initialization",    1,   opt_initialization      },
00044   {"Initialisation",    1,   opt_initialization      }, /* UK spelling */
00045   {"NumTrees",          1,   opt_num_trees           },
00046   {"MaxNumComparisons", 1,   opt_num_comparisons     },
00047   {"MinEnergyVariation",1,   opt_min_energy_variation},
00048   {0,                   0,   0                       }
00049 } ;
00050 
00051 /* driver */
00052 void
00053 mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
00054 {
00055 
00056   enum {IN_DATA = 0, IN_NUMCENTERS, IN_END} ;
00057   enum {OUT_CENTERS = 0, OUT_ASSIGNMENTS, OUT_ENERGY} ;
00058 
00059   int opt ;
00060   int next = IN_END ;
00061   mxArray const  *optarg ;
00062 
00063   vl_size numCenters ;
00064   vl_size dimension ;
00065   vl_size numData ;
00066 
00067   void const * data = NULL ;
00068 
00069   VlKMeansAlgorithm algorithm = VlKMeansLloyd ;
00070   VlVectorComparisonType distance = VlDistanceL2 ;
00071   vl_size maxNumIterations = 100 ;
00072   vl_size numRepetitions = 1 ;
00073   double minEnergyVariation = -1 ;
00074   double energy ;
00075   int verbosity = 0 ;
00076   int initialization = INIT_PLUSPLUS ;
00077   vl_size maxNumComparisons = 100 ;
00078   vl_size numTrees = 3;
00079 
00080   vl_type dataType ;
00081   mxClassID classID ;
00082 
00083   VlKMeans * kmeans ;
00084 
00085   VL_USE_MATLAB_ENV ;
00086 
00087   /* -----------------------------------------------------------------
00088    *                                               Check the arguments
00089    * -------------------------------------------------------------- */
00090 
00091   if (nin < 2) {
00092     vlmxError (vlmxErrInvalidArgument,
00093               "At least two arguments required.");
00094   }
00095   else if (nout > 3) {
00096     vlmxError (vlmxErrInvalidArgument,
00097               "Too many output arguments.");
00098   }
00099 
00100   classID = mxGetClassID (IN(DATA)) ;
00101   switch (classID) {
00102     case mxSINGLE_CLASS: dataType = VL_TYPE_FLOAT ; break ;
00103     case mxDOUBLE_CLASS: dataType = VL_TYPE_DOUBLE ; break ;
00104     default:
00105       vlmxError (vlmxErrInvalidArgument,
00106                 "DATA must be of class SINGLE or DOUBLE") ;
00107       abort() ;
00108   }
00109 
00110   dimension = mxGetM (IN(DATA)) ;
00111   numData = mxGetN (IN(DATA)) ;
00112 
00113   if (dimension == 0) {
00114     vlmxError (vlmxErrInvalidArgument, "SIZE(DATA,1) is zero") ;
00115   }
00116 
00117   if (!vlmxIsPlainScalar(IN(NUMCENTERS)) ||
00118       (numCenters = (vl_size) mxGetScalar(IN(NUMCENTERS))) < 1  ||
00119       numCenters > numData) {
00120     vlmxError (vlmxErrInvalidArgument,
00121               "NUMCENTERS must be a positive integer not greater "
00122               "than the number of data.") ;
00123   }
00124 
00125   while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00126     char buf [1024] ;
00127 
00128     switch (opt) {
00129 
00130       case opt_verbose :
00131         ++ verbosity ;
00132         break ;
00133 
00134       case opt_max_num_iterations :
00135         if (!vlmxIsPlainScalar(optarg) || mxGetScalar(optarg) < 0) {
00136           vlmxError (vlmxErrInvalidArgument,
00137                     "MAXNUMITERATIONS must be a non-negative integer scalar") ;
00138         }
00139         maxNumIterations = (vl_size) mxGetScalar(optarg) ;
00140         break ;
00141         
00142       case opt_min_energy_variation :
00143         if (!vlmxIsPlainScalar(optarg) || mxGetScalar(optarg) < 0) {
00144           vlmxError (vlmxErrInvalidArgument,
00145                      "MINENERGYVARIATION must be a non-negative scalar") ;
00146         }
00147         minEnergyVariation = mxGetScalar(optarg) ;
00148         break ;
00149 
00150       case opt_algorithm :
00151         if (!vlmxIsString (optarg, -1)) {
00152           vlmxError (vlmxErrInvalidArgument,
00153                     "ALGORITHM must be a string.") ;
00154         }
00155         if (mxGetString (optarg, buf, sizeof(buf))) {
00156           vlmxError (vlmxErrInvalidArgument,
00157                     "ALGORITHM argument too long.") ;
00158         }
00159         if (vlmxCompareStringsI("lloyd", buf) == 0) {
00160           algorithm = VlKMeansLloyd ;
00161         } else if (vlmxCompareStringsI("elkan", buf) == 0) {
00162           algorithm = VlKMeansElkan ;
00163         } else if (vlmxCompareStringsI("ann", buf) == 0) {
00164           algorithm = VlKMeansANN ;
00165         } else {
00166           vlmxError (vlmxErrInvalidArgument,
00167                     "Invalid value %s for ALGORITHM", buf) ;
00168         }
00169         break ;
00170 
00171       case opt_initialization :
00172         if (!vlmxIsString (optarg, -1)) {
00173           vlmxError (vlmxErrInvalidArgument,
00174                     "INITLAIZATION must be a string.") ;
00175         }
00176         if (mxGetString (optarg, buf, sizeof(buf))) {
00177           vlmxError (vlmxErrInvalidArgument,
00178                     "INITIALIZATION argument too long.") ;
00179         }
00180         if (vlmxCompareStringsI("plusplus", buf) == 0 ||
00181             vlmxCompareStringsI("++", buf) == 0) {
00182           initialization = VlKMeansPlusPlus ;
00183         } else if (vlmxCompareStringsI("randsel", buf) == 0) {
00184           initialization = VlKMeansRandomSelection ;
00185         } else {
00186           vlmxError (vlmxErrInvalidArgument,
00187                     "Invalid value %s for INITIALISATION.", buf) ;
00188         }
00189         break ;
00190 
00191       case opt_distance :
00192         if (!vlmxIsString (optarg, -1)) {
00193           vlmxError (vlmxErrInvalidArgument,
00194                     "DISTANCE must be a string.") ;
00195         }
00196         if (mxGetString (optarg, buf, sizeof(buf))) {
00197           vlmxError (vlmxErrInvalidArgument,
00198                     "DISTANCE argument too long.") ;
00199         }
00200         if (vlmxCompareStringsI("l2", buf) == 0) {
00201           distance = VlDistanceL2 ;
00202         } else if (vlmxCompareStringsI("l1", buf) == 0) {
00203           distance = VlDistanceL1 ;
00204         } else if (vlmxCompareStringsI("chi2", buf) == 0) {
00205           distance = VlDistanceChi2 ;
00206         } else {
00207           vlmxError (vlmxErrInvalidArgument,
00208                     "Invalid value %s for DISTANCE", buf) ;
00209         }
00210         break ;
00211 
00212       case opt_num_repetitions :
00213         if (!vlmxIsPlainScalar (optarg)) {
00214           vlmxError (vlmxErrInvalidArgument,
00215                      "NUMREPETITIONS must be a scalar.") ;
00216         }
00217         if (mxGetScalar (optarg) < 1) {
00218           vlmxError (vlmxErrInvalidArgument,
00219                      "NUMREPETITIONS must be larger than or equal to 1.") ;
00220         }
00221         numRepetitions = (vl_size) mxGetScalar (optarg) ;
00222         break ;
00223 
00224        case opt_num_trees :
00225             if (!vlmxIsPlainScalar (optarg)) {
00226               vlmxError (vlmxErrInvalidArgument,
00227                      "NUMTREES must be a scalar.") ;
00228             }
00229             if (mxGetScalar (optarg) < 1) {
00230               vlmxError (vlmxErrInvalidArgument,
00231                     "NUMTREES must be larger than or equal to 1.") ;
00232             }
00233             numTrees = (vl_size) mxGetScalar (optarg) ;
00234          break;
00235 
00236 
00237        case opt_num_comparisons :
00238             if (!vlmxIsPlainScalar (optarg)) {
00239               vlmxError (vlmxErrInvalidArgument,
00240                      "NUMCOMPARISONS must be a scalar.") ;
00241             }
00242             if (mxGetScalar (optarg) < 0) {
00243               vlmxError (vlmxErrInvalidArgument,
00244                     "NUMCOMPARISONS must be larger than or equal to 0.") ;
00245             }
00246             maxNumComparisons = (vl_size) mxGetScalar (optarg) ;
00247          break;
00248 
00249       default :
00250         abort() ;
00251         break ;
00252     }
00253   }
00254 
00255   /* -----------------------------------------------------------------
00256    *                                                        Do the job
00257    * -------------------------------------------------------------- */
00258 
00259   data = mxGetPr(IN(DATA)) ;
00260 
00261   kmeans = vl_kmeans_new (dataType, distance) ;
00262 
00263   vl_kmeans_set_verbosity (kmeans, verbosity) ;
00264   vl_kmeans_set_num_repetitions (kmeans, numRepetitions) ;
00265   vl_kmeans_set_algorithm (kmeans, algorithm) ;
00266   vl_kmeans_set_initialization (kmeans, initialization) ;
00267   vl_kmeans_set_max_num_iterations (kmeans, maxNumIterations) ;
00268   vl_kmeans_set_max_num_comparisons (kmeans, maxNumComparisons) ;
00269   vl_kmeans_set_num_trees (kmeans, numTrees);
00270   
00271   if (minEnergyVariation >= 0) {
00272     vl_kmeans_set_min_energy_variation (kmeans, minEnergyVariation) ;
00273   }
00274 
00275   if (verbosity) {
00276     char const * algorithmName = 0 ;
00277     char const * initializationName = 0 ;
00278 
00279     switch (vl_kmeans_get_algorithm(kmeans)) {
00280       case VlKMeansLloyd: algorithmName = "Lloyd" ; break ;
00281       case VlKMeansElkan: algorithmName = "Elkan" ; break ;
00282       case VlKMeansANN:   algorithmName = "ANN" ; break ;
00283       default : abort() ;
00284     }
00285     switch (vl_kmeans_get_initialization(kmeans)) {
00286       case VlKMeansPlusPlus : initializationName = "plusplus" ; break ;
00287       case VlKMeansRandomSelection : initializationName = "randsel" ; break ;
00288       default: abort() ;
00289     }
00290     mexPrintf("kmeans: Initialization = %s\n", initializationName) ;
00291     mexPrintf("kmeans: Algorithm = %s\n", algorithmName) ;
00292     mexPrintf("kmeans: MaxNumIterations = %d\n", vl_kmeans_get_max_num_iterations(kmeans)) ;
00293     mexPrintf("kmeans: MinEnergyVariation = %f\n", vl_kmeans_get_min_energy_variation(kmeans)) ;
00294     mexPrintf("kmeans: NumRepetitions = %d\n", vl_kmeans_get_num_repetitions(kmeans)) ;
00295     mexPrintf("kmeans: data type = %s\n", vl_get_type_name(vl_kmeans_get_data_type(kmeans))) ;
00296     mexPrintf("kmeans: distance = %s\n", vl_get_vector_comparison_type_name(vl_kmeans_get_distance(kmeans))) ;
00297     mexPrintf("kmeans: data dimension = %d\n", dimension) ;
00298     mexPrintf("kmeans: num. data points = %d\n", numData) ;
00299     mexPrintf("kmeans: num. centers = %d\n", numCenters) ;
00300     mexPrintf("kmeans: max num. comparisons = %d\n", maxNumComparisons) ;
00301     mexPrintf("kmeans: num. trees = %d\n", numTrees) ;
00302     mexPrintf("\n") ;
00303   }
00304 
00305   /* -------------------------------------------------------------- */
00306   /*                                    Clustering and quantization */
00307   /* -------------------------------------------------------------- */
00308 
00309   energy = vl_kmeans_cluster(kmeans, data, dimension, numData, numCenters) ;
00310 
00311   /* copy centers */
00312   OUT(CENTERS) = mxCreateNumericMatrix (dimension, numCenters, classID, mxREAL) ;
00313   memcpy (mxGetData(OUT(CENTERS)),
00314           vl_kmeans_get_centers (kmeans),
00315           vl_get_type_size (dataType) * dimension * vl_kmeans_get_num_centers(kmeans)) ;
00316 
00317   /* optionally qunatize */
00318   if (nout > 1) {
00319     vl_uindex j ;
00320     vl_uint32 * assignments  ;
00321     OUT(ASSIGNMENTS) = mxCreateNumericMatrix (1, numData, mxUINT32_CLASS, mxREAL) ;
00322     assignments = mxGetData (OUT(ASSIGNMENTS)) ;
00323 
00324     vl_kmeans_quantize (kmeans, assignments, NULL, data, numData) ;
00325 
00326     /* use MATLAB indexing convention */
00327     for (j = 0 ; j < numData ; ++j) { assignments[j] += 1 ; }
00328   }
00329 
00330   /* optionally return energy */
00331   if (nout > 2) {
00332     OUT(ENERGY) = vlmxCreatePlainScalar (energy) ;
00333   }
00334 
00335   vl_kmeans_delete (kmeans) ;
00336 }


libvlfeat
Author(s): Andrea Vedaldi
autogenerated on Thu Jun 6 2019 20:25:51