00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00016
00017
00018
00019 template<typename eT>
00020 class arma_ascend_sort_helper
00021 {
00022 public:
00023
00024 arma_inline
00025 bool
00026 operator() (eT a, eT b) const
00027 {
00028 return (a < b);
00029 }
00030 };
00031
00032
00033
00034 template<typename eT>
00035 class arma_descend_sort_helper
00036 {
00037 public:
00038
00039 arma_inline
00040 bool
00041 operator() (eT a, eT b) const
00042 {
00043 return (a > b);
00044 }
00045 };
00046
00047
00048
00049 template<typename T>
00050 class arma_ascend_sort_helper< std::complex<T> >
00051 {
00052 public:
00053
00054 typedef typename std::complex<T> eT;
00055
00056 inline
00057 bool
00058 operator() (const eT& a, const eT& b) const
00059 {
00060 return (std::abs(a) < std::abs(b));
00061 }
00062 };
00063
00064
00065
00066 template<typename T>
00067 class arma_descend_sort_helper< std::complex<T> >
00068 {
00069 public:
00070
00071 typedef typename std::complex<T> eT;
00072
00073 inline
00074 bool
00075 operator() (const eT& a, const eT& b) const
00076 {
00077 return (std::abs(a) > std::abs(b));
00078 }
00079 };
00080
00081
00082
00083 template<typename eT>
00084 inline
00085 void
00086 op_sort::direct_sort(eT* X, const uword n_elem, const uword sort_type)
00087 {
00088 arma_extra_debug_sigprint();
00089
00090 if(sort_type == 0)
00091 {
00092 arma_ascend_sort_helper<eT> comparator;
00093
00094 std::sort(&X[0], &X[n_elem], comparator);
00095 }
00096 else
00097 {
00098 arma_descend_sort_helper<eT> comparator;
00099
00100 std::sort(&X[0], &X[n_elem], comparator);
00101 }
00102 }
00103
00104
00105
00106 template<typename eT>
00107 inline
00108 void
00109 op_sort::copy_row(eT* X, const Mat<eT>& A, const uword row)
00110 {
00111 const uword N = A.n_cols;
00112
00113 uword i,j;
00114
00115 for(i=0, j=1; j<N; i+=2, j+=2)
00116 {
00117 X[i] = A.at(row,i);
00118 X[j] = A.at(row,j);
00119 }
00120
00121 if(i < N)
00122 {
00123 X[i] = A.at(row,i);
00124 }
00125 }
00126
00127
00128
00129 template<typename eT>
00130 inline
00131 void
00132 op_sort::copy_row(Mat<eT>& A, const eT* X, const uword row)
00133 {
00134 const uword N = A.n_cols;
00135
00136 uword i,j;
00137
00138 for(i=0, j=1; j<N; i+=2, j+=2)
00139 {
00140 A.at(row,i) = X[i];
00141 A.at(row,j) = X[j];
00142 }
00143
00144 if(i < N)
00145 {
00146 A.at(row,i) = X[i];
00147 }
00148 }
00149
00150
00151
00152 template<typename T1>
00153 inline
00154 void
00155 op_sort::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_sort>& in)
00156 {
00157 arma_extra_debug_sigprint();
00158
00159 typedef typename T1::elem_type eT;
00160
00161 const unwrap<T1> tmp(in.m);
00162 const Mat<eT>& X = tmp.M;
00163
00164 const uword sort_type = in.aux_uword_a;
00165 const uword dim = in.aux_uword_b;
00166
00167 arma_debug_check( (sort_type > 1), "sort(): incorrect usage. sort_type must be 0 or 1");
00168 arma_debug_check( (dim > 1), "sort(): incorrect usage. dim must be 0 or 1" );
00169 arma_debug_check( (X.is_finite() == false), "sort(): given object has non-finite elements" );
00170
00171 if( (X.n_rows * X.n_cols) <= 1 )
00172 {
00173 out = X;
00174 return;
00175 }
00176
00177
00178 if(dim == 0)
00179 {
00180 arma_extra_debug_print("op_sort::apply(), dim = 0");
00181
00182 out = X;
00183
00184 const uword n_rows = out.n_rows;
00185 const uword n_cols = out.n_cols;
00186
00187 for(uword col=0; col < n_cols; ++col)
00188 {
00189 op_sort::direct_sort( out.colptr(col), n_rows, sort_type );
00190 }
00191 }
00192 else
00193 if(dim == 1)
00194 {
00195 if(X.n_rows == 1)
00196 {
00197 arma_extra_debug_print("op_sort::apply(), dim = 1, vector specific");
00198
00199 out = X;
00200 op_sort::direct_sort(out.memptr(), out.n_elem, sort_type);
00201 }
00202 else
00203 {
00204 arma_extra_debug_print("op_sort::apply(), dim = 1, generic");
00205
00206 out.copy_size(X);
00207
00208 const uword n_rows = out.n_rows;
00209 const uword n_cols = out.n_cols;
00210
00211 podarray<eT> tmp_array(n_cols);
00212
00213 for(uword row=0; row < n_rows; ++row)
00214 {
00215 op_sort::copy_row(tmp_array.memptr(), X, row);
00216
00217 op_sort::direct_sort( tmp_array.memptr(), n_cols, sort_type );
00218
00219 op_sort::copy_row(out, tmp_array.memptr(), row);
00220 }
00221 }
00222 }
00223
00224 }
00225
00226