12 #include <Eigen/CXX11/Tensor> 16 template <
int DataLayout>
21 array<ptrdiff_t, 4> broadcasts;
28 no_broadcast = tensor.broadcast(broadcasts);
30 VERIFY_IS_EQUAL(no_broadcast.
dimension(0), 2);
31 VERIFY_IS_EQUAL(no_broadcast.
dimension(1), 3);
32 VERIFY_IS_EQUAL(no_broadcast.
dimension(2), 5);
33 VERIFY_IS_EQUAL(no_broadcast.
dimension(3), 7);
35 for (
int i = 0; i < 2; ++i) {
36 for (
int j = 0; j < 3; ++j) {
37 for (
int k = 0; k < 5; ++k) {
38 for (
int l = 0; l < 7; ++l) {
39 VERIFY_IS_EQUAL(tensor(i,j,k,l), no_broadcast(i,j,k,l));
50 broadcast = tensor.broadcast(broadcasts);
52 VERIFY_IS_EQUAL(broadcast.
dimension(0), 4);
53 VERIFY_IS_EQUAL(broadcast.
dimension(1), 9);
54 VERIFY_IS_EQUAL(broadcast.
dimension(2), 5);
55 VERIFY_IS_EQUAL(broadcast.
dimension(3), 28);
57 for (
int i = 0; i < 4; ++i) {
58 for (
int j = 0; j < 9; ++j) {
59 for (
int k = 0; k < 5; ++k) {
60 for (
int l = 0; l < 28; ++l) {
61 VERIFY_IS_EQUAL(tensor(i%2,j%3,k%5,l%7), broadcast(i,j,k,l));
69 template <
int DataLayout>
74 array<ptrdiff_t, 3> broadcasts;
80 broadcast = tensor.broadcast(broadcasts);
82 VERIFY_IS_EQUAL(broadcast.
dimension(0), 16);
83 VERIFY_IS_EQUAL(broadcast.
dimension(1), 9);
84 VERIFY_IS_EQUAL(broadcast.
dimension(2), 20);
86 for (
int i = 0; i < 16; ++i) {
87 for (
int j = 0; j < 9; ++j) {
88 for (
int k = 0; k < 20; ++k) {
89 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
96 broadcast = tensor.broadcast(broadcasts);
98 VERIFY_IS_EQUAL(broadcast.
dimension(0), 22);
99 VERIFY_IS_EQUAL(broadcast.
dimension(1), 9);
100 VERIFY_IS_EQUAL(broadcast.
dimension(2), 20);
102 for (
int i = 0; i < 22; ++i) {
103 for (
int j = 0; j < 9; ++j) {
104 for (
int k = 0; k < 20; ++k) {
105 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
112 template <
int DataLayout>
118 #if EIGEN_HAS_CONSTEXPR 119 Eigen::IndexList<Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> broadcasts;
128 broadcast = tensor.broadcast(broadcasts);
130 VERIFY_IS_EQUAL(broadcast.
dimension(0), 16);
131 VERIFY_IS_EQUAL(broadcast.
dimension(1), 9);
132 VERIFY_IS_EQUAL(broadcast.
dimension(2), 20);
134 for (
int i = 0; i < 16; ++i) {
135 for (
int j = 0; j < 9; ++j) {
136 for (
int k = 0; k < 20; ++k) {
137 VERIFY_IS_EQUAL(tensor(i%8,j%3,k%5), broadcast(i,j,k));
144 broadcast = tensor.broadcast(broadcasts);
146 VERIFY_IS_EQUAL(broadcast.
dimension(0), 22);
147 VERIFY_IS_EQUAL(broadcast.
dimension(1), 9);
148 VERIFY_IS_EQUAL(broadcast.
dimension(2), 20);
150 for (
int i = 0; i < 22; ++i) {
151 for (
int j = 0; j < 9; ++j) {
152 for (
int k = 0; k < 20; ++k) {
153 VERIFY_IS_EQUAL(tensor(i%11,j%3,k%5), broadcast(i,j,k));
160 template <
int DataLayout>
167 TensorFixedSize<float, Sizes<1>, DataLayout> t2;
168 t2 = t2.constant(20.0
f);
171 for (
int i = 0; i < 10; ++i) {
172 VERIFY_IS_APPROX(t3(i), t1(i) + t2(0));
175 TensorMap<TensorFixedSize<float, Sizes<1>, DataLayout> > t4(t2.data(), {{1}});
177 for (
int i = 0; i < 10; ++i) {
178 VERIFY_IS_APPROX(t5(i), t1(i) + t2(0));
186 CALL_SUBTEST(test_simple_broadcasting<ColMajor>());
187 CALL_SUBTEST(test_simple_broadcasting<RowMajor>());
188 CALL_SUBTEST(test_vectorized_broadcasting<ColMajor>());
189 CALL_SUBTEST(test_vectorized_broadcasting<RowMajor>());
190 CALL_SUBTEST(test_static_broadcasting<ColMajor>());
191 CALL_SUBTEST(test_static_broadcasting<RowMajor>());
192 CALL_SUBTEST(test_fixed_size_broadcasting<ColMajor>());
193 CALL_SUBTEST(test_fixed_size_broadcasting<RowMajor>());
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
void test_cxx11_tensor_broadcasting()
static void test_simple_broadcasting()
static void test_vectorized_broadcasting()
EIGEN_DEVICE_FUNC void resize(const array< Index, NumIndices > &dimensions)
static void test_static_broadcasting()
static void test_fixed_size_broadcasting()