13 #include <Eigen/CXX11/Tensor> 19 template <
int DataLayout>
24 tensor = (tensor + tensor.constant(0.5)).
log();
27 index_tuples = tensor.index_tuples();
36 template <
int DataLayout>
41 tensor = (tensor + tensor.constant(0.5)).
log();
45 index_tuples = tensor.index_tuples();
54 template <
int DataLayout>
59 tensor = (tensor + tensor.constant(0.5)).
log();
62 index_tuples = tensor.index_tuples();
66 reduced = index_tuples.reduce(
74 for (
int d = 0;
d < 3; ++
d) reduce_dims[
d] =
d;
76 reduced_by_dims = index_tuples.reduce(
81 for (
int l = 0;
l < 7; ++
l) {
86 template <
int DataLayout>
91 tensor = (tensor + tensor.constant(0.5)).
log();
94 index_tuples = tensor.index_tuples();
98 reduced = index_tuples.reduce(
106 for (
int d = 0;
d < 3; ++
d) reduce_dims[
d] =
d;
108 reduced_by_dims = index_tuples.reduce(
113 for (
int l = 0;
l < 7; ++
l) {
118 template <
int DataLayout>
123 tensor = (tensor + tensor.constant(0.5)).
log();
124 tensor(0,0,0,0) = 10.0;
128 tensor_argmax = tensor.argmax();
132 tensor(1,2,4,6) = 20.0;
134 tensor_argmax = tensor.argmax();
139 template <
int DataLayout>
144 tensor = (tensor + tensor.constant(0.5)).
log();
145 tensor(0,0,0,0) = -10.0;
149 tensor_argmin = tensor.argmin();
153 tensor(1,2,4,6) = -20.0;
155 tensor_argmin = tensor.argmin();
160 template <
int DataLayout>
164 std::vector<int> dims {2, 3, 5, 7};
166 for (
int dim = 0; dim < 4; ++dim) {
168 tensor = (tensor + tensor.constant(0.5)).
log();
172 for (
int i = 0;
i < 2; ++
i) {
173 for (
int j = 0;
j < 3; ++
j) {
174 for (
int k = 0; k < 5; ++k) {
175 for (
int l = 0;
l < 7; ++
l) {
176 ix[0] =
i; ix[1] =
j; ix[2] = k; ix[3] =
l;
177 if (ix[dim] != 0)
continue;
185 tensor_argmax = tensor.argmax(dim);
188 ptrdiff_t(2*3*5*7 / tensor.
dimension(dim)));
189 for (ptrdiff_t
n = 0;
n < tensor_argmax.
size(); ++
n) {
194 for (
int i = 0;
i < 2; ++
i) {
195 for (
int j = 0;
j < 3; ++
j) {
196 for (
int k = 0; k < 5; ++k) {
197 for (
int l = 0;
l < 7; ++
l) {
198 ix[0] =
i; ix[1] =
j; ix[2] = k; ix[3] =
l;
199 if (ix[dim] != tensor.
dimension(dim) - 1)
continue;
207 tensor_argmax = tensor.argmax(dim);
210 ptrdiff_t(2*3*5*7 / tensor.
dimension(dim)));
211 for (ptrdiff_t
n = 0;
n < tensor_argmax.
size(); ++
n) {
218 template <
int DataLayout>
222 std::vector<int> dims {2, 3, 5, 7};
224 for (
int dim = 0; dim < 4; ++dim) {
226 tensor = (tensor + tensor.constant(0.5)).
log();
230 for (
int i = 0;
i < 2; ++
i) {
231 for (
int j = 0;
j < 3; ++
j) {
232 for (
int k = 0; k < 5; ++k) {
233 for (
int l = 0;
l < 7; ++
l) {
234 ix[0] =
i; ix[1] =
j; ix[2] = k; ix[3] =
l;
235 if (ix[dim] != 0)
continue;
243 tensor_argmin = tensor.argmin(dim);
246 ptrdiff_t(2*3*5*7 / tensor.
dimension(dim)));
247 for (ptrdiff_t
n = 0;
n < tensor_argmin.
size(); ++
n) {
252 for (
int i = 0;
i < 2; ++
i) {
253 for (
int j = 0;
j < 3; ++
j) {
254 for (
int k = 0; k < 5; ++k) {
255 for (
int l = 0;
l < 7; ++
l) {
256 ix[0] =
i; ix[1] =
j; ix[2] = k; ix[3] =
l;
257 if (ix[dim] != tensor.
dimension(dim) - 1)
continue;
265 tensor_argmin = tensor.argmin(dim);
268 ptrdiff_t(2*3*5*7 / tensor.
dimension(dim)));
269 for (ptrdiff_t
n = 0;
n < tensor_argmin.
size(); ++
n) {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar & coeff(const array< Index, NumIndices > &indices) const
static void test_simple_index_tuples()
static void test_argmax_dim()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
static void test_index_tuples_dim()
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T maxi(const T &x, const T &y)
EIGEN_DEVICE_FUNC const LogReturnType log() const
static const Line3 l(Rot3(), 1, 1)
#define VERIFY_IS_EQUAL(a, b)
static void test_simple_argmin()
static void test_argmin_tuple_reducer()
Array< int, Dynamic, 1 > v
static void test_simple_argmax()
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T mini(const T &x, const T &y)
static void test_argmax_tuple_reducer()
EIGEN_DECLARE_TEST(cxx11_tensor_argmax)
EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
#define CALL_SUBTEST(FUNC)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
static const int DataLayout
static void test_argmin_dim()