cxx11_tensor_concatenation.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 
16 template<int DataLayout>
18 {
21  left.setRandom();
22  right.setRandom();
23 
24  // Okay; other dimensions are equal.
25  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
26 
27  // Dimension mismatches.
28  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 1));
29  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 2));
30 
31  // Axis > NumDims or < 0.
32  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, 3));
33  VERIFY_RAISES_ASSERT(concatenation = left.concatenate(right, -1));
34 }
35 
36 template<int DataLayout>
38 {
41 
42 #ifdef CXX11_TENSOR_CONCATENATION_STATIC_DIMENSION_FAILURE
43  // Technically compatible, but we static assert that the inputs have same
44  // NumDims.
45  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
46 #endif
47 
48  // This can be worked around in this case.
49  Tensor<int, 3, DataLayout> concatenation = left
50  .reshape(Tensor<int, 3>::Dimensions(2, 3, 1))
51  .concatenate(right, 0);
52  Tensor<int, 2, DataLayout> alternative = left
53  // Clang compiler break with {{{}}} with an ambiguous error on copy constructor
54  // the variadic DSize constructor added for #ifndef EIGEN_EMULATE_CXX11_META_H.
55  // Solution:
56  // either the code should change to
57  // Tensor<int, 2>::Dimensions{{2, 3}}
58  // or Tensor<int, 2>::Dimensions{Tensor<int, 2>::Dimensions{{2, 3}}}
59  .concatenate(right.reshape(Tensor<int, 2>::Dimensions(2, 3)), 0);
60 }
61 
62 template<int DataLayout>
64 {
67  left.setRandom();
68  right.setRandom();
69 
70  Tensor<int, 3, DataLayout> concatenation = left.concatenate(right, 0);
71  VERIFY_IS_EQUAL(concatenation.dimension(0), 4);
72  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
73  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
74  for (int j = 0; j < 3; ++j) {
75  for (int i = 0; i < 2; ++i) {
76  VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
77  }
78  for (int i = 2; i < 4; ++i) {
79  VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i - 2, j, 0));
80  }
81  }
82 
83  concatenation = left.concatenate(right, 1);
84  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
85  VERIFY_IS_EQUAL(concatenation.dimension(1), 6);
86  VERIFY_IS_EQUAL(concatenation.dimension(2), 1);
87  for (int i = 0; i < 2; ++i) {
88  for (int j = 0; j < 3; ++j) {
89  VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
90  }
91  for (int j = 3; j < 6; ++j) {
92  VERIFY_IS_EQUAL(concatenation(i, j, 0), right(i, j - 3, 0));
93  }
94  }
95 
96  concatenation = left.concatenate(right, 2);
97  VERIFY_IS_EQUAL(concatenation.dimension(0), 2);
98  VERIFY_IS_EQUAL(concatenation.dimension(1), 3);
99  VERIFY_IS_EQUAL(concatenation.dimension(2), 2);
100  for (int i = 0; i < 2; ++i) {
101  for (int j = 0; j < 3; ++j) {
102  VERIFY_IS_EQUAL(concatenation(i, j, 0), left(i, j, 0));
103  VERIFY_IS_EQUAL(concatenation(i, j, 1), right(i, j, 0));
104  }
105  }
106 }
107 
108 
109 // TODO(phli): Add test once we have a real vectorized implementation.
110 // static void test_vectorized_concatenation() {}
111 
113 {
114  Tensor<int, 2> t1(2, 3);
115  Tensor<int, 2> t2(2, 3);
116  t1.setRandom();
117  t2.setRandom();
118 
119  Tensor<int, 2> result(4, 3);
120  result.setRandom();
121  t1.concatenate(t2, 0) = result;
122 
123  for (int i = 0; i < 2; ++i) {
124  for (int j = 0; j < 3; ++j) {
125  VERIFY_IS_EQUAL(t1(i, j), result(i, j));
126  VERIFY_IS_EQUAL(t2(i, j), result(i+2, j));
127  }
128  }
129 }
130 
131 
132 EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)
133 {
134  CALL_SUBTEST(test_dimension_failures<ColMajor>());
135  CALL_SUBTEST(test_dimension_failures<RowMajor>());
136  CALL_SUBTEST(test_static_dimension_failure<ColMajor>());
137  CALL_SUBTEST(test_static_dimension_failure<RowMajor>());
138  CALL_SUBTEST(test_simple_concatenation<ColMajor>());
139  CALL_SUBTEST(test_simple_concatenation<RowMajor>());
140  // CALL_SUBTEST(test_vectorized_concatenation());
142 
143 }
Eigen::Tensor::dimension
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:101
Eigen::Tensor
The tensor class.
Definition: Tensor.h:63
test_static_dimension_failure
static void test_static_dimension_failure()
Definition: cxx11_tensor_concatenation.cpp:37
test_dimension_failures
static void test_dimension_failures()
Definition: cxx11_tensor_concatenation.cpp:17
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::concatenate
EIGEN_DEVICE_FUNC const EIGEN_STRONG_INLINE TensorConcatenationOp< const Axis, const Tensor< Scalar_, NumIndices_, Options_, IndexType_ >, const OtherDerived > concatenate(const OtherDerived &other, const Axis &axis) const
Definition: TensorBase.h:1044
VERIFY_IS_EQUAL
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:386
result
Values result
Definition: OdometryOptimize.cpp:8
Eigen::DSizes
Definition: TensorDimensions.h:263
VERIFY_RAISES_ASSERT
#define VERIFY_RAISES_ASSERT(a)
Definition: main.h:340
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
left
static char left
Definition: blas_interface.hh:62
test_simple_concatenation
static void test_simple_concatenation()
Definition: cxx11_tensor_concatenation.cpp:63
EIGEN_DECLARE_TEST
EIGEN_DECLARE_TEST(cxx11_tensor_concatenation)
Definition: cxx11_tensor_concatenation.cpp:132
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::setRandom
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition: TensorBase.h:996
right
static char right
Definition: blas_interface.hh:61
main.h
test_concatenation_as_lvalue
static void test_concatenation_as_lvalue()
Definition: cxx11_tensor_concatenation.cpp:112
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
CALL_SUBTEST
#define CALL_SUBTEST(FUNC)
Definition: main.h:399


gtsam
Author(s):
autogenerated on Tue Jan 7 2025 04:02:06