vl_hikmeans.c
Go to the documentation of this file.
00001 
00006 /*
00007 Copyright (C) 2014 Andrea Vedaldi.
00008 Copyright (C) 2007-12 Andrea Vedaldi and Brian Fulkerson.
00009 All rights reserved.
00010 
00011 This file is part of the VLFeat library and is made available under
00012 the terms of the BSD license (see the COPYING file).
00013 */
00014 
00015 #include<mexutils.h>
00016 
00017 #include<stdio.h>
00018 #include<stdlib.h>
00019 #include<math.h>
00020 #include<string.h>
00021 #include<assert.h>
00022 
00023 #include <vl/hikmeans.h>
00024 #include <vl/generic.h>
00025 
00026 #define NFIELDS(field_names) (sizeof(field_names)/sizeof(*field_names))
00027 
00028 enum {
00029   opt_max_niters,
00030   opt_method,
00031   opt_verbose
00032 } ;
00033 
00034 vlmxOption  options [] = {
00035   {"MaxIters",     1,   opt_max_niters  },
00036   {"Method",       1,   opt_method      },
00037   {"Verbose",      0,   opt_verbose     },
00038   {0,              0,   0               }
00039 } ;
00040 
00046 static void
00047 xcreate (mxArray *mnode, int i, VlHIKMNode *node)
00048 {
00049   int node_K = vl_ikm_get_K (node->filter) ;
00050   int M = vl_ikm_get_ndims (node->filter) ;
00051   vl_ikmacc_t const *centers = vl_ikm_get_centers (node->filter) ;
00052 
00053   mxArray *mcenters ;
00054 
00055   mcenters = mxCreateNumericMatrix (M, node_K, mxINT32_CLASS, mxREAL);
00056   memcpy (mxGetPr(mcenters), centers, sizeof(*centers) * M * node_K) ;
00057   mxSetField (mnode, i, "centers", mcenters) ;
00058 
00059   if (node->children) {
00060     mxArray * msub ;
00061     const char * field_names[] = {"centers", "sub" } ;
00062     mwSize dims [2] ;
00063     int k ;
00064 
00065     dims[0] = 1 ;
00066     dims[1] = node_K ;
00067 
00068     msub = mxCreateStructArray (2, dims, 2, field_names) ;
00069 
00070     for (k = 0 ; k < node_K ; ++k) {
00071       xcreate (msub, k, node -> children [k]) ;
00072     }
00073 
00074     mxSetField (mnode, i, "sub", msub) ;
00075   }
00076 }
00077 
00083 mxArray *
00084 hikm_to_matlab (VlHIKMTree * tree)
00085 {
00086   vl_size K = vl_hikm_get_K (tree) ;
00087   vl_size depth = vl_hikm_get_depth (tree) ;
00088   mwSize  dims [2] = {1, 1} ;
00089   mxArray *mtree ;
00090   const char *field_names[] = {"K", "depth", "centers", "sub"} ;
00091 
00092   /* Create the main struct array */
00093   mtree = mxCreateStructArray
00094     (2, dims, NFIELDS(field_names), field_names) ;
00095   mxSetField (mtree, 0, "K", mxCreateDoubleScalar (K)) ;
00096   mxSetField (mtree, 0, "depth", mxCreateDoubleScalar (depth)) ;
00097   if (tree->root) xcreate (mtree, 0, tree->root) ;
00098   return mtree;
00099 }
00100 
00105 void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
00106 {
00107   enum {IN_DATA = 0, IN_K, IN_NLEAVES, IN_END} ;
00108   enum {OUT_TREE = 0, OUT_ASGN} ;
00109   VlHIKMTree* tree ;
00110   int nleaves = 1 ;
00111   int method_type = VL_IKM_LLOYD ;
00112   int max_niters = 200 ;
00113   int verb = 0 ;
00114   vl_uint8 *data ;
00115   vl_size M, N, K = 2 ;
00116   vl_size depth = 0 ;
00117 
00118   int opt ;
00119   int next = IN_END ;
00120   mxArray const *optarg ;
00121 
00122   VL_USE_MATLAB_ENV ;
00123 
00124   /* -----------------------------------------------------------------
00125    *                                               Check the arguments
00126    * -------------------------------------------------------------- */
00127 
00128   if (nin < 3) {
00129       mexErrMsgTxt ("At least three arguments required.");
00130   } else if (nout > 2) {
00131     mexErrMsgTxt ("Too many output arguments.");
00132   }
00133 
00134   if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS) {
00135     mexErrMsgTxt ("DATA must be of class UINT8.");
00136   }
00137 
00138   if (! vlmxIsPlainScalar (in[IN_NLEAVES])          ||
00139       (nleaves = (int) *mxGetPr (in[IN_NLEAVES])) < 1) {
00140     mexErrMsgTxt ("NLEAVES must be a scalar not smaller than 2.") ;
00141   }
00142 
00143   M = mxGetM (in[IN_DATA]);   /* n of components */
00144   N = mxGetN (in[IN_DATA]);   /* n of elements */
00145 
00146   if (! vlmxIsPlainScalar (in[IN_K])    ||
00147       (K = (int) *mxGetPr (in[IN_K])) > N) {
00148     mexErrMsgTxt ("Cannot have more clusters than data.") ;
00149   }
00150 
00151   data = (vl_uint8 *) mxGetPr (in[IN_DATA]) ;
00152 
00153   while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00154     char buf [1024] ;
00155 
00156     switch (opt) {
00157 
00158     case opt_verbose :
00159       ++ verb ;
00160       break ;
00161 
00162     case opt_max_niters :
00163       if (!vlmxIsPlainScalar(optarg) ||
00164           (max_niters = (int) *mxGetPr(optarg)) < 1) {
00165         mexErrMsgTxt("MaxNiters must be not smaller than 1.") ;
00166       }
00167       break ;
00168 
00169     case opt_method :
00170       if (!vlmxIsString (optarg, -1)) {
00171         mexErrMsgTxt("'Method' must be a string.") ;
00172       }
00173       if (mxGetString (optarg, buf, sizeof(buf))) {
00174         mexErrMsgTxt("Option argument too long.") ;
00175       }
00176       if (strcmp("lloyd", buf) == 0) {
00177         method_type = VL_IKM_LLOYD ;
00178       } else if (strcmp("elkan", buf) == 0) {
00179         method_type = VL_IKM_ELKAN ;
00180       } else {
00181         mexErrMsgTxt("Unknown cost type.") ;
00182       }
00183       break ;
00184 
00185     default :
00186       abort() ; break ;
00187     }
00188   }
00189 
00190   /* -----------------------------------------------------------------
00191    *                                                        Do the job
00192    * -------------------------------------------------------------- */
00193 
00194   depth = VL_MAX(1, ceil (log (nleaves) / log(K))) ;
00195   tree  = vl_hikm_new  (method_type) ;
00196 
00197   if (verb) {
00198     mexPrintf("hikmeans: # dims: %d\n", M) ;
00199     mexPrintf("hikmeans: # data: %d\n", N) ;
00200     mexPrintf("hikmeans: K: %d\n", K) ;
00201     mexPrintf("hikmeans: depth: %d\n", depth) ;
00202   }
00203 
00204   vl_hikm_set_verbosity (tree, verb) ;
00205   vl_hikm_init (tree, M, K, depth) ;
00206   vl_hikm_train (tree, data, N) ;
00207 
00208   out[OUT_TREE] = hikm_to_matlab (tree) ;
00209 
00210   if (nout > 1) {
00211     vl_uint *asgn ;
00212     vl_uindex j ;
00213     out [OUT_ASGN] = mxCreateNumericMatrix
00214       (vl_hikm_get_depth (tree), N, mxUINT32_CLASS, mxREAL) ;
00215     asgn = mxGetData(out[OUT_ASGN]) ;
00216     vl_hikm_push (tree, asgn, data, N) ;
00217     for (j = 0 ; j < N*depth ; ++ j) asgn [j] ++ ;
00218   }
00219 
00220   if (verb) {
00221     mexPrintf("hikmeans: done.\n") ;
00222   }
00223 
00224   /* vl_hikm_delete (tree) ; */
00225 }


libvlfeat
Author(s): Andrea Vedaldi
autogenerated on Thu Jun 6 2019 20:25:51