cxx11_tensor_argmax.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Eugene Brevdo <ebrevdo@google.com>
5 // Benoit Steiner <benoit.steiner.goog@gmail.com>
6 //
7 // This Source Code Form is subject to the terms of the Mozilla
8 // Public License v. 2.0. If a copy of the MPL was not distributed
9 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
10 
11 #include "main.h"
12 
13 #include <Eigen/CXX11/Tensor>
14 
15 using Eigen::Tensor;
16 using Eigen::array;
17 using Eigen::Tuple;
18 
19 template <int DataLayout>
21 {
22  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
23  tensor.setRandom();
24  tensor = (tensor + tensor.constant(0.5)).log();
25 
26  Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
27  index_tuples = tensor.index_tuples();
28 
29  for (DenseIndex n = 0; n < 2*3*5*7; ++n) {
30  const Tuple<DenseIndex, float>& v = index_tuples.coeff(n);
32  VERIFY_IS_EQUAL(v.second, tensor.coeff(n));
33  }
34 }
35 
36 template <int DataLayout>
37 static void test_index_tuples_dim()
38 {
39  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
40  tensor.setRandom();
41  tensor = (tensor + tensor.constant(0.5)).log();
42 
43  Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
44 
45  index_tuples = tensor.index_tuples();
46 
47  for (Eigen::DenseIndex n = 0; n < tensor.size(); ++n) {
48  const Tuple<DenseIndex, float>& v = index_tuples(n); //(i, j, k, l);
50  VERIFY_IS_EQUAL(v.second, tensor(n));
51  }
52 }
53 
54 template <int DataLayout>
56 {
57  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
58  tensor.setRandom();
59  tensor = (tensor + tensor.constant(0.5)).log();
60 
61  Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
62  index_tuples = tensor.index_tuples();
63 
66  reduced = index_tuples.reduce(
67  dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
68 
69  Tensor<float, 0, DataLayout> maxi = tensor.maximum();
70 
71  VERIFY_IS_EQUAL(maxi(), reduced(0).second);
72 
73  array<DenseIndex, 3> reduce_dims;
74  for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
75  Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
76  reduced_by_dims = index_tuples.reduce(
77  reduce_dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
78 
79  Tensor<float, 1, DataLayout> max_by_dims = tensor.maximum(reduce_dims);
80 
81  for (int l = 0; l < 7; ++l) {
82  VERIFY_IS_EQUAL(max_by_dims(l), reduced_by_dims(l).second);
83  }
84 }
85 
86 template <int DataLayout>
88 {
89  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
90  tensor.setRandom();
91  tensor = (tensor + tensor.constant(0.5)).log();
92 
93  Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
94  index_tuples = tensor.index_tuples();
95 
98  reduced = index_tuples.reduce(
99  dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
100 
101  Tensor<float, 0, DataLayout> mini = tensor.minimum();
102 
103  VERIFY_IS_EQUAL(mini(), reduced(0).second);
104 
105  array<DenseIndex, 3> reduce_dims;
106  for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
107  Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
108  reduced_by_dims = index_tuples.reduce(
109  reduce_dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
110 
111  Tensor<float, 1, DataLayout> min_by_dims = tensor.minimum(reduce_dims);
112 
113  for (int l = 0; l < 7; ++l) {
114  VERIFY_IS_EQUAL(min_by_dims(l), reduced_by_dims(l).second);
115  }
116 }
117 
118 template <int DataLayout>
119 static void test_simple_argmax()
120 {
121  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
122  tensor.setRandom();
123  tensor = (tensor + tensor.constant(0.5)).log();
124  tensor(0,0,0,0) = 10.0;
125 
126  Tensor<DenseIndex, 0, DataLayout> tensor_argmax;
127 
128  tensor_argmax = tensor.argmax();
129 
130  VERIFY_IS_EQUAL(tensor_argmax(0), 0);
131 
132  tensor(1,2,4,6) = 20.0;
133 
134  tensor_argmax = tensor.argmax();
135 
136  VERIFY_IS_EQUAL(tensor_argmax(0), 2*3*5*7 - 1);
137 }
138 
139 template <int DataLayout>
140 static void test_simple_argmin()
141 {
142  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
143  tensor.setRandom();
144  tensor = (tensor + tensor.constant(0.5)).log();
145  tensor(0,0,0,0) = -10.0;
146 
147  Tensor<DenseIndex, 0, DataLayout> tensor_argmin;
148 
149  tensor_argmin = tensor.argmin();
150 
151  VERIFY_IS_EQUAL(tensor_argmin(0), 0);
152 
153  tensor(1,2,4,6) = -20.0;
154 
155  tensor_argmin = tensor.argmin();
156 
157  VERIFY_IS_EQUAL(tensor_argmin(0), 2*3*5*7 - 1);
158 }
159 
160 template <int DataLayout>
161 static void test_argmax_dim()
162 {
163  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
164  std::vector<int> dims {2, 3, 5, 7};
165 
166  for (int dim = 0; dim < 4; ++dim) {
167  tensor.setRandom();
168  tensor = (tensor + tensor.constant(0.5)).log();
169 
170  Tensor<DenseIndex, 3, DataLayout> tensor_argmax;
172  for (int i = 0; i < 2; ++i) {
173  for (int j = 0; j < 3; ++j) {
174  for (int k = 0; k < 5; ++k) {
175  for (int l = 0; l < 7; ++l) {
176  ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
177  if (ix[dim] != 0) continue;
178  // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = 10.0
179  tensor(ix) = 10.0;
180  }
181  }
182  }
183  }
184 
185  tensor_argmax = tensor.argmax(dim);
186 
187  VERIFY_IS_EQUAL(tensor_argmax.size(),
188  ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
189  for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
190  // Expect max to be in the first index of the reduced dimension
191  VERIFY_IS_EQUAL(tensor_argmax.data()[n], 0);
192  }
193 
194  for (int i = 0; i < 2; ++i) {
195  for (int j = 0; j < 3; ++j) {
196  for (int k = 0; k < 5; ++k) {
197  for (int l = 0; l < 7; ++l) {
198  ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
199  if (ix[dim] != tensor.dimension(dim) - 1) continue;
200  // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = 20.0
201  tensor(ix) = 20.0;
202  }
203  }
204  }
205  }
206 
207  tensor_argmax = tensor.argmax(dim);
208 
209  VERIFY_IS_EQUAL(tensor_argmax.size(),
210  ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
211  for (ptrdiff_t n = 0; n < tensor_argmax.size(); ++n) {
212  // Expect max to be in the last index of the reduced dimension
213  VERIFY_IS_EQUAL(tensor_argmax.data()[n], tensor.dimension(dim) - 1);
214  }
215  }
216 }
217 
218 template <int DataLayout>
219 static void test_argmin_dim()
220 {
221  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
222  std::vector<int> dims {2, 3, 5, 7};
223 
224  for (int dim = 0; dim < 4; ++dim) {
225  tensor.setRandom();
226  tensor = (tensor + tensor.constant(0.5)).log();
227 
228  Tensor<DenseIndex, 3, DataLayout> tensor_argmin;
230  for (int i = 0; i < 2; ++i) {
231  for (int j = 0; j < 3; ++j) {
232  for (int k = 0; k < 5; ++k) {
233  for (int l = 0; l < 7; ++l) {
234  ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
235  if (ix[dim] != 0) continue;
236  // suppose dim == 1, then for all i, k, l, set tensor(i, 0, k, l) = -10.0
237  tensor(ix) = -10.0;
238  }
239  }
240  }
241  }
242 
243  tensor_argmin = tensor.argmin(dim);
244 
245  VERIFY_IS_EQUAL(tensor_argmin.size(),
246  ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
247  for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
248  // Expect min to be in the first index of the reduced dimension
249  VERIFY_IS_EQUAL(tensor_argmin.data()[n], 0);
250  }
251 
252  for (int i = 0; i < 2; ++i) {
253  for (int j = 0; j < 3; ++j) {
254  for (int k = 0; k < 5; ++k) {
255  for (int l = 0; l < 7; ++l) {
256  ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
257  if (ix[dim] != tensor.dimension(dim) - 1) continue;
258  // suppose dim == 1, then for all i, k, l, set tensor(i, 2, k, l) = -20.0
259  tensor(ix) = -20.0;
260  }
261  }
262  }
263  }
264 
265  tensor_argmin = tensor.argmin(dim);
266 
267  VERIFY_IS_EQUAL(tensor_argmin.size(),
268  ptrdiff_t(2*3*5*7 / tensor.dimension(dim)));
269  for (ptrdiff_t n = 0; n < tensor_argmin.size(); ++n) {
270  // Expect min to be in the last index of the reduced dimension
271  VERIFY_IS_EQUAL(tensor_argmin.data()[n], tensor.dimension(dim) - 1);
272  }
273  }
274 }
275 
276 EIGEN_DECLARE_TEST(cxx11_tensor_argmax)
277 {
278  CALL_SUBTEST(test_simple_index_tuples<RowMajor>());
279  CALL_SUBTEST(test_simple_index_tuples<ColMajor>());
280  CALL_SUBTEST(test_index_tuples_dim<RowMajor>());
281  CALL_SUBTEST(test_index_tuples_dim<ColMajor>());
282  CALL_SUBTEST(test_argmax_tuple_reducer<RowMajor>());
283  CALL_SUBTEST(test_argmax_tuple_reducer<ColMajor>());
284  CALL_SUBTEST(test_argmin_tuple_reducer<RowMajor>());
285  CALL_SUBTEST(test_argmin_tuple_reducer<ColMajor>());
286  CALL_SUBTEST(test_simple_argmax<RowMajor>());
287  CALL_SUBTEST(test_simple_argmax<ColMajor>());
288  CALL_SUBTEST(test_simple_argmin<RowMajor>());
289  CALL_SUBTEST(test_simple_argmin<ColMajor>());
290  CALL_SUBTEST(test_argmax_dim<RowMajor>());
291  CALL_SUBTEST(test_argmax_dim<ColMajor>());
292  CALL_SUBTEST(test_argmin_dim<RowMajor>());
293  CALL_SUBTEST(test_argmin_dim<ColMajor>());
294 }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const
Definition: Tensor.h:103
int array[24]
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar & coeff(const array< Index, NumIndices > &indices) const
Definition: Tensor.h:124
static void test_simple_index_tuples()
static void test_argmax_dim()
int n
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition: TensorBase.h:996
static void test_index_tuples_dim()
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
EIGEN_DEVICE_FUNC const LogReturnType log() const
static const Line3 l(Rot3(), 1, 1)
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:386
static void test_simple_argmin()
static void test_argmin_tuple_reducer()
Array< int, Dynamic, 1 > v
static void test_simple_argmax()
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
static void test_argmax_tuple_reducer()
EIGEN_DECLARE_TEST(cxx11_tensor_argmax)
EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex
Definition: Meta.h:66
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
Definition: Tensor.h:104
#define CALL_SUBTEST(FUNC)
Definition: main.h:399
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:101
std::ptrdiff_t j
static const int DataLayout
The tensor class.
Definition: Tensor.h:63
static void test_argmin_dim()


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:06