cxx11_tensor_inflation.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2015 Ke Yang <yangke@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template<int DataLayout>
17 static void test_simple_inflation()
18 {
19  Tensor<float, 4, DataLayout> tensor(2,3,5,7);
20  tensor.setRandom();
21  array<ptrdiff_t, 4> strides;
22 
23  strides[0] = 1;
24  strides[1] = 1;
25  strides[2] = 1;
26  strides[3] = 1;
27 
29  no_stride = tensor.inflate(strides);
30 
31  VERIFY_IS_EQUAL(no_stride.dimension(0), 2);
32  VERIFY_IS_EQUAL(no_stride.dimension(1), 3);
33  VERIFY_IS_EQUAL(no_stride.dimension(2), 5);
34  VERIFY_IS_EQUAL(no_stride.dimension(3), 7);
35 
36  for (int i = 0; i < 2; ++i) {
37  for (int j = 0; j < 3; ++j) {
38  for (int k = 0; k < 5; ++k) {
39  for (int l = 0; l < 7; ++l) {
40  VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l));
41  }
42  }
43  }
44  }
45 
46  strides[0] = 2;
47  strides[1] = 4;
48  strides[2] = 2;
49  strides[3] = 3;
51  inflated = tensor.inflate(strides);
52 
53  VERIFY_IS_EQUAL(inflated.dimension(0), 3);
54  VERIFY_IS_EQUAL(inflated.dimension(1), 9);
55  VERIFY_IS_EQUAL(inflated.dimension(2), 9);
56  VERIFY_IS_EQUAL(inflated.dimension(3), 19);
57 
58  for (int i = 0; i < 3; ++i) {
59  for (int j = 0; j < 9; ++j) {
60  for (int k = 0; k < 9; ++k) {
61  for (int l = 0; l < 19; ++l) {
62  if (i % 2 == 0 &&
63  j % 4 == 0 &&
64  k % 2 == 0 &&
65  l % 3 == 0) {
66  VERIFY_IS_EQUAL(inflated(i,j,k,l),
67  tensor(i/2, j/4, k/2, l/3));
68  } else {
69  VERIFY_IS_EQUAL(0, inflated(i,j,k,l));
70  }
71  }
72  }
73  }
74  }
75 }
76 
78 {
79  CALL_SUBTEST(test_simple_inflation<ColMajor>());
80  CALL_SUBTEST(test_simple_inflation<RowMajor>());
81 }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:101
void test_cxx11_tensor_inflation()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition: TensorBase.h:848
static void test_simple_inflation()
The tensor class.
Definition: Tensor.h:63


hebiros
Author(s): Xavier Artache , Matthew Tesch
autogenerated on Thu Sep 3 2020 04:08:09