cxx11_tensor_shuffling.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 using Eigen::array;
16 
17 template <int DataLayout>
18 static void test_simple_shuffling()
19 {
20  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
21  tensor.setRandom();
22  array<ptrdiff_t, 4> shuffles;
23  shuffles[0] = 0;
24  shuffles[1] = 1;
25  shuffles[2] = 2;
26  shuffles[3] = 3;
27 
29  no_shuffle = tensor.shuffle(shuffles);
30 
31  VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
32  VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
33  VERIFY_IS_EQUAL(no_shuffle.dimension(2), 5);
34  VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
35 
36  for (int i = 0; i < 2; ++i) {
37  for (int j = 0; j < 3; ++j) {
38  for (int k = 0; k < 5; ++k) {
39  for (int l = 0; l < 7; ++l) {
40  VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
41  }
42  }
43  }
44  }
45 
46  shuffles[0] = 2;
47  shuffles[1] = 3;
48  shuffles[2] = 1;
49  shuffles[3] = 0;
51  shuffle = tensor.shuffle(shuffles);
52 
53  VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
54  VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
55  VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
56  VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
57 
58  for (int i = 0; i < 2; ++i) {
59  for (int j = 0; j < 3; ++j) {
60  for (int k = 0; k < 5; ++k) {
61  for (int l = 0; l < 7; ++l) {
62  VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
63  }
64  }
65  }
66  }
67 }
68 
69 
70 template <int DataLayout>
71 static void test_expr_shuffling()
72 {
73  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
74  tensor.setRandom();
75 
76  array<ptrdiff_t, 4> shuffles;
77  shuffles[0] = 2;
78  shuffles[1] = 3;
79  shuffles[2] = 1;
80  shuffles[3] = 0;
82  expected = tensor.shuffle(shuffles);
83 
85 
86  array<ptrdiff_t, 4> src_slice_dim{{2, 3, 1, 7}};
87  array<ptrdiff_t, 4> src_slice_start{{0, 0, 0, 0}};
88  array<ptrdiff_t, 4> dst_slice_dim{{1, 7, 3, 2}};
89  array<ptrdiff_t, 4> dst_slice_start{{0, 0, 0, 0}};
90 
91  for (int i = 0; i < 5; ++i) {
92  result.slice(dst_slice_start, dst_slice_dim) =
93  tensor.slice(src_slice_start, src_slice_dim).shuffle(shuffles);
94  src_slice_start[2] += 1;
95  dst_slice_start[0] += 1;
96  }
97 
98  VERIFY_IS_EQUAL(result.dimension(0), 5);
99  VERIFY_IS_EQUAL(result.dimension(1), 7);
100  VERIFY_IS_EQUAL(result.dimension(2), 3);
101  VERIFY_IS_EQUAL(result.dimension(3), 2);
102 
103  for (int i = 0; i < expected.dimension(0); ++i) {
104  for (int j = 0; j < expected.dimension(1); ++j) {
105  for (int k = 0; k < expected.dimension(2); ++k) {
106  for (int l = 0; l < expected.dimension(3); ++l) {
107  VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
108  }
109  }
110  }
111  }
112 
113  dst_slice_start[0] = 0;
114  result.setRandom();
115  for (int i = 0; i < 5; ++i) {
116  result.slice(dst_slice_start, dst_slice_dim) =
117  tensor.shuffle(shuffles).slice(dst_slice_start, dst_slice_dim);
118  dst_slice_start[0] += 1;
119  }
120 
121  for (int i = 0; i < expected.dimension(0); ++i) {
122  for (int j = 0; j < expected.dimension(1); ++j) {
123  for (int k = 0; k < expected.dimension(2); ++k) {
124  for (int l = 0; l < expected.dimension(3); ++l) {
125  VERIFY_IS_EQUAL(result(i,j,k,l), expected(i,j,k,l));
126  }
127  }
128  }
129  }
130 }
131 
132 
133 template <int DataLayout>
135 {
136  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
137  tensor.setRandom();
138  array<ptrdiff_t, 4> shuffles;
139  shuffles[2] = 0;
140  shuffles[3] = 1;
141  shuffles[1] = 2;
142  shuffles[0] = 3;
143  Tensor<float, 4, DataLayout> shuffle(5,7,3,2);
144  shuffle.shuffle(shuffles) = tensor;
145 
146  VERIFY_IS_EQUAL(shuffle.dimension(0), 5);
147  VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
148  VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
149  VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
150 
151  for (int i = 0; i < 2; ++i) {
152  for (int j = 0; j < 3; ++j) {
153  for (int k = 0; k < 5; ++k) {
154  for (int l = 0; l < 7; ++l) {
155  VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
156  }
157  }
158  }
159  }
160 
161  array<ptrdiff_t, 4> no_shuffle;
162  no_shuffle[0] = 0;
163  no_shuffle[1] = 1;
164  no_shuffle[2] = 2;
165  no_shuffle[3] = 3;
167  shuffle2.shuffle(shuffles) = tensor.shuffle(no_shuffle);
168  for (int i = 0; i < 5; ++i) {
169  for (int j = 0; j < 7; ++j) {
170  for (int k = 0; k < 3; ++k) {
171  for (int l = 0; l < 2; ++l) {
172  VERIFY_IS_EQUAL(shuffle2(i,j,k,l), shuffle(i,j,k,l));
173  }
174  }
175  }
176  }
177 }
178 
179 
180 template <int DataLayout>
182 {
183  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
184  tensor.setRandom();
185 
186  // Choose a random permutation.
187  array<ptrdiff_t, 4> shuffles;
188  for (int i = 0; i < 4; ++i) {
189  shuffles[i] = i;
190  }
191  array<ptrdiff_t, 4> shuffles_inverse;
192  for (int i = 0; i < 4; ++i) {
193  const ptrdiff_t index = internal::random<ptrdiff_t>(i, 3);
194  shuffles_inverse[shuffles[index]] = i;
195  std::swap(shuffles[i], shuffles[index]);
196  }
197 
199  shuffle = tensor.shuffle(shuffles).shuffle(shuffles_inverse);
200 
201  VERIFY_IS_EQUAL(shuffle.dimension(0), 2);
202  VERIFY_IS_EQUAL(shuffle.dimension(1), 3);
203  VERIFY_IS_EQUAL(shuffle.dimension(2), 5);
204  VERIFY_IS_EQUAL(shuffle.dimension(3), 7);
205 
206  for (int i = 0; i < 2; ++i) {
207  for (int j = 0; j < 3; ++j) {
208  for (int k = 0; k < 5; ++k) {
209  for (int l = 0; l < 7; ++l) {
210  VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(i,j,k,l));
211  }
212  }
213  }
214  }
215 }
216 
217 
218 template <int DataLayout>
219 static void test_empty_shuffling()
220 {
221  Tensor<float, 4, DataLayout> tensor(2,3,0,7);
222  tensor.setRandom();
223  array<ptrdiff_t, 4> shuffles;
224  shuffles[0] = 0;
225  shuffles[1] = 1;
226  shuffles[2] = 2;
227  shuffles[3] = 3;
228 
229  Tensor<float, 4, DataLayout> no_shuffle;
230  no_shuffle = tensor.shuffle(shuffles);
231 
232  VERIFY_IS_EQUAL(no_shuffle.dimension(0), 2);
233  VERIFY_IS_EQUAL(no_shuffle.dimension(1), 3);
234  VERIFY_IS_EQUAL(no_shuffle.dimension(2), 0);
235  VERIFY_IS_EQUAL(no_shuffle.dimension(3), 7);
236 
237  for (int i = 0; i < 2; ++i) {
238  for (int j = 0; j < 3; ++j) {
239  for (int k = 0; k < 0; ++k) {
240  for (int l = 0; l < 7; ++l) {
241  VERIFY_IS_EQUAL(tensor(i,j,k,l), no_shuffle(i,j,k,l));
242  }
243  }
244  }
245  }
246 
247  shuffles[0] = 2;
248  shuffles[1] = 3;
249  shuffles[2] = 1;
250  shuffles[3] = 0;
252  shuffle = tensor.shuffle(shuffles);
253 
254  VERIFY_IS_EQUAL(shuffle.dimension(0), 0);
255  VERIFY_IS_EQUAL(shuffle.dimension(1), 7);
256  VERIFY_IS_EQUAL(shuffle.dimension(2), 3);
257  VERIFY_IS_EQUAL(shuffle.dimension(3), 2);
258 
259  for (int i = 0; i < 2; ++i) {
260  for (int j = 0; j < 3; ++j) {
261  for (int k = 0; k < 0; ++k) {
262  for (int l = 0; l < 7; ++l) {
263  VERIFY_IS_EQUAL(tensor(i,j,k,l), shuffle(k,l,j,i));
264  }
265  }
266  }
267  }
268 }
269 
270 
271 EIGEN_DECLARE_TEST(cxx11_tensor_shuffling)
272 {
273  CALL_SUBTEST(test_simple_shuffling<ColMajor>());
274  CALL_SUBTEST(test_simple_shuffling<RowMajor>());
275  CALL_SUBTEST(test_expr_shuffling<ColMajor>());
276  CALL_SUBTEST(test_expr_shuffling<RowMajor>());
277  CALL_SUBTEST(test_shuffling_as_value<ColMajor>());
278  CALL_SUBTEST(test_shuffling_as_value<RowMajor>());
279  CALL_SUBTEST(test_shuffle_unshuffle<ColMajor>());
280  CALL_SUBTEST(test_shuffle_unshuffle<RowMajor>());
281  CALL_SUBTEST(test_empty_shuffling<ColMajor>());
282  CALL_SUBTEST(test_empty_shuffling<RowMajor>());
283 }
Eigen::Tensor::dimension
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:101
Eigen::Tensor
The tensor class.
Definition: Tensor.h:63
Eigen::internal::shuffle2
EIGEN_STRONG_INLINE Packet4f shuffle2(const Packet4f &m, const Packet4f &n, int mask)
Definition: NEON/PacketMath.h:94
test_shuffle_unshuffle
static void test_shuffle_unshuffle()
Definition: cxx11_tensor_shuffling.cpp:181
array
int array[24]
Definition: Map_general_stride.cpp:1
test_simple_shuffling
static void test_simple_shuffling()
Definition: cxx11_tensor_shuffling.cpp:18
Eigen::array
Definition: EmulateArray.h:21
VERIFY_IS_EQUAL
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:386
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::shuffle
EIGEN_DEVICE_FUNC const EIGEN_STRONG_INLINE TensorShufflingOp< const Shuffle, const Tensor< Scalar_, NumIndices_, Options_, IndexType_ > > shuffle(const Shuffle &shfl) const
Definition: TensorBase.h:1123
result
Values result
Definition: OdometryOptimize.cpp:8
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
l
static const Line3 l(Rot3(), 1, 1)
cholesky::expected
Matrix expected
Definition: testMatrix.cpp:971
std::swap
void swap(GeographicLib::NearestNeighbor< dist_t, pos_t, distfun_t > &a, GeographicLib::NearestNeighbor< dist_t, pos_t, distfun_t > &b)
Definition: NearestNeighbor.hpp:827
test_shuffling_as_value
static void test_shuffling_as_value()
Definition: cxx11_tensor_shuffling.cpp:134
EIGEN_DECLARE_TEST
EIGEN_DECLARE_TEST(cxx11_tensor_shuffling)
Definition: cxx11_tensor_shuffling.cpp:271
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::setRandom
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition: TensorBase.h:996
test_expr_shuffling
static void test_expr_shuffling()
Definition: cxx11_tensor_shuffling.cpp:71
main.h
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::slice
EIGEN_DEVICE_FUNC const EIGEN_STRONG_INLINE TensorSlicingOp< const StartIndices, const Sizes, const Tensor< Scalar_, NumIndices_, Options_, IndexType_ > > slice(const StartIndices &startIndices, const Sizes &sizes) const
Definition: TensorBase.h:1066
test_empty_shuffling
static void test_empty_shuffling()
Definition: cxx11_tensor_shuffling.cpp:219
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
CALL_SUBTEST
#define CALL_SUBTEST(FUNC)
Definition: main.h:399


gtsam
Author(s):
autogenerated on Wed Jan 1 2025 04:01:24