00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include<mexutils.h>
00016
00017 #include<stdio.h>
00018 #include<stdlib.h>
00019 #include<math.h>
00020 #include<string.h>
00021 #include<assert.h>
00022
00023 #include <vl/ikmeans.h>
00024 #include <vl/generic.h>
00025
00026 enum {
00027 opt_max_niters,
00028 opt_method,
00029 opt_verbose
00030 } ;
00031
00032 vlmxOption options [] = {
00033 {"MaxIters", 1, opt_max_niters },
00034 {"Method", 1, opt_method },
00035 {"Verbose", 0, opt_verbose },
00036 {0, 0, 0 }
00037 } ;
00038
00039
00040
00041 void mexFunction (int nout, mxArray * out[], int nin, const mxArray * in[])
00042 {
00043 enum {IN_X = 0, IN_K, IN_END} ;
00044 enum {OUT_C = 0, OUT_I} ;
00045
00046 int opt ;
00047 int next = IN_END ;
00048 mxArray const *optarg ;
00049
00050
00051 VlIKMFilt *ikmf ;
00052 vl_uint8 *data ;
00053 mwSize M, N, K ;
00054 int method_type = VL_IKM_LLOYD ;
00055 int max_niters = 200 ;
00056 int verb = 0 ;
00057 int err = 0 ;
00058
00059 VL_USE_MATLAB_ENV ;
00060
00061
00062
00063
00064 if (nin < 2) {
00065 mexErrMsgTxt ("At least two arguments required.");
00066 }
00067 else if (nout > 2) {
00068 mexErrMsgTxt ("Too many output arguments.");
00069 }
00070 if (mxGetClassID (in[IN_X]) != mxUINT8_CLASS) {
00071 mexErrMsgTxt ("X must be of class UINT8.");
00072 }
00073
00074 M = mxGetM(in[IN_X]);
00075 N = mxGetN(in[IN_X]);
00076
00077 if (!vlmxIsPlainScalar (in[IN_K]) ||
00078 (K = (int) *mxGetPr(in[IN_K])) < 1 ||
00079 K > N) {
00080 mexErrMsgTxt ("K must be a positive integer not greater than the number of data.");
00081 }
00082
00083 while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00084 char buf [1024] ;
00085
00086 switch (opt) {
00087
00088 case opt_verbose :
00089 ++ verb ;
00090 break ;
00091
00092 case opt_max_niters :
00093 if (!vlmxIsPlainScalar(optarg) || (max_niters = (int) *mxGetPr(optarg)) < 1) {
00094 mexErrMsgTxt("MaxNIters must be not smaller than 1.") ;
00095 }
00096 break ;
00097
00098 case opt_method :
00099 if (!vlmxIsString (optarg, -1)) {
00100 mexErrMsgTxt("'Method' must be a string.") ;
00101 }
00102 if (mxGetString (optarg, buf, sizeof(buf))) {
00103 mexErrMsgTxt("Option argument too long.") ;
00104 }
00105 if (strcmp("lloyd", buf) == 0) {
00106 method_type = VL_IKM_LLOYD ;
00107 } else if (strcmp("elkan", buf) == 0) {
00108 method_type = VL_IKM_ELKAN ;
00109 } else {
00110 mexErrMsgTxt("Unknown method type.") ;
00111 }
00112 break ;
00113
00114 default :
00115 abort() ;
00116 }
00117 }
00118
00119
00120
00121
00122
00123 if (verb) {
00124 char const * method_name = 0 ;
00125 switch (method_type) {
00126 case VL_IKM_LLOYD: method_name = "Lloyd" ; break ;
00127 case VL_IKM_ELKAN: method_name = "Elkan" ; break ;
00128 default :
00129 abort() ;
00130 }
00131 mexPrintf("vl_ikmeans: MaxInters = %d\n", max_niters) ;
00132 mexPrintf("vl_ikmeans: Method = %s\n", method_name) ;
00133 }
00134
00135 data = (vl_uint8*) mxGetData(in[IN_X]) ;
00136 ikmf = vl_ikm_new (method_type) ;
00137
00138 vl_ikm_set_verbosity (ikmf, verb) ;
00139 vl_ikm_set_max_niters (ikmf, max_niters) ;
00140 vl_ikm_init_rand_data (ikmf, data, M, N, K) ;
00141
00142 err = vl_ikm_train (ikmf, data, N) ;
00143 if (err) mexWarnMsgTxt("vl_ikmeans: possible overflow!") ;
00144
00145
00146
00147
00148
00149 {
00150 out[OUT_C] = mxCreateNumericMatrix (M, K, mxINT32_CLASS, mxREAL) ;
00151 memcpy(mxGetData(OUT(C)),
00152 vl_ikm_get_centers(ikmf),
00153 sizeof(vl_ikmacc_t) * M * K) ;
00154 }
00155
00156 if (nout > 1) {
00157 vl_uindex i ;
00158 vl_uint32 *asgn ;
00159 out[OUT_I] = mxCreateNumericMatrix (1, N, mxUINT32_CLASS, mxREAL) ;
00160 asgn = (vl_uint32*) mxGetData (out[OUT_I]) ;
00161
00162 vl_ikm_push (ikmf, asgn, data, N) ;
00163
00164 for (i = 0 ; i < N ; ++i) { ++ asgn [i] ; }
00165 }
00166
00167 vl_ikm_delete (ikmf) ;
00168
00169 if (verb) {
00170 mexPrintf("vl_ikmeans: done\n") ;
00171 }
00172 }