00001
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include <vl/fisher.h>
00016 #include <mexutils.h>
00017 #include <string.h>
00018 #include <stdio.h>
00019
00020 enum {
00021 opt_verbose,
00022 opt_normalized,
00023 opt_square_root,
00024 opt_improved,
00025 opt_fast
00026 } ;
00027
00028 vlmxOption options [] = {
00029 {"Verbose", 0, opt_verbose },
00030 {"Normalized", 0, opt_normalized },
00031 {"SquareRoot", 0, opt_square_root },
00032 {"Improved", 0, opt_improved },
00033 {"Fast", 0, opt_fast }
00034 } ;
00035
00036
00037 void
00038 mexFunction (int nout VL_UNUSED, mxArray * out[], int nin, const mxArray * in[])
00039 {
00040 enum {IN_DATA = 0, IN_MEANS, IN_COVARIANCES, IN_PRIORS, IN_END} ;
00041 enum {OUT_ENC} ;
00042
00043 int opt ;
00044 int next = IN_END ;
00045 mxArray const *optarg ;
00046
00047 vl_size numClusters = 10;
00048 vl_size dimension ;
00049 vl_size numData ;
00050 int flags = 0 ;
00051
00052 void * covariances = NULL;
00053 void * means = NULL;
00054 void * priors = NULL;
00055 void * data = NULL ;
00056 vl_size numTerms ;
00057
00058 int verbosity = 0 ;
00059
00060 vl_type dataType ;
00061 mxClassID classID ;
00062
00063 VL_USE_MATLAB_ENV ;
00064
00065
00066
00067
00068
00069 if (nin < 4) {
00070 vlmxError (vlmxErrInvalidArgument,
00071 "At least four arguments required.");
00072 }
00073 if (nout > 1) {
00074 vlmxError (vlmxErrInvalidArgument,
00075 "At most one output argument.");
00076 }
00077
00078 classID = mxGetClassID (IN(DATA)) ;
00079 switch (classID) {
00080 case mxSINGLE_CLASS: dataType = VL_TYPE_FLOAT ; break ;
00081 case mxDOUBLE_CLASS: dataType = VL_TYPE_DOUBLE ; break ;
00082 default:
00083 vlmxError (vlmxErrInvalidArgument,
00084 "DATA is neither of class SINGLE or DOUBLE.") ;
00085 }
00086
00087 if (mxGetClassID (IN(MEANS)) != classID) {
00088 vlmxError(vlmxErrInvalidArgument, "MEANS is not of the same class as DATA.") ;
00089 }
00090 if (mxGetClassID (IN(COVARIANCES)) != classID) {
00091 vlmxError(vlmxErrInvalidArgument, "COVARIANCES is not of the same class as DATA.") ;
00092 }
00093 if (mxGetClassID (IN(PRIORS)) != classID) {
00094 vlmxError(vlmxErrInvalidArgument, "PRIORS is not of the same class as DATA.") ;
00095 }
00096
00097 dimension = mxGetM (IN(DATA)) ;
00098 numData = mxGetN (IN(DATA)) ;
00099 numClusters = mxGetN (IN(MEANS)) ;
00100
00101 if (dimension == 0) {
00102 vlmxError (vlmxErrInvalidArgument, "SIZE(DATA,1) is zero.") ;
00103 }
00104 if (!vlmxIsMatrix(IN(MEANS), dimension, numClusters)) {
00105 vlmxError (vlmxErrInvalidArgument, "MEANS is not a matrix or does not have the correct size.") ;
00106 }
00107 if (!vlmxIsMatrix(IN(COVARIANCES), dimension, numClusters)) {
00108 vlmxError (vlmxErrInvalidArgument, "COVARIANCES is not a matrix or does not have the correct size.") ;
00109 }
00110 if (!vlmxIsVector(IN(PRIORS), numClusters)) {
00111 vlmxError (vlmxErrInvalidArgument, "PRIORS is not a vector or does not have the correct size.") ;
00112 }
00113 if (!vlmxIsMatrix(IN(DATA), dimension, numData)) {
00114 vlmxError (vlmxErrInvalidArgument, "DATA is not a matrix or does not have the correct size.") ;
00115 }
00116
00117 while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00118 switch (opt) {
00119 case opt_verbose : ++ verbosity ; break ;
00120 case opt_normalized: flags |= VL_FISHER_FLAG_NORMALIZED ; break ;
00121 case opt_square_root: flags |= VL_FISHER_FLAG_SQUARE_ROOT ; break ;
00122 case opt_improved: flags |= VL_FISHER_FLAG_IMPROVED ; break ;
00123 case opt_fast: flags |= VL_FISHER_FLAG_FAST ; break ;
00124 default : abort() ;
00125 }
00126 }
00127
00128
00129
00130
00131
00132 data = mxGetPr(IN(DATA)) ;
00133 means = mxGetPr(IN(MEANS)) ;
00134 covariances = mxGetPr(IN(COVARIANCES)) ;
00135 priors = mxGetPr(IN(PRIORS)) ;
00136
00137 if (verbosity) {
00138 mexPrintf("vl_fisher: num data: %d\n", numData) ;
00139 mexPrintf("vl_fisher: num clusters: %d\n", numClusters) ;
00140 mexPrintf("vl_fisher: data dimension: %d\n", dimension) ;
00141 mexPrintf("vl_fisher: code dimension: %d\n", numClusters * dimension) ;
00142 mexPrintf("vl_fisher: square root: %s\n", VL_YESNO(flags & VL_FISHER_FLAG_SQUARE_ROOT)) ;
00143 mexPrintf("vl_fisher: normalized: %s\n", VL_YESNO(flags & VL_FISHER_FLAG_NORMALIZED)) ;
00144 mexPrintf("vl_fisher: fast: %s\n", VL_YESNO(flags & VL_FISHER_FLAG_FAST)) ;
00145 }
00146
00147
00148
00149
00150
00151 OUT(ENC) = mxCreateNumericMatrix (dimension * numClusters * 2, 1, classID, mxREAL) ;
00152
00153 numTerms = vl_fisher_encode (mxGetData(OUT(ENC)), dataType,
00154 means, dimension, numClusters,
00155 covariances,
00156 priors,
00157 data, numData,
00158 flags) ;
00159
00160 if (verbosity) {
00161 mexPrintf("vl_fisher: sparsity of assignments: %.2f%% (%d non-negligible assignments)\n",
00162 100.0 * (1.0 - (double)numTerms/((double)numData*(double)numClusters+1e-12)),
00163 numTerms) ;
00164 }
00165 }