13 #include <Eigen/CXX11/Tensor> 19 template <
int DataLayout>
24 tensor = (tensor + tensor.constant(0.5)).
log();
27 index_tuples = tensor.index_tuples();
31 VERIFY_IS_EQUAL(v.
first, n);
36 template <
int DataLayout>
41 tensor = (tensor + tensor.constant(0.5)).
log();
45 index_tuples = tensor.index_tuples();
49 VERIFY_IS_EQUAL(v.
first, n);
50 VERIFY_IS_EQUAL(v.
second, tensor(n));
54 template <
int DataLayout>
59 tensor = (tensor + tensor.constant(0.5)).
log();
62 index_tuples = tensor.index_tuples();
65 DimensionList<DenseIndex, 4> dims;
66 reduced = index_tuples.reduce(
71 VERIFY_IS_EQUAL(maxi(), reduced(0).second);
74 for (
int d = 0;
d < 3; ++
d) reduce_dims[
d] =
d;
76 reduced_by_dims = index_tuples.reduce(
81 for (
int l = 0; l < 7; ++l) {
82 VERIFY_IS_EQUAL(max_by_dims(l), reduced_by_dims(l).second);
86 template <
int DataLayout>
91 tensor = (tensor + tensor.constant(0.5)).
log();
94 index_tuples = tensor.index_tuples();
97 DimensionList<DenseIndex, 4> dims;
98 reduced = index_tuples.reduce(
103 VERIFY_IS_EQUAL(mini(), reduced(0).second);
106 for (
int d = 0;
d < 3; ++
d) reduce_dims[
d] =
d;
108 reduced_by_dims = index_tuples.reduce(
113 for (
int l = 0; l < 7; ++l) {
114 VERIFY_IS_EQUAL(min_by_dims(l), reduced_by_dims(l).second);
118 template <
int DataLayout>
123 tensor = (tensor + tensor.constant(0.5)).
log();
124 tensor(0,0,0,0) = 10.0;
128 tensor_argmax = tensor.argmax();
130 VERIFY_IS_EQUAL(tensor_argmax(0), 0);
132 tensor(1,2,4,6) = 20.0;
134 tensor_argmax = tensor.argmax();
136 VERIFY_IS_EQUAL(tensor_argmax(0), 2*3*5*7 - 1);
139 template <
int DataLayout>
144 tensor = (tensor + tensor.constant(0.5)).
log();
145 tensor(0,0,0,0) = -10.0;
149 tensor_argmin = tensor.argmin();
151 VERIFY_IS_EQUAL(tensor_argmin(0), 0);
153 tensor(1,2,4,6) = -20.0;
155 tensor_argmin = tensor.argmin();
157 VERIFY_IS_EQUAL(tensor_argmin(0), 2*3*5*7 - 1);
160 template <
int DataLayout>
164 std::vector<int> dims {2, 3, 5, 7};
168 tensor = (tensor + tensor.constant(0.5)).
log();
172 for (
int i = 0; i < 2; ++i) {
173 for (
int j = 0; j < 3; ++j) {
174 for (
int k = 0; k < 5; ++k) {
175 for (
int l = 0; l < 7; ++l) {
176 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
177 if (ix[
dim] != 0)
continue;
185 tensor_argmax = tensor.argmax(
dim);
187 VERIFY_IS_EQUAL(tensor_argmax.
size(),
189 for (ptrdiff_t n = 0; n < tensor_argmax.
size(); ++n) {
191 VERIFY_IS_EQUAL(tensor_argmax.
data()[n], 0);
194 for (
int i = 0; i < 2; ++i) {
195 for (
int j = 0; j < 3; ++j) {
196 for (
int k = 0; k < 5; ++k) {
197 for (
int l = 0; l < 7; ++l) {
198 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
207 tensor_argmax = tensor.argmax(
dim);
209 VERIFY_IS_EQUAL(tensor_argmax.
size(),
211 for (ptrdiff_t n = 0; n < tensor_argmax.
size(); ++n) {
218 template <
int DataLayout>
222 std::vector<int> dims {2, 3, 5, 7};
226 tensor = (tensor + tensor.constant(0.5)).
log();
230 for (
int i = 0; i < 2; ++i) {
231 for (
int j = 0; j < 3; ++j) {
232 for (
int k = 0; k < 5; ++k) {
233 for (
int l = 0; l < 7; ++l) {
234 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
235 if (ix[
dim] != 0)
continue;
243 tensor_argmin = tensor.argmin(
dim);
245 VERIFY_IS_EQUAL(tensor_argmin.
size(),
247 for (ptrdiff_t n = 0; n < tensor_argmin.
size(); ++n) {
249 VERIFY_IS_EQUAL(tensor_argmin.
data()[n], 0);
252 for (
int i = 0; i < 2; ++i) {
253 for (
int j = 0; j < 3; ++j) {
254 for (
int k = 0; k < 5; ++k) {
255 for (
int l = 0; l < 7; ++l) {
256 ix[0] = i; ix[1] = j; ix[2] = k; ix[3] = l;
265 tensor_argmin = tensor.argmin(
dim);
267 VERIFY_IS_EQUAL(tensor_argmin.
size(),
269 for (ptrdiff_t n = 0; n < tensor_argmin.
size(); ++n) {
278 CALL_SUBTEST(test_simple_index_tuples<RowMajor>());
279 CALL_SUBTEST(test_simple_index_tuples<ColMajor>());
280 CALL_SUBTEST(test_index_tuples_dim<RowMajor>());
281 CALL_SUBTEST(test_index_tuples_dim<ColMajor>());
282 CALL_SUBTEST(test_argmax_tuple_reducer<RowMajor>());
283 CALL_SUBTEST(test_argmax_tuple_reducer<ColMajor>());
284 CALL_SUBTEST(test_argmin_tuple_reducer<RowMajor>());
285 CALL_SUBTEST(test_argmin_tuple_reducer<ColMajor>());
286 CALL_SUBTEST(test_simple_argmax<RowMajor>());
287 CALL_SUBTEST(test_simple_argmax<ColMajor>());
288 CALL_SUBTEST(test_simple_argmin<RowMajor>());
289 CALL_SUBTEST(test_simple_argmin<ColMajor>());
290 CALL_SUBTEST(test_argmax_dim<RowMajor>());
291 CALL_SUBTEST(test_argmax_dim<ColMajor>());
292 CALL_SUBTEST(test_argmin_dim<RowMajor>());
293 CALL_SUBTEST(test_argmin_dim<ColMajor>());
static void test_simple_index_tuples()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
static void test_argmax_dim()
EIGEN_DEVICE_FUNC const LogReturnType log() const
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
static void test_index_tuples_dim()
void test_cxx11_tensor_argmax()
static void test_simple_argmin()
static void test_argmin_tuple_reducer()
static void test_simple_argmax()
static void test_argmax_tuple_reducer()
const mpreal dim(const mpreal &a, const mpreal &b, mp_rnd_t r=mpreal::get_default_rnd())
EIGEN_DEFAULT_DENSE_INDEX_TYPE DenseIndex
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar & coeff(const array< Index, NumIndices > &indices) const
static void test_argmin_dim()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const