00001
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include<mexutils.h>
00016
00017 #include<stdio.h>
00018 #include<stdlib.h>
00019 #include<math.h>
00020 #include<string.h>
00021 #include<assert.h>
00022
00023 #include <vl/hikmeans.h>
00024 #include <vl/generic.h>
00025
00026 #define NFIELDS(field_names) (sizeof(field_names)/sizeof(*field_names))
00027
00028 enum {
00029 opt_max_niters,
00030 opt_method,
00031 opt_verbose
00032 } ;
00033
00034 vlmxOption options [] = {
00035 {"MaxIters", 1, opt_max_niters },
00036 {"Method", 1, opt_method },
00037 {"Verbose", 0, opt_verbose },
00038 {0, 0, 0 }
00039 } ;
00040
00046 static void
00047 xcreate (mxArray *mnode, int i, VlHIKMNode *node)
00048 {
00049 int node_K = vl_ikm_get_K (node->filter) ;
00050 int M = vl_ikm_get_ndims (node->filter) ;
00051 vl_ikmacc_t const *centers = vl_ikm_get_centers (node->filter) ;
00052
00053 mxArray *mcenters ;
00054
00055 mcenters = mxCreateNumericMatrix (M, node_K, mxINT32_CLASS, mxREAL);
00056 memcpy (mxGetPr(mcenters), centers, sizeof(*centers) * M * node_K) ;
00057 mxSetField (mnode, i, "centers", mcenters) ;
00058
00059 if (node->children) {
00060 mxArray * msub ;
00061 const char * field_names[] = {"centers", "sub" } ;
00062 mwSize dims [2] ;
00063 int k ;
00064
00065 dims[0] = 1 ;
00066 dims[1] = node_K ;
00067
00068 msub = mxCreateStructArray (2, dims, 2, field_names) ;
00069
00070 for (k = 0 ; k < node_K ; ++k) {
00071 xcreate (msub, k, node -> children [k]) ;
00072 }
00073
00074 mxSetField (mnode, i, "sub", msub) ;
00075 }
00076 }
00077
00083 mxArray *
00084 hikm_to_matlab (VlHIKMTree * tree)
00085 {
00086 vl_size K = vl_hikm_get_K (tree) ;
00087 vl_size depth = vl_hikm_get_depth (tree) ;
00088 mwSize dims [2] = {1, 1} ;
00089 mxArray *mtree ;
00090 const char *field_names[] = {"K", "depth", "centers", "sub"} ;
00091
00092
00093 mtree = mxCreateStructArray
00094 (2, dims, NFIELDS(field_names), field_names) ;
00095 mxSetField (mtree, 0, "K", mxCreateDoubleScalar (K)) ;
00096 mxSetField (mtree, 0, "depth", mxCreateDoubleScalar (depth)) ;
00097 if (tree->root) xcreate (mtree, 0, tree->root) ;
00098 return mtree;
00099 }
00100
00105 void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
00106 {
00107 enum {IN_DATA = 0, IN_K, IN_NLEAVES, IN_END} ;
00108 enum {OUT_TREE = 0, OUT_ASGN} ;
00109 VlHIKMTree* tree ;
00110 int nleaves = 1 ;
00111 int method_type = VL_IKM_LLOYD ;
00112 int max_niters = 200 ;
00113 int verb = 0 ;
00114 vl_uint8 *data ;
00115 vl_size M, N, K = 2 ;
00116 vl_size depth = 0 ;
00117
00118 int opt ;
00119 int next = IN_END ;
00120 mxArray const *optarg ;
00121
00122 VL_USE_MATLAB_ENV ;
00123
00124
00125
00126
00127
00128 if (nin < 3) {
00129 mexErrMsgTxt ("At least three arguments required.");
00130 } else if (nout > 2) {
00131 mexErrMsgTxt ("Too many output arguments.");
00132 }
00133
00134 if (mxGetClassID (in[IN_DATA]) != mxUINT8_CLASS) {
00135 mexErrMsgTxt ("DATA must be of class UINT8.");
00136 }
00137
00138 if (! vlmxIsPlainScalar (in[IN_NLEAVES]) ||
00139 (nleaves = (int) *mxGetPr (in[IN_NLEAVES])) < 1) {
00140 mexErrMsgTxt ("NLEAVES must be a scalar not smaller than 2.") ;
00141 }
00142
00143 M = mxGetM (in[IN_DATA]);
00144 N = mxGetN (in[IN_DATA]);
00145
00146 if (! vlmxIsPlainScalar (in[IN_K]) ||
00147 (K = (int) *mxGetPr (in[IN_K])) > N) {
00148 mexErrMsgTxt ("Cannot have more clusters than data.") ;
00149 }
00150
00151 data = (vl_uint8 *) mxGetPr (in[IN_DATA]) ;
00152
00153 while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00154 char buf [1024] ;
00155
00156 switch (opt) {
00157
00158 case opt_verbose :
00159 ++ verb ;
00160 break ;
00161
00162 case opt_max_niters :
00163 if (!vlmxIsPlainScalar(optarg) ||
00164 (max_niters = (int) *mxGetPr(optarg)) < 1) {
00165 mexErrMsgTxt("MaxNiters must be not smaller than 1.") ;
00166 }
00167 break ;
00168
00169 case opt_method :
00170 if (!vlmxIsString (optarg, -1)) {
00171 mexErrMsgTxt("'Method' must be a string.") ;
00172 }
00173 if (mxGetString (optarg, buf, sizeof(buf))) {
00174 mexErrMsgTxt("Option argument too long.") ;
00175 }
00176 if (strcmp("lloyd", buf) == 0) {
00177 method_type = VL_IKM_LLOYD ;
00178 } else if (strcmp("elkan", buf) == 0) {
00179 method_type = VL_IKM_ELKAN ;
00180 } else {
00181 mexErrMsgTxt("Unknown cost type.") ;
00182 }
00183 break ;
00184
00185 default :
00186 abort() ; break ;
00187 }
00188 }
00189
00190
00191
00192
00193
00194 depth = VL_MAX(1, ceil (log (nleaves) / log(K))) ;
00195 tree = vl_hikm_new (method_type) ;
00196
00197 if (verb) {
00198 mexPrintf("hikmeans: # dims: %d\n", M) ;
00199 mexPrintf("hikmeans: # data: %d\n", N) ;
00200 mexPrintf("hikmeans: K: %d\n", K) ;
00201 mexPrintf("hikmeans: depth: %d\n", depth) ;
00202 }
00203
00204 vl_hikm_set_verbosity (tree, verb) ;
00205 vl_hikm_init (tree, M, K, depth) ;
00206 vl_hikm_train (tree, data, N) ;
00207
00208 out[OUT_TREE] = hikm_to_matlab (tree) ;
00209
00210 if (nout > 1) {
00211 vl_uint *asgn ;
00212 vl_uindex j ;
00213 out [OUT_ASGN] = mxCreateNumericMatrix
00214 (vl_hikm_get_depth (tree), N, mxUINT32_CLASS, mxREAL) ;
00215 asgn = mxGetData(out[OUT_ASGN]) ;
00216 vl_hikm_push (tree, asgn, data, N) ;
00217 for (j = 0 ; j < N*depth ; ++ j) asgn [j] ++ ;
00218 }
00219
00220 if (verb) {
00221 mexPrintf("hikmeans: done.\n") ;
00222 }
00223
00224
00225 }