$search
00001 // Copyright (C) 2008-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2008-2011 Conrad Sanderson 00003 // 00004 // This file is part of the Armadillo C++ library. 00005 // It is provided without any warranty of fitness 00006 // for any purpose. You can redistribute this file 00007 // and/or modify it under the terms of the GNU 00008 // Lesser General Public License (LGPL) as published 00009 // by the Free Software Foundation, either version 3 00010 // of the License or (at your option) any later version. 00011 // (see http://www.opensource.org/licenses for more info) 00012 00013 00016 00017 00018 00019 template<typename eT> 00020 class arma_ascend_sort_helper 00021 { 00022 public: 00023 00024 arma_inline 00025 bool 00026 operator() (eT a, eT b) const 00027 { 00028 return (a < b); 00029 } 00030 }; 00031 00032 00033 00034 template<typename eT> 00035 class arma_descend_sort_helper 00036 { 00037 public: 00038 00039 arma_inline 00040 bool 00041 operator() (eT a, eT b) const 00042 { 00043 return (a > b); 00044 } 00045 }; 00046 00047 00048 00049 template<typename T> 00050 class arma_ascend_sort_helper< std::complex<T> > 00051 { 00052 public: 00053 00054 typedef typename std::complex<T> eT; 00055 00056 inline 00057 bool 00058 operator() (const eT& a, const eT& b) const 00059 { 00060 return (std::abs(a) < std::abs(b)); 00061 } 00062 }; 00063 00064 00065 00066 template<typename T> 00067 class arma_descend_sort_helper< std::complex<T> > 00068 { 00069 public: 00070 00071 typedef typename std::complex<T> eT; 00072 00073 inline 00074 bool 00075 operator() (const eT& a, const eT& b) const 00076 { 00077 return (std::abs(a) > std::abs(b)); 00078 } 00079 }; 00080 00081 00082 00083 template<typename eT> 00084 inline 00085 void 00086 op_sort::direct_sort(eT* X, const uword n_elem, const uword sort_type) 00087 { 00088 arma_extra_debug_sigprint(); 00089 00090 if(sort_type == 0) 00091 { 00092 arma_ascend_sort_helper<eT> comparator; 00093 00094 std::sort(&X[0], &X[n_elem], comparator); 00095 } 00096 else 00097 { 00098 arma_descend_sort_helper<eT> comparator; 00099 00100 std::sort(&X[0], &X[n_elem], comparator); 00101 } 00102 } 00103 00104 00105 00106 template<typename eT> 00107 inline 00108 void 00109 op_sort::copy_row(eT* X, const Mat<eT>& A, const uword row) 00110 { 00111 const uword N = A.n_cols; 00112 00113 uword i,j; 00114 00115 for(i=0, j=1; j<N; i+=2, j+=2) 00116 { 00117 X[i] = A.at(row,i); 00118 X[j] = A.at(row,j); 00119 } 00120 00121 if(i < N) 00122 { 00123 X[i] = A.at(row,i); 00124 } 00125 } 00126 00127 00128 00129 template<typename eT> 00130 inline 00131 void 00132 op_sort::copy_row(Mat<eT>& A, const eT* X, const uword row) 00133 { 00134 const uword N = A.n_cols; 00135 00136 uword i,j; 00137 00138 for(i=0, j=1; j<N; i+=2, j+=2) 00139 { 00140 A.at(row,i) = X[i]; 00141 A.at(row,j) = X[j]; 00142 } 00143 00144 if(i < N) 00145 { 00146 A.at(row,i) = X[i]; 00147 } 00148 } 00149 00150 00151 00152 template<typename T1> 00153 inline 00154 void 00155 op_sort::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sort>& in) 00156 { 00157 arma_extra_debug_sigprint(); 00158 00159 typedef typename T1::elem_type eT; 00160 00161 const unwrap<T1> tmp(in.m); 00162 const Mat<eT>& X = tmp.M; 00163 00164 const uword sort_type = in.aux_uword_a; 00165 const uword dim = in.aux_uword_b; 00166 00167 arma_debug_check( (sort_type > 1), "sort(): incorrect usage. sort_type must be 0 or 1"); 00168 arma_debug_check( (dim > 1), "sort(): incorrect usage. dim must be 0 or 1" ); 00169 arma_debug_check( (X.is_finite() == false), "sort(): given object has non-finite elements" ); 00170 00171 if( (X.n_rows * X.n_cols) <= 1 ) 00172 { 00173 out = X; 00174 return; 00175 } 00176 00177 00178 if(dim == 0) // sort the contents of each column 00179 { 00180 arma_extra_debug_print("op_sort::apply(), dim = 0"); 00181 00182 out = X; 00183 00184 const uword n_rows = out.n_rows; 00185 const uword n_cols = out.n_cols; 00186 00187 for(uword col=0; col < n_cols; ++col) 00188 { 00189 op_sort::direct_sort( out.colptr(col), n_rows, sort_type ); 00190 } 00191 } 00192 else 00193 if(dim == 1) // sort the contents of each row 00194 { 00195 if(X.n_rows == 1) // a row vector 00196 { 00197 arma_extra_debug_print("op_sort::apply(), dim = 1, vector specific"); 00198 00199 out = X; 00200 op_sort::direct_sort(out.memptr(), out.n_elem, sort_type); 00201 } 00202 else // not a row vector 00203 { 00204 arma_extra_debug_print("op_sort::apply(), dim = 1, generic"); 00205 00206 out.copy_size(X); 00207 00208 const uword n_rows = out.n_rows; 00209 const uword n_cols = out.n_cols; 00210 00211 podarray<eT> tmp_array(n_cols); 00212 00213 for(uword row=0; row < n_rows; ++row) 00214 { 00215 op_sort::copy_row(tmp_array.memptr(), X, row); 00216 00217 op_sort::direct_sort( tmp_array.memptr(), n_cols, sort_type ); 00218 00219 op_sort::copy_row(out, tmp_array.memptr(), row); 00220 } 00221 } 00222 } 00223 00224 } 00225 00226