$search
00001 // Copyright (C) 2009-2011 NICTA (www.nicta.com.au) 00002 // Copyright (C) 2009-2011 Conrad Sanderson 00003 // Copyright (C) 2009-2010 Dimitrios Bouzas 00004 // 00005 // This file is part of the Armadillo C++ library. 00006 // It is provided without any warranty of fitness 00007 // for any purpose. You can redistribute this file 00008 // and/or modify it under the terms of the GNU 00009 // Lesser General Public License (LGPL) as published 00010 // by the Free Software Foundation, either version 3 00011 // of the License or (at your option) any later version. 00012 // (see http://www.opensource.org/licenses for more info) 00013 00014 00015 00018 00019 00020 00021 template<typename T1> 00022 inline 00023 void 00024 op_shuffle::apply(Mat<typename T1::elem_type>& out, const Op<T1,op_shuffle>& in) 00025 { 00026 arma_extra_debug_sigprint(); 00027 00028 typedef typename T1::elem_type eT; 00029 00030 const unwrap<T1> tmp(in.m); 00031 const Mat<eT>& X = tmp.M; 00032 00033 if(X.is_empty()) 00034 { 00035 out.copy_size(X); 00036 return; 00037 } 00038 00039 const uword dim = in.aux_uword_a; 00040 const uword N = (dim == 0) ? X.n_rows : X.n_cols; 00041 00042 // see "fn_sort_index.hpp" for the definition of "arma_sort_index_packet_ascend" 00043 // and the associated "operator<" 00044 std::vector< arma_sort_index_packet_ascend<int,uword> > packet_vec(N); 00045 00046 for(uword i=0; i<N; ++i) 00047 { 00048 packet_vec[i].val = std::rand(); 00049 packet_vec[i].index = i; 00050 } 00051 00052 std::sort( packet_vec.begin(), packet_vec.end() ); 00053 00054 if(X.is_vec() == false) 00055 { 00056 if(&out != &X) 00057 { 00058 arma_extra_debug_print("op_shuffle::apply(): matrix"); 00059 00060 out.copy_size(X); 00061 00062 if(dim == 0) 00063 { 00064 for(uword i=0; i<N; ++i) 00065 { 00066 out.row(i) = X.row(packet_vec[i].index); 00067 } 00068 } 00069 else 00070 { 00071 for(uword i=0; i<N; ++i) 00072 { 00073 out.col(i) = X.col(packet_vec[i].index); 00074 } 00075 } 00076 } 00077 else // in-place shuffle 00078 { 00079 arma_extra_debug_print("op_shuffle::apply(): in-place matrix"); 00080 00081 // reuse the val member variable of packet_vec 00082 // to indicate whether a particular row or column 00083 // has already been shuffled 00084 00085 for(uword i=0; i<N; ++i) 00086 { 00087 packet_vec[i].val = 0; 00088 } 00089 00090 if(dim == 0) 00091 { 00092 for(uword i=0; i<N; ++i) 00093 { 00094 if(packet_vec[i].val == 0) 00095 { 00096 const uword j = packet_vec[i].index; 00097 00098 out.swap_rows(i, j); 00099 00100 packet_vec[j].val = 1; 00101 } 00102 } 00103 } 00104 else 00105 { 00106 for(uword i=0; i<N; ++i) 00107 { 00108 if(packet_vec[i].val == 0) 00109 { 00110 const uword j = packet_vec[i].index; 00111 00112 out.swap_cols(i, j); 00113 00114 packet_vec[j].val = 1; 00115 } 00116 } 00117 } 00118 } 00119 } 00120 else // we're dealing with a vector 00121 { 00122 if(&out != &X) 00123 { 00124 arma_extra_debug_print("op_shuffle::apply(): vector"); 00125 00126 out.copy_size(X); 00127 00128 if(dim == 0) 00129 { 00130 if(X.n_rows > 1) // i.e. column vector 00131 { 00132 for(uword i=0; i<N; ++i) 00133 { 00134 out[i] = X[ packet_vec[i].index ]; 00135 } 00136 } 00137 else 00138 { 00139 out = X; 00140 } 00141 } 00142 else 00143 { 00144 if(X.n_cols > 1) // i.e. row vector 00145 { 00146 for(uword i=0; i<N; ++i) 00147 { 00148 out[i] = X[ packet_vec[i].index ]; 00149 } 00150 } 00151 else 00152 { 00153 out = X; 00154 } 00155 } 00156 } 00157 else // in-place shuffle 00158 { 00159 arma_extra_debug_print("op_shuffle::apply(): in-place vector"); 00160 00161 // reuse the val member variable of packet_vec 00162 // to indicate whether a particular row or column 00163 // has already been shuffled 00164 00165 for(uword i=0; i<N; ++i) 00166 { 00167 packet_vec[i].val = 0; 00168 } 00169 00170 if(dim == 0) 00171 { 00172 if(X.n_rows > 1) // i.e. column vector 00173 { 00174 for(uword i=0; i<N; ++i) 00175 { 00176 if(packet_vec[i].val == 0) 00177 { 00178 const uword j = packet_vec[i].index; 00179 00180 std::swap(out[i], out[j]); 00181 00182 packet_vec[j].val = 1; 00183 } 00184 } 00185 } 00186 } 00187 else 00188 { 00189 if(X.n_cols > 1) // i.e. row vector 00190 { 00191 for(uword i=0; i<N; ++i) 00192 { 00193 if(packet_vec[i].val == 0) 00194 { 00195 const uword j = packet_vec[i].index; 00196 00197 std::swap(out[i], out[j]); 00198 00199 packet_vec[j].val = 1; 00200 } 00201 } 00202 } 00203 } 00204 } 00205 } 00206 00207 } 00208 00209