vl_alldist2.c
Go to the documentation of this file.
00001 /* file:        alldist2.c
00002 ** description: All pairwise distances
00003 ** author:      Andrea Vedaldi
00004 **/
00005 
00006 /*
00007 Copyright (C) 2007-12 Andrea Vedaldi and Brian Fulkerson.
00008 All rights reserved.
00009 
00010 This file is part of the VLFeat library and is made available under
00011 the terms of the BSD license (see the COPYING file).
00012 */
00013 
00014 #include <mexutils.h>
00015 
00016 #include <vl/mathop.h>
00017 #include <vl/generic.h>
00018 
00019 #include<stdio.h>
00020 #include<stdlib.h>
00021 #include<math.h>
00022 #include<string.h>
00023 #include<assert.h>
00024 
00025 enum {
00026   opt_LINF,
00027   opt_L2,
00028   opt_L1,
00029   opt_L0,
00030   opt_CHI2,
00031   opt_HELL,
00032 
00033   opt_KL2,
00034   opt_KL1,
00035   opt_KCHI2,
00036   opt_KHELL,
00037 
00038   opt_MIN
00039 } ;
00040 
00041 vlmxOption  options [] = {
00042   {"linf",         0,   opt_LINF          },
00043   {"l2",           0,   opt_L2            },
00044   {"l1",           0,   opt_L1            },
00045   {"l0",           0,   opt_L0            },
00046   {"chi2",         0,   opt_CHI2          },
00047   {"hell",         0,   opt_HELL          },
00048 
00049   {"kl2",          0,   opt_KL2           },
00050   {"kl1",          0,   opt_KL1           },
00051   {"kchi2",        0,   opt_KCHI2         },
00052   {"khell",        0,   opt_KHELL         },
00053 
00054   {"min",          0,   opt_MIN           },
00055   {0,              0,   0                 }
00056 } ;
00057 
00058 
00059 #undef MIN
00060 #undef MAX
00061 #undef ABS
00062 #undef ABS_DIFF
00063 #undef CORE
00064 
00065 #define MIN(x,y)      ((x) <= (y) ? (x) :  (y))
00066 #define MAX(x,y)      ((x) >= (y) ? (x) :  (y))
00067 #define ABS(x)        ((x) >= 0   ? (x) : -(x))
00068 #define ABS_DIFF(x,y) ((x) >= (y) ? ((x) - (y)) : ((y) - (x)))
00069 
00070 #ifndef sqrtf
00071 #define sqrtf(x) ((float)sqrt(x))
00072 #endif
00073 
00074 /* for L2 norm */
00075 /*#define CMP(s1,s2) ((double)(s1*s2)) */
00076 
00077 /* for L1 norm */
00078 #define CMP(s1,s2) ((double)MIN(s1,s2))
00079 
00080 #define UINT8_t  vl_uint8
00081 #define  INT8_t  vl_int8
00082 #define UINT16_t vl_uint16
00083 #define  INT16_t vl_int16
00084 #define UINT32_t vl_uint32
00085 #define  INT32_t vl_int32
00086 #define SINGLE_t float
00087 #define DOUBLE_t double
00088 
00089 #define CORE(NORM,F,DC,AC)                                              \
00090   void                                                                  \
00091   dist ## NORM ## _ ## DC ## _ ## AC                                    \
00092   (                                                                     \
00093    AC ## _t * pt, DC ## _t * s1_pt,                                     \
00094    DC ## _t * s2_pt,                                                    \
00095    vl_size L, vl_size N1, vl_size N2,                                   \
00096    bool self)                                                           \
00097   {                                                                     \
00098     vl_uindex j1,j2,l ;                                                 \
00099     for(j2 = 0 ; j2 < N2 ; ++j2)  {                                     \
00100       for(j1 = 0 ; j1 < N1 ; ++j1) {                                    \
00101         if(! self || j1>=j2) {                                          \
00102           AC ## _t acc = 0 ;                                            \
00103           DC ## _t * s1_it = s1_pt + L*j1 ;                             \
00104           DC ## _t * s2_it = s2_pt + L*j2 ;                             \
00105           for(l = 0 ; l < L ; ++l) {                                    \
00106             AC ## _t s1 = *s1_it++ ;                                    \
00107             AC ## _t s2 = *s2_it++ ;                                    \
00108             F(AC, s1, s2)                                               \
00109           }                                                             \
00110           *pt = acc;                                                    \
00111         } else {                                                        \
00112           *pt = *(pt + (j1 - j2) * (N1 - 1))  ;                         \
00113         }                                                               \
00114         pt++ ;                                                          \
00115       }                                                                 \
00116     }                                                                   \
00117   }                                                                     \
00118 
00119 #define CORE_SPARSE(NORM, F)                                            \
00120   {                                                                     \
00121     double const * s1_pt = mxGetPr(in[IN_S1]) ;                         \
00122     mwIndex const * s1_ir  = mxGetIr(in[IN_S1]) ;                       \
00123     mwIndex const * s1_jc  = mxGetJc(in[IN_S1]) ;                       \
00124     double const * s2_pt = 0 ;                                          \
00125     mwIndex  const * s2_ir  = 0 ;                                       \
00126     mwIndex const * s2_jc  = 0 ;                                        \
00127     double * pt = mxGetPr(out[OUT_D]) ;                                 \
00128     vl_uindex j1, j2 ;                                                  \
00129                                                                         \
00130     if (self) {                                                         \
00131       s2_pt = s1_pt ;                                                   \
00132       s2_ir = s1_ir ;                                                   \
00133       s2_jc = s1_jc ;                                                   \
00134     } else {                                                            \
00135       s2_pt = mxGetPr(in[IN_S2]) ;                                      \
00136       s2_ir = mxGetIr(in[IN_S2]) ;                                      \
00137       s2_jc = mxGetJc(in[IN_S2]) ;                                      \
00138     }                                                                   \
00139                                                                         \
00140     for (j2 = 0 ; j2 < N2 ; ++j2)  {                                    \
00141       for (j1 = 0 ; j1 < N1 ; ++j1) {                                   \
00142         int nz1 = s1_jc [j1+1] - s1_jc [j1] ;                           \
00143         int nz2 = s2_jc [j2+1] - s2_jc [j2] ;                           \
00144         if(! self || j1 >= j2) {                                        \
00145           double acc = 0 ;                                              \
00146           double const * s1_it = s1_pt + s1_jc [j1] ;                   \
00147           double const * s2_it = s2_pt + s2_jc [j2] ;                   \
00148           mwIndex const * s1_ir_it = s1_ir + s1_jc [j1] ;               \
00149           mwIndex const * s2_ir_it = s2_ir + s2_jc [j2] ;               \
00150           mwIndex ir1 ; \
00151           mwIndex ir2 ; \
00152           while (nz1 || nz2) {                                          \
00153             if (nz2 == 0) {                                             \
00154               double a = *s1_it++ ;                                     \
00155               F(DOUBLE, a, 0) ;                                         \
00156               s1_ir_it ++ ;                                             \
00157               nz1 -- ;                                                  \
00158               continue ;                                                \
00159             }                                                           \
00160             if (nz1 == 0) {                                             \
00161               double b = *s2_it++ ;                                     \
00162               F(DOUBLE, 0, b) ;                                         \
00163               s2_ir_it ++ ;                                             \
00164               nz2 -- ;                                                  \
00165               continue ;                                                \
00166             }                                                           \
00167             ir1 = *s1_ir_it ;                                           \
00168             ir2 = *s2_ir_it ;                                           \
00169             if (ir1 < ir2) {                                            \
00170               double a = *s1_it++ ;                                     \
00171               F(DOUBLE, a, 0) ;                                         \
00172               s1_ir_it ++ ;                                             \
00173               nz1 -- ;                                                  \
00174               continue ;                                                \
00175             }                                                           \
00176             if (ir1 > ir2) {                                            \
00177               double b = *s2_it++ ;                                     \
00178               F(DOUBLE, 0, b) ;                                         \
00179               s2_ir_it ++ ;                                             \
00180               nz2 -- ;                                                  \
00181               continue ;                                                \
00182             }                                                           \
00183             {                                                           \
00184               double a = *s1_it++ ;                                     \
00185               double b = *s2_it++ ;                                     \
00186               F(DOUBLE, a, b) ;                                         \
00187               s1_ir_it ++ ;                                             \
00188               s2_ir_it ++ ;                                             \
00189               nz1 -- ;                                                  \
00190               nz2 -- ;                                                  \
00191             }                                                           \
00192           }                                                             \
00193           *pt = acc;                                                    \
00194         } else {                                                        \
00195           *pt = *(pt + (j1 - j2) * (N1 - 1))  ;                         \
00196         }                                                               \
00197         pt++ ;                                                          \
00198       }                                                                 \
00199     }                                                                   \
00200   }
00201 
00202 
00203 #define DEF_CLASS(NORM,F)                          \
00204   CORE (NORM,  F,  INT8,     INT32)                \
00205   CORE (NORM,  F,  UINT8,   UINT32)                \
00206   CORE (NORM,  F,  INT16,    INT32)                \
00207   CORE (NORM,  F,  UINT16,  UINT32)                \
00208   CORE (NORM,  F,  INT32,    INT32)                \
00209   CORE (NORM,  F,  UINT32,  UINT32)                \
00210   CORE (NORM,  F,  SINGLE,  SINGLE)                \
00211   CORE (NORM,  F,  DOUBLE,  DOUBLE)
00212 
00213 #define  F_L0(AC,x,y)   { acc += (x) != (y) ; }
00214 #define  F_L1(AC,x,y)   { acc += ABS_DIFF(x,y) ; }
00215 #define  F_L2(AC,x,y)   { AC ## _t tmp = ABS_DIFF(x,y) ; acc += tmp * tmp ; }
00216 #define  F_LINF(AC,x,y) { acc = MAX(acc, ABS_DIFF(x,y)) ; }
00217 #define  F_CHI2(AC,x,y)                                  \
00218   {                                                      \
00219     AC ## _t  meant2 = ((x) + (y))  ;                    \
00220     if (meant2 != 0) {                                   \
00221       AC ## _t tmp  = ABS_DIFF(x,y) ;                    \
00222       acc += tmp * tmp / meant2 ;                        \
00223     }                                                    \
00224   }
00225 #define F_HELL_SINGLE(x,y) { acc += x + y - 2 * sqrtf (x * y) ; }
00226 #define F_HELL_DOUBLE(x,y) { acc += x + y - 2 * sqrt  (x * y) ; }
00227 #define F_HELL_UINT32(x,y) { acc += x + y - 2 * vl_fast_sqrt_ui32 (x * y) ; }
00228 #define F_HELL_INT32(x,y)  { acc += x + y - 2 * vl_fast_sqrt_ui32 (x * y) ; }
00229 #define F_HELL_UINT16(x,y) { acc += x + y - 2 * vl_fast_sqrt_ui32 (x * y) ; }
00230 #define F_HELL_INT16(x,y)  { acc += x + y - 2 * vl_fast_sqrt_ui32 (x * y) ; }
00231 #define F_HELL_UINT8(x,y)  { acc += x + y - 2 * vl_fast_sqrt_ui32 (x * y) ; }
00232 #define F_HELL_INT8(x,y)   { acc += x + y - 2 * vl_fast_sqrt_ui32 (x * y) ; }
00233 #define F_HELL(AC,x,y) F_HELL_ ## AC (x,y)
00234 
00235 #define  F_KL2(AC,x,y)  { acc += (x*y) ; }
00236 #define  F_KL1(AC,x,y)  { acc += MIN(x,y) ; }
00237 #define  F_MIN(AC,x,y)  { acc += MIN(x,y) ; }
00238 #define  F_KCHI2(AC,x,y)                                 \
00239   {                                                      \
00240     AC ## _t  mean = ((x) + (y)) / 2  ;                  \
00241     if (mean != 0) {                                     \
00242       AC ## _t tmp  = (x) * (y) ;                        \
00243       acc += tmp / mean ;                                \
00244     }                                                    \
00245   }
00246 #define F_KHELL_SINGLE(x,y) { acc += sqrtf (x * y) ; }
00247 #define F_KHELL_DOUBLE(x,y) { acc += sqrt  (x * y) ; }
00248 #define F_KHELL_UINT32(x,y) { acc += vl_fast_sqrt_ui32 (x * y) ; }
00249 #define F_KHELL_INT32(x,y)  { acc += vl_fast_sqrt_ui32 (x * y) ; }
00250 #define F_KHELL_UINT16(x,y) { acc += vl_fast_sqrt_ui32 (x * y) ; }
00251 #define F_KHELL_INT16(x,y)  { acc += vl_fast_sqrt_ui32 (x * y) ; }
00252 #define F_KHELL_UINT8(x,y)  { acc += vl_fast_sqrt_ui32 (x * y) ; }
00253 #define F_KHELL_INT8(x,y)   { acc += vl_fast_sqrt_ui32 (x * y) ; }
00254 #define F_KHELL(AC,x,y) F_KHELL_ ## AC (x,y)
00255 
00256 DEF_CLASS (LINF,  F_LINF )
00257 DEF_CLASS (L2,    F_L2   )
00258 DEF_CLASS (L1,    F_L1   )
00259 DEF_CLASS (L0,    F_L0   )
00260 DEF_CLASS (CHI2,  F_CHI2 )
00261 DEF_CLASS (HELL,  F_HELL )
00262 
00263 DEF_CLASS (KL2,   F_KL2  )
00264 DEF_CLASS (KL1,   F_KL1  )
00265 DEF_CLASS (KCHI2, F_KCHI2)
00266 DEF_CLASS (KHELL, F_KHELL)
00267 
00268 DEF_CLASS (MIN,   F_MIN  )
00269 
00270 /* driver */
00271 void
00272 mexFunction(int nout, mxArray *out[],
00273             int nin, const mxArray *in[])
00274 {
00275 
00276   typedef int  unsigned data_t ;
00277 
00278   /*  mxClassID data_class = mxINT8_CLASS ;*/
00279   enum {IN_S1,IN_S2} ;
00280   enum {OUT_D=0} ;
00281   vl_size L,N1,N2 ;
00282   vl_bool sparse = 0 ;
00283   void const * s1_pt ;
00284   void const * s2_pt ;
00285   mxClassID data_class ;
00286   mxClassID acc_class ;
00287   mwSize dims [2] ;
00288 
00289   /* for option parsing */
00290   bool           self = 1 ;      /* called with one numeric argument? */
00291   int            norm = opt_L2 ; /* type of norm to be computed       */
00292   int            opt ;
00293   int            next = 1 ;
00294   mxArray const *optarg ;
00295 
00300   if (nout > 1) {
00301     mexErrMsgTxt("Too many output arguments.");
00302   }
00303 
00304   if (nin < 1) {
00305     mexErrMsgTxt("At leat one argument required.") ;
00306   }
00307 
00308   if(! mxIsNumeric(in[IN_S1])) {
00309     mexErrMsgTxt ("X must be numeric") ;
00310   }
00311 
00312   if (nin >= 2 && mxIsNumeric(in[IN_S2])) {
00313     self = 0 ;
00314     next = 2 ;
00315   }
00316 
00317   sparse = mxIsSparse(in[IN_S1]) ;
00318 
00319   if (sparse && nin >=2 && mxIsNumeric(in[IN_S2])) {
00320     if (! mxIsSparse(in[IN_S2])) {
00321       mexErrMsgTxt ("X and Y must be either both full or sparse.") ;
00322     }
00323   }
00324 
00325   while ((opt = vlmxNextOption (in, nin, options, &next, &optarg)) >= 0) {
00326     switch (opt) {
00327     case opt_LINF :
00328     case opt_L2 :
00329     case opt_L1 :
00330     case opt_L0 :
00331     case opt_CHI2 :
00332     case opt_HELL :
00333 
00334     case opt_KL2 :
00335     case opt_KL1 :
00336     case opt_KCHI2 :
00337     case opt_KHELL :
00338 
00339     case opt_MIN :
00340       norm = opt ;
00341       break ;
00342 
00343     default:
00344       abort() ;
00345     }
00346   }
00347 
00348   data_class = mxGetClassID(in[IN_S1]) ;
00349   if ((!self) && data_class != mxGetClassID(in[IN_S2])) {
00350     mexErrMsgTxt("X and Y must have the same numeric class") ;
00351   }
00352 
00353   assert ((! sparse) || (data_class == mxDOUBLE_CLASS)) ;
00354 
00355   L  = mxGetM(in[IN_S1]) ;
00356   N1 = mxGetN(in[IN_S1]) ;
00357   N2 = self ?  N1 : mxGetN(in[IN_S2]) ;
00358 
00359   dims[0] = N1 ;
00360   dims[1] = N2 ;
00361 
00362   if ((!self) && L != mxGetM(in[IN_S2])) {
00363     mexErrMsgTxt("X and Y must have the same number of rows") ;
00364   }
00365 
00366   s1_pt = mxGetData(in[IN_S1]) ;
00367   s2_pt = self ? s1_pt : mxGetData(in[IN_S2]) ;
00368 
00369 #define DISPATCH_CLASS(NORM, DC,AC)                                     \
00370   case mx ## DC ## _CLASS :                                             \
00371     acc_class = mx ## AC ## _CLASS ;                                    \
00372   out[OUT_D] = mxCreateNumericArray(2,dims,acc_class,mxREAL) ;          \
00373   dist ## NORM ## _ ## DC ## _ ## AC                                    \
00374     ( (AC ## _t *)mxGetData(out[OUT_D]),                                \
00375       (DC ## _t *)s1_pt,                                                \
00376       (DC ## _t *)s2_pt,                                                \
00377       L, N1, N2,                                                        \
00378       self ) ;                                                          \
00379   break ;
00380 
00381 #define DISPATCH_NORM(NORM)                                             \
00382   case opt_ ## NORM :                                                   \
00383     if (sparse) {                                                       \
00384       out[OUT_D] = mxCreateNumericArray(2,dims,mxDOUBLE_CLASS,mxREAL) ; \
00385       CORE_SPARSE(NORM, VL_XCAT(F_, NORM))                              \
00386     } else {                                                            \
00387       switch (data_class) {                                             \
00388         DISPATCH_CLASS(NORM,  UINT8 , UINT32)                           \
00389           DISPATCH_CLASS(NORM,  INT8 ,  INT32)                          \
00390           DISPATCH_CLASS(NORM, UINT16, UINT32)                          \
00391           DISPATCH_CLASS(NORM,  INT16,  INT32)                          \
00392           DISPATCH_CLASS(NORM, UINT32, UINT32)                          \
00393           DISPATCH_CLASS(NORM,  INT32,  INT32)                          \
00394           DISPATCH_CLASS(NORM, SINGLE, SINGLE)                          \
00395           DISPATCH_CLASS(NORM, DOUBLE,DOUBLE)                           \
00396       default:                                                          \
00397         mexErrMsgTxt("Data class not supported!") ;                     \
00398       }                                                                 \
00399     }                                                                   \
00400   break ;
00401 
00402   switch (norm) {
00403     DISPATCH_NORM(LINF ) ;
00404     DISPATCH_NORM(L2   ) ;
00405     DISPATCH_NORM(L1   ) ;
00406     DISPATCH_NORM(L0   ) ;
00407     DISPATCH_NORM(CHI2 ) ;
00408     DISPATCH_NORM(HELL ) ;
00409 
00410     DISPATCH_NORM(KL2  ) ;
00411     DISPATCH_NORM(KL1  ) ;
00412     DISPATCH_NORM(KCHI2) ;
00413     DISPATCH_NORM(KHELL) ;
00414 
00415     DISPATCH_NORM(MIN  ) ;
00416   default:
00417     abort() ;
00418   }
00419 }


libvlfeat
Author(s): Andrea Vedaldi
autogenerated on Thu Jun 6 2019 20:25:51