13 #include <Eigen/CXX11/Tensor> 17 template <
int DataLayout,
typename Type=
float,
bool Exclusive = false>
28 for (
int i = 0; i <
size; i++) {
30 VERIFY_IS_EQUAL(result(i), accum);
34 VERIFY_IS_EQUAL(result(i), accum);
39 result = tensor.cumprod(0, Exclusive);
40 for (
int i = 0; i <
size; i++) {
42 VERIFY_IS_EQUAL(result(i), accum);
46 VERIFY_IS_EQUAL(result(i), accum);
51 template <
int DataLayout,
typename Type=
float>
60 result = tensor.cumsum(0);
62 for (
int i = 0; i <
size; i++) {
63 accum += tensor(i, 1, 2, 3);
64 VERIFY_IS_EQUAL(result(i, 1, 2, 3), accum);
66 result = tensor.cumsum(1);
68 for (
int i = 0; i <
size; i++) {
69 accum += tensor(1, i, 2, 3);
70 VERIFY_IS_EQUAL(result(1, i, 2, 3), accum);
72 result = tensor.cumsum(2);
74 for (
int i = 0; i <
size; i++) {
75 accum += tensor(1, 2, i, 3);
76 VERIFY_IS_EQUAL(result(1, 2, i, 3), accum);
78 result = tensor.cumsum(3);
80 for (
int i = 0; i <
size; i++) {
81 accum += tensor(1, 2, 3, i);
82 VERIFY_IS_EQUAL(result(1, 2, 3, i), accum);
86 template <
int DataLayout>
89 TensorMap<Tensor<int, 1, DataLayout> > tensor_map(inputs, 20);
90 tensor_map.setRandom();
95 for (
int i = 0; i < 20; ++i) {
96 accum += tensor_map(i);
97 VERIFY_IS_EQUAL(result(i), accum);
102 CALL_SUBTEST((test_1d_scan<ColMajor, float, true>()));
103 CALL_SUBTEST((test_1d_scan<ColMajor, float, false>()));
104 CALL_SUBTEST((test_1d_scan<RowMajor, float, true>()));
105 CALL_SUBTEST((test_1d_scan<RowMajor, float, false>()));
106 CALL_SUBTEST(test_4d_scan<ColMajor>());
107 CALL_SUBTEST(test_4d_scan<RowMajor>());
108 CALL_SUBTEST(test_tensor_maps<ColMajor>());
109 CALL_SUBTEST(test_tensor_maps<RowMajor>());
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
static void test_1d_scan()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
void test_cxx11_tensor_scan()
static constexpr size_t size(Tuple< Args... > &)
Provides access to the number of elements in a tuple as a compile-time constant expression.
static void test_4d_scan()
static void test_tensor_maps()