00001
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include <mexutils.h>
00016 #include <vl/kdtree.h>
00017 #include <vl/stringop.h>
00018
00019 #include <assert.h>
00020 #include <string.h>
00021
00022 #include "kdtree.h"
00023
00024
00025 enum {
00026 opt_verbose, opt_threshold_method, opt_num_trees, opt_distance
00027 } ;
00028
00029
00030 vlmxOption options [] = {
00031 {"Verbose", 0, opt_verbose },
00032 {"ThresholdMethod", 1, opt_threshold_method },
00033 {"NumTrees", 1, opt_num_trees },
00034 {"Distance", 1, opt_distance },
00035 {0, 0, 0 }
00036 } ;
00037
00042 void
00043 mexFunction(int nout, mxArray *out[],
00044 int nin, const mxArray *in[])
00045 {
00046 enum {IN_DATA = 0, IN_END} ;
00047 enum {OUT_TREE = 0} ;
00048
00049 int verbose = 0 ;
00050 int opt ;
00051 int next = IN_END ;
00052 mxArray const *optarg ;
00053
00054 VlKDForest * forest ;
00055 void * data ;
00056 vl_size numData ;
00057 vl_size dimension ;
00058 mxClassID dataClass ;
00059 vl_type dataType ;
00060 int thresholdingMethod = VL_KDTREE_MEDIAN ;
00061 VlVectorComparisonType distance = VlDistanceL2;
00062 vl_size numTrees = 1 ;
00063
00064 VL_USE_MATLAB_ENV ;
00065
00066
00067
00068
00069
00070 if (nin < 1) {
00071 vlmxError(vlmxErrInvalidArgument,
00072 "At least one argument required") ;
00073 } else if (nout > 2) {
00074 vlmxError(vlmxErrInvalidArgument,
00075 "Too many output arguments");
00076 }
00077
00078 dataClass = mxGetClassID(IN(DATA)) ;
00079
00080 if (! vlmxIsMatrix (IN(DATA), -1, -1) ||
00081 ! vlmxIsReal (IN(DATA))) {
00082 vlmxError(vlmxErrInvalidArgument,
00083 "DATA must be a real matrix ") ;
00084 }
00085
00086 switch (dataClass) {
00087 case mxSINGLE_CLASS : dataType = VL_TYPE_FLOAT ; break ;
00088 case mxDOUBLE_CLASS : dataType = VL_TYPE_DOUBLE ; break ;
00089 default:
00090 vlmxError(vlmxErrInvalidArgument,
00091 "DATA must be either SINGLE or DOUBLE") ;
00092 }
00093
00094 while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00095 char buffer [1024] ;
00096 switch (opt) {
00097 case opt_threshold_method :
00098 mxGetString (optarg, buffer, sizeof(buffer)/sizeof(buffer[0])) ;
00099 if (! vlmxIsString(optarg, -1)) {
00100 vlmxError(vlmxErrInvalidOption,
00101 "THRESHOLDMETHOD must be a string") ;
00102 }
00103 if (vl_string_casei_cmp(buffer, "median") == 0) {
00104 thresholdingMethod = VL_KDTREE_MEDIAN ;
00105 } else if (vl_string_casei_cmp(buffer, "mean") == 0) {
00106 thresholdingMethod = VL_KDTREE_MEAN ;
00107 } else {
00108 vlmxError(vlmxErrInvalidOption,
00109 "Unknown thresholding method %s", buffer) ;
00110 }
00111 break ;
00112
00113 case opt_num_trees :
00114 if (! vlmxIsScalar(optarg) ||
00115 (numTrees = mxGetScalar(optarg)) < 1) {
00116 vlmxError(vlmxErrInvalidOption,
00117 "NUMTREES must be not smaller than one") ;
00118 }
00119 break ;
00120
00121 case opt_verbose :
00122 ++ verbose ;
00123 break ;
00124
00125 case opt_distance :
00126 if (!vlmxIsString (optarg, -1)) {
00127 vlmxError (vlmxErrInvalidArgument,
00128 "DISTANCE must be a string.") ;
00129 }
00130 if (mxGetString (optarg, buffer, sizeof(buffer))) {
00131 vlmxError (vlmxErrInvalidArgument,
00132 "DISTANCE argument too long.") ;
00133 }
00134 if (vlmxCompareStringsI("l2", buffer) == 0) {
00135 distance = VlDistanceL2 ;
00136 } else if (vlmxCompareStringsI("l1", buffer) == 0) {
00137 distance = VlDistanceL1 ;
00138 } else {
00139 vlmxError (vlmxErrInvalidArgument,
00140 "Invalid value %s for DISTANCE", buffer) ;
00141 }
00142 break ;
00143
00144 }
00145 }
00146
00147 data = mxGetData (IN(DATA)) ;
00148 numData = mxGetN (IN(DATA)) ;
00149 dimension = mxGetM (IN(DATA)) ;
00150
00151 if (dimension < 1) {
00152 vlmxError (vlmxErrInconsistentData,
00153 "DATA must have at least one row.") ;
00154 }
00155
00156 if (numData < 1) {
00157 vlmxError (vlmxErrInconsistentData,
00158 "DATA must have at least one column.") ;
00159 }
00160
00161
00162 forest = vl_kdforest_new (dataType, dimension, numTrees, distance) ;
00163 vl_kdforest_set_thresholding_method (forest, thresholdingMethod) ;
00164
00165 if (verbose) {
00166 char const * str = 0 ;
00167 mexPrintf("vl_kdforestbuild: data %s [%d x %d]\n",
00168 vl_get_type_name (dataType), dimension, numData) ;
00169 switch (vl_kdforest_get_thresholding_method(forest)) {
00170 case VL_KDTREE_MEAN : str = "mean" ; break ;
00171 case VL_KDTREE_MEDIAN : str = "median" ; break ;
00172 default: abort() ;
00173 }
00174 mexPrintf("vl_kdforestbuild: threshold selection method: %s\n", str) ;
00175 mexPrintf("vl_kdforestbuild: number of trees: %d\n",
00176 vl_kdforest_get_num_trees(forest)) ;
00177 }
00178
00179
00180
00181
00182
00183 vl_kdforest_build (forest, numData, data) ;
00184
00185 if (verbose) {
00186 vl_uindex ti ;
00187 for (ti = 0 ; ti < vl_kdforest_get_num_trees(forest) ; ++ ti) {
00188 mexPrintf("vl_kdforestbuild: tree %d: depth %d, num nodes %d\n",
00189 ti,
00190 vl_kdforest_get_depth_of_tree(forest, ti),
00191 vl_kdforest_get_num_nodes_of_tree(forest, ti)) ;
00192 }
00193 }
00194
00195 out[OUT_TREE] = new_array_from_kdforest (forest) ;
00196 vl_kdforest_delete (forest) ;
00197 }