cxx11_tensor_trace.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 using Eigen::array;
16 
17 template <int DataLayout>
18 static void test_0D_trace() {
20  tensor.setRandom();
22  Tensor<float, 0, DataLayout> result = tensor.trace(dims);
23  VERIFY_IS_EQUAL(result(), tensor());
24 }
25 
26 
27 template <int DataLayout>
29  Tensor<float, 3, DataLayout> tensor1(5, 5, 5);
30  tensor1.setRandom();
31  Tensor<float, 0, DataLayout> result1 = tensor1.trace();
32  VERIFY_IS_EQUAL(result1.rank(), 0);
33  float sum = 0.0f;
34  for (int i = 0; i < 5; ++i) {
35  sum += tensor1(i, i, i);
36  }
37  VERIFY_IS_EQUAL(result1(), sum);
38 
39  Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7);
40  tensor2.setRandom();
41  array<ptrdiff_t, 5> dims = { { 2, 1, 0, 3, 4 } };
42  Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims);
43  VERIFY_IS_EQUAL(result2.rank(), 0);
44  sum = 0.0f;
45  for (int i = 0; i < 7; ++i) {
46  sum += tensor2(i, i, i, i, i);
47  }
48  VERIFY_IS_EQUAL(result2(), sum);
49 }
50 
51 
52 template <int DataLayout>
53 static void test_simple_trace() {
54  Tensor<float, 3, DataLayout> tensor1(3, 5, 3);
55  tensor1.setRandom();
56  array<ptrdiff_t, 2> dims1 = { { 0, 2 } };
57  Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1);
58  VERIFY_IS_EQUAL(result1.rank(), 1);
59  VERIFY_IS_EQUAL(result1.dimension(0), 5);
60  float sum = 0.0f;
61  for (int i = 0; i < 5; ++i) {
62  sum = 0.0f;
63  for (int j = 0; j < 3; ++j) {
64  sum += tensor1(j, i, j);
65  }
66  VERIFY_IS_EQUAL(result1(i), sum);
67  }
68 
69  Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7);
70  tensor2.setRandom();
71  array<ptrdiff_t, 2> dims2 = { { 2, 3 } };
72  Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2);
73  VERIFY_IS_EQUAL(result2.rank(), 2);
74  VERIFY_IS_EQUAL(result2.dimension(0), 5);
75  VERIFY_IS_EQUAL(result2.dimension(1), 5);
76  for (int i = 0; i < 5; ++i) {
77  for (int j = 0; j < 5; ++j) {
78  sum = 0.0f;
79  for (int k = 0; k < 7; ++k) {
80  sum += tensor2(i, j, k, k);
81  }
82  VERIFY_IS_EQUAL(result2(i, j), sum);
83  }
84  }
85 
86  array<ptrdiff_t, 2> dims3 = { { 1, 0 } };
87  Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3);
88  VERIFY_IS_EQUAL(result3.rank(), 2);
89  VERIFY_IS_EQUAL(result3.dimension(0), 7);
90  VERIFY_IS_EQUAL(result3.dimension(1), 7);
91  for (int i = 0; i < 7; ++i) {
92  for (int j = 0; j < 7; ++j) {
93  sum = 0.0f;
94  for (int k = 0; k < 5; ++k) {
95  sum += tensor2(k, k, i, j);
96  }
97  VERIFY_IS_EQUAL(result3(i, j), sum);
98  }
99  }
100 
101  Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3);
102  tensor3.setRandom();
103  array<ptrdiff_t, 3> dims4 = { { 0, 2, 4 } };
104  Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4);
105  VERIFY_IS_EQUAL(result4.rank(), 2);
106  VERIFY_IS_EQUAL(result4.dimension(0), 7);
107  VERIFY_IS_EQUAL(result4.dimension(1), 7);
108  for (int i = 0; i < 7; ++i) {
109  for (int j = 0; j < 7; ++j) {
110  sum = 0.0f;
111  for (int k = 0; k < 3; ++k) {
112  sum += tensor3(k, i, k, j, k);
113  }
114  VERIFY_IS_EQUAL(result4(i, j), sum);
115  }
116  }
117 
118  Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5);
119  tensor4.setRandom();
120  array<ptrdiff_t, 2> dims5 = { { 1, 3 } };
121  Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5);
122  VERIFY_IS_EQUAL(result5.rank(), 3);
123  VERIFY_IS_EQUAL(result5.dimension(0), 3);
124  VERIFY_IS_EQUAL(result5.dimension(1), 4);
125  VERIFY_IS_EQUAL(result5.dimension(2), 5);
126  for (int i = 0; i < 3; ++i) {
127  for (int j = 0; j < 4; ++j) {
128  for (int k = 0; k < 5; ++k) {
129  sum = 0.0f;
130  for (int l = 0; l < 7; ++l) {
131  sum += tensor4(i, l, j, l, k);
132  }
133  VERIFY_IS_EQUAL(result5(i, j, k), sum);
134  }
135  }
136  }
137 }
138 
139 
140 template<int DataLayout>
141 static void test_trace_in_expr() {
142  Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3);
143  tensor.setRandom();
144  array<ptrdiff_t, 2> dims = { { 1, 3 } };
146  result = result.constant(1.0f) - tensor.trace(dims);
147  VERIFY_IS_EQUAL(result.rank(), 2);
148  VERIFY_IS_EQUAL(result.dimension(0), 2);
149  VERIFY_IS_EQUAL(result.dimension(1), 5);
150  float sum = 0.0f;
151  for (int i = 0; i < 2; ++i) {
152  for (int j = 0; j < 5; ++j) {
153  sum = 0.0f;
154  for (int k = 0; k < 3; ++k) {
155  sum += tensor(i, k, j, k);
156  }
157  VERIFY_IS_EQUAL(result(i, j), 1.0f - sum);
158  }
159  }
160 }
161 
162 
163 EIGEN_DECLARE_TEST(cxx11_tensor_trace) {
164  CALL_SUBTEST(test_0D_trace<ColMajor>());
165  CALL_SUBTEST(test_0D_trace<RowMajor>());
166  CALL_SUBTEST(test_all_dimensions_trace<ColMajor>());
167  CALL_SUBTEST(test_all_dimensions_trace<RowMajor>());
168  CALL_SUBTEST(test_simple_trace<ColMajor>());
169  CALL_SUBTEST(test_simple_trace<RowMajor>());
170  CALL_SUBTEST(test_trace_in_expr<ColMajor>());
171  CALL_SUBTEST(test_trace_in_expr<RowMajor>());
172 }
Eigen::Tensor::dimension
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition: Tensor.h:101
Eigen::Tensor
The tensor class.
Definition: Tensor.h:63
test_0D_trace
static void test_0D_trace()
Definition: cxx11_tensor_trace.cpp:18
array
int array[24]
Definition: Map_general_stride.cpp:1
Eigen::array
Definition: EmulateArray.h:21
VERIFY_IS_EQUAL
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:386
result
Values result
Definition: OdometryOptimize.cpp:8
Eigen::Tensor::rank
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const
Definition: Tensor.h:100
test_trace_in_expr
static void test_trace_in_expr()
Definition: cxx11_tensor_trace.cpp:141
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
l
static const Line3 l(Rot3(), 1, 1)
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::setRandom
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition: TensorBase.h:996
test_all_dimensions_trace
static void test_all_dimensions_trace()
Definition: cxx11_tensor_trace.cpp:28
main.h
test_simple_trace
static void test_simple_trace()
Definition: cxx11_tensor_trace.cpp:53
EIGEN_DECLARE_TEST
EIGEN_DECLARE_TEST(cxx11_tensor_trace)
Definition: cxx11_tensor_trace.cpp:163
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
CALL_SUBTEST
#define CALL_SUBTEST(FUNC)
Definition: main.h:399


gtsam
Author(s):
autogenerated on Wed Jan 1 2025 04:01:24