00001
00006
00007
00008
00009
00010
00011
00012
00013
00014 #include <vl/kmeans.h>
00015 #include <mexutils.h>
00016 #include <string.h>
00017 #include <stdio.h>
00018
00019 enum {
00020 opt_max_num_iterations,
00021 opt_algorithm,
00022 opt_distance,
00023 opt_initialization,
00024 opt_num_repetitions,
00025 opt_verbose,
00026 opt_num_comparisons,
00027 opt_min_energy_variation,
00028 opt_num_trees,
00029 opt_multithreading
00030 } ;
00031
00032 enum {
00033 INIT_RANDSEL,
00034 INIT_PLUSPLUS
00035 } ;
00036
00037 vlmxOption options [] = {
00038 {"MaxNumIterations", 1, opt_max_num_iterations },
00039 {"Algorithm", 1, opt_algorithm },
00040 {"Distance", 1, opt_distance },
00041 {"Verbose", 0, opt_verbose },
00042 {"NumRepetitions", 1, opt_num_repetitions, },
00043 {"Initialization", 1, opt_initialization },
00044 {"Initialisation", 1, opt_initialization },
00045 {"NumTrees", 1, opt_num_trees },
00046 {"MaxNumComparisons", 1, opt_num_comparisons },
00047 {"MinEnergyVariation",1, opt_min_energy_variation},
00048 {0, 0, 0 }
00049 } ;
00050
00051
00052 void
00053 mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
00054 {
00055
00056 enum {IN_DATA = 0, IN_NUMCENTERS, IN_END} ;
00057 enum {OUT_CENTERS = 0, OUT_ASSIGNMENTS, OUT_ENERGY} ;
00058
00059 int opt ;
00060 int next = IN_END ;
00061 mxArray const *optarg ;
00062
00063 vl_size numCenters ;
00064 vl_size dimension ;
00065 vl_size numData ;
00066
00067 void const * data = NULL ;
00068
00069 VlKMeansAlgorithm algorithm = VlKMeansLloyd ;
00070 VlVectorComparisonType distance = VlDistanceL2 ;
00071 vl_size maxNumIterations = 100 ;
00072 vl_size numRepetitions = 1 ;
00073 double minEnergyVariation = -1 ;
00074 double energy ;
00075 int verbosity = 0 ;
00076 int initialization = INIT_PLUSPLUS ;
00077 vl_size maxNumComparisons = 100 ;
00078 vl_size numTrees = 3;
00079
00080 vl_type dataType ;
00081 mxClassID classID ;
00082
00083 VlKMeans * kmeans ;
00084
00085 VL_USE_MATLAB_ENV ;
00086
00087
00088
00089
00090
00091 if (nin < 2) {
00092 vlmxError (vlmxErrInvalidArgument,
00093 "At least two arguments required.");
00094 }
00095 else if (nout > 3) {
00096 vlmxError (vlmxErrInvalidArgument,
00097 "Too many output arguments.");
00098 }
00099
00100 classID = mxGetClassID (IN(DATA)) ;
00101 switch (classID) {
00102 case mxSINGLE_CLASS: dataType = VL_TYPE_FLOAT ; break ;
00103 case mxDOUBLE_CLASS: dataType = VL_TYPE_DOUBLE ; break ;
00104 default:
00105 vlmxError (vlmxErrInvalidArgument,
00106 "DATA must be of class SINGLE or DOUBLE") ;
00107 abort() ;
00108 }
00109
00110 dimension = mxGetM (IN(DATA)) ;
00111 numData = mxGetN (IN(DATA)) ;
00112
00113 if (dimension == 0) {
00114 vlmxError (vlmxErrInvalidArgument, "SIZE(DATA,1) is zero") ;
00115 }
00116
00117 if (!vlmxIsPlainScalar(IN(NUMCENTERS)) ||
00118 (numCenters = (vl_size) mxGetScalar(IN(NUMCENTERS))) < 1 ||
00119 numCenters > numData) {
00120 vlmxError (vlmxErrInvalidArgument,
00121 "NUMCENTERS must be a positive integer not greater "
00122 "than the number of data.") ;
00123 }
00124
00125 while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00126 char buf [1024] ;
00127
00128 switch (opt) {
00129
00130 case opt_verbose :
00131 ++ verbosity ;
00132 break ;
00133
00134 case opt_max_num_iterations :
00135 if (!vlmxIsPlainScalar(optarg) || mxGetScalar(optarg) < 0) {
00136 vlmxError (vlmxErrInvalidArgument,
00137 "MAXNUMITERATIONS must be a non-negative integer scalar") ;
00138 }
00139 maxNumIterations = (vl_size) mxGetScalar(optarg) ;
00140 break ;
00141
00142 case opt_min_energy_variation :
00143 if (!vlmxIsPlainScalar(optarg) || mxGetScalar(optarg) < 0) {
00144 vlmxError (vlmxErrInvalidArgument,
00145 "MINENERGYVARIATION must be a non-negative scalar") ;
00146 }
00147 minEnergyVariation = mxGetScalar(optarg) ;
00148 break ;
00149
00150 case opt_algorithm :
00151 if (!vlmxIsString (optarg, -1)) {
00152 vlmxError (vlmxErrInvalidArgument,
00153 "ALGORITHM must be a string.") ;
00154 }
00155 if (mxGetString (optarg, buf, sizeof(buf))) {
00156 vlmxError (vlmxErrInvalidArgument,
00157 "ALGORITHM argument too long.") ;
00158 }
00159 if (vlmxCompareStringsI("lloyd", buf) == 0) {
00160 algorithm = VlKMeansLloyd ;
00161 } else if (vlmxCompareStringsI("elkan", buf) == 0) {
00162 algorithm = VlKMeansElkan ;
00163 } else if (vlmxCompareStringsI("ann", buf) == 0) {
00164 algorithm = VlKMeansANN ;
00165 } else {
00166 vlmxError (vlmxErrInvalidArgument,
00167 "Invalid value %s for ALGORITHM", buf) ;
00168 }
00169 break ;
00170
00171 case opt_initialization :
00172 if (!vlmxIsString (optarg, -1)) {
00173 vlmxError (vlmxErrInvalidArgument,
00174 "INITLAIZATION must be a string.") ;
00175 }
00176 if (mxGetString (optarg, buf, sizeof(buf))) {
00177 vlmxError (vlmxErrInvalidArgument,
00178 "INITIALIZATION argument too long.") ;
00179 }
00180 if (vlmxCompareStringsI("plusplus", buf) == 0 ||
00181 vlmxCompareStringsI("++", buf) == 0) {
00182 initialization = VlKMeansPlusPlus ;
00183 } else if (vlmxCompareStringsI("randsel", buf) == 0) {
00184 initialization = VlKMeansRandomSelection ;
00185 } else {
00186 vlmxError (vlmxErrInvalidArgument,
00187 "Invalid value %s for INITIALISATION.", buf) ;
00188 }
00189 break ;
00190
00191 case opt_distance :
00192 if (!vlmxIsString (optarg, -1)) {
00193 vlmxError (vlmxErrInvalidArgument,
00194 "DISTANCE must be a string.") ;
00195 }
00196 if (mxGetString (optarg, buf, sizeof(buf))) {
00197 vlmxError (vlmxErrInvalidArgument,
00198 "DISTANCE argument too long.") ;
00199 }
00200 if (vlmxCompareStringsI("l2", buf) == 0) {
00201 distance = VlDistanceL2 ;
00202 } else if (vlmxCompareStringsI("l1", buf) == 0) {
00203 distance = VlDistanceL1 ;
00204 } else if (vlmxCompareStringsI("chi2", buf) == 0) {
00205 distance = VlDistanceChi2 ;
00206 } else {
00207 vlmxError (vlmxErrInvalidArgument,
00208 "Invalid value %s for DISTANCE", buf) ;
00209 }
00210 break ;
00211
00212 case opt_num_repetitions :
00213 if (!vlmxIsPlainScalar (optarg)) {
00214 vlmxError (vlmxErrInvalidArgument,
00215 "NUMREPETITIONS must be a scalar.") ;
00216 }
00217 if (mxGetScalar (optarg) < 1) {
00218 vlmxError (vlmxErrInvalidArgument,
00219 "NUMREPETITIONS must be larger than or equal to 1.") ;
00220 }
00221 numRepetitions = (vl_size) mxGetScalar (optarg) ;
00222 break ;
00223
00224 case opt_num_trees :
00225 if (!vlmxIsPlainScalar (optarg)) {
00226 vlmxError (vlmxErrInvalidArgument,
00227 "NUMTREES must be a scalar.") ;
00228 }
00229 if (mxGetScalar (optarg) < 1) {
00230 vlmxError (vlmxErrInvalidArgument,
00231 "NUMTREES must be larger than or equal to 1.") ;
00232 }
00233 numTrees = (vl_size) mxGetScalar (optarg) ;
00234 break;
00235
00236
00237 case opt_num_comparisons :
00238 if (!vlmxIsPlainScalar (optarg)) {
00239 vlmxError (vlmxErrInvalidArgument,
00240 "NUMCOMPARISONS must be a scalar.") ;
00241 }
00242 if (mxGetScalar (optarg) < 0) {
00243 vlmxError (vlmxErrInvalidArgument,
00244 "NUMCOMPARISONS must be larger than or equal to 0.") ;
00245 }
00246 maxNumComparisons = (vl_size) mxGetScalar (optarg) ;
00247 break;
00248
00249 default :
00250 abort() ;
00251 break ;
00252 }
00253 }
00254
00255
00256
00257
00258
00259 data = mxGetPr(IN(DATA)) ;
00260
00261 kmeans = vl_kmeans_new (dataType, distance) ;
00262
00263 vl_kmeans_set_verbosity (kmeans, verbosity) ;
00264 vl_kmeans_set_num_repetitions (kmeans, numRepetitions) ;
00265 vl_kmeans_set_algorithm (kmeans, algorithm) ;
00266 vl_kmeans_set_initialization (kmeans, initialization) ;
00267 vl_kmeans_set_max_num_iterations (kmeans, maxNumIterations) ;
00268 vl_kmeans_set_max_num_comparisons (kmeans, maxNumComparisons) ;
00269 vl_kmeans_set_num_trees (kmeans, numTrees);
00270
00271 if (minEnergyVariation >= 0) {
00272 vl_kmeans_set_min_energy_variation (kmeans, minEnergyVariation) ;
00273 }
00274
00275 if (verbosity) {
00276 char const * algorithmName = 0 ;
00277 char const * initializationName = 0 ;
00278
00279 switch (vl_kmeans_get_algorithm(kmeans)) {
00280 case VlKMeansLloyd: algorithmName = "Lloyd" ; break ;
00281 case VlKMeansElkan: algorithmName = "Elkan" ; break ;
00282 case VlKMeansANN: algorithmName = "ANN" ; break ;
00283 default : abort() ;
00284 }
00285 switch (vl_kmeans_get_initialization(kmeans)) {
00286 case VlKMeansPlusPlus : initializationName = "plusplus" ; break ;
00287 case VlKMeansRandomSelection : initializationName = "randsel" ; break ;
00288 default: abort() ;
00289 }
00290 mexPrintf("kmeans: Initialization = %s\n", initializationName) ;
00291 mexPrintf("kmeans: Algorithm = %s\n", algorithmName) ;
00292 mexPrintf("kmeans: MaxNumIterations = %d\n", vl_kmeans_get_max_num_iterations(kmeans)) ;
00293 mexPrintf("kmeans: MinEnergyVariation = %f\n", vl_kmeans_get_min_energy_variation(kmeans)) ;
00294 mexPrintf("kmeans: NumRepetitions = %d\n", vl_kmeans_get_num_repetitions(kmeans)) ;
00295 mexPrintf("kmeans: data type = %s\n", vl_get_type_name(vl_kmeans_get_data_type(kmeans))) ;
00296 mexPrintf("kmeans: distance = %s\n", vl_get_vector_comparison_type_name(vl_kmeans_get_distance(kmeans))) ;
00297 mexPrintf("kmeans: data dimension = %d\n", dimension) ;
00298 mexPrintf("kmeans: num. data points = %d\n", numData) ;
00299 mexPrintf("kmeans: num. centers = %d\n", numCenters) ;
00300 mexPrintf("kmeans: max num. comparisons = %d\n", maxNumComparisons) ;
00301 mexPrintf("kmeans: num. trees = %d\n", numTrees) ;
00302 mexPrintf("\n") ;
00303 }
00304
00305
00306
00307
00308
00309 energy = vl_kmeans_cluster(kmeans, data, dimension, numData, numCenters) ;
00310
00311
00312 OUT(CENTERS) = mxCreateNumericMatrix (dimension, numCenters, classID, mxREAL) ;
00313 memcpy (mxGetData(OUT(CENTERS)),
00314 vl_kmeans_get_centers (kmeans),
00315 vl_get_type_size (dataType) * dimension * vl_kmeans_get_num_centers(kmeans)) ;
00316
00317
00318 if (nout > 1) {
00319 vl_uindex j ;
00320 vl_uint32 * assignments ;
00321 OUT(ASSIGNMENTS) = mxCreateNumericMatrix (1, numData, mxUINT32_CLASS, mxREAL) ;
00322 assignments = mxGetData (OUT(ASSIGNMENTS)) ;
00323
00324 vl_kmeans_quantize (kmeans, assignments, NULL, data, numData) ;
00325
00326
00327 for (j = 0 ; j < numData ; ++j) { assignments[j] += 1 ; }
00328 }
00329
00330
00331 if (nout > 2) {
00332 OUT(ENERGY) = vlmxCreatePlainScalar (energy) ;
00333 }
00334
00335 vl_kmeans_delete (kmeans) ;
00336 }