12 #include <Eigen/CXX11/Tensor> 19 template<
int DataLayout>
33 typedef TensorEvaluator<decltype(mat1.contract(mat2, dims3)),
DefaultDevice> Evaluator;
35 eval.evalTo(mat4.
data());
37 VERIFY_IS_EQUAL(eval.dimensions()[0], 3);
38 VERIFY_IS_EQUAL(eval.dimensions()[1], 3);
40 VERIFY_IS_APPROX(mat4(0,0), mat1(0,0)*mat2(0,0) + mat1(1,0)*mat2(1,0));
41 VERIFY_IS_APPROX(mat4(0,1), mat1(0,0)*mat2(0,1) + mat1(1,0)*mat2(1,1));
42 VERIFY_IS_APPROX(mat4(0,2), mat1(0,0)*mat2(0,2) + mat1(1,0)*mat2(1,2));
43 VERIFY_IS_APPROX(mat4(1,0), mat1(0,1)*mat2(0,0) + mat1(1,1)*mat2(1,0));
44 VERIFY_IS_APPROX(mat4(1,1), mat1(0,1)*mat2(0,1) + mat1(1,1)*mat2(1,1));
45 VERIFY_IS_APPROX(mat4(1,2), mat1(0,1)*mat2(0,2) + mat1(1,1)*mat2(1,2));
46 VERIFY_IS_APPROX(mat4(2,0), mat1(0,2)*mat2(0,0) + mat1(1,2)*mat2(1,0));
47 VERIFY_IS_APPROX(mat4(2,1), mat1(0,2)*mat2(0,1) + mat1(1,2)*mat2(1,1));
48 VERIFY_IS_APPROX(mat4(2,2), mat1(0,2)*mat2(0,2) + mat1(1,2)*mat2(1,2));
53 typedef TensorEvaluator<decltype(mat1.contract(mat2, dims4)),
DefaultDevice> Evaluator2;
54 Evaluator2 eval2(mat1.contract(mat2, dims4),
DefaultDevice());
55 eval2.evalTo(mat5.
data());
57 VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
58 VERIFY_IS_EQUAL(eval2.dimensions()[1], 2);
60 VERIFY_IS_APPROX(mat5(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(0,1) + mat1(0,2)*mat2(0,2));
61 VERIFY_IS_APPROX(mat5(0,1), mat1(0,0)*mat2(1,0) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(1,2));
62 VERIFY_IS_APPROX(mat5(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(0,1) + mat1(1,2)*mat2(0,2));
63 VERIFY_IS_APPROX(mat5(1,1), mat1(1,0)*mat2(1,0) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(1,2));
68 typedef TensorEvaluator<decltype(mat1.contract(mat3, dims6)),
DefaultDevice> Evaluator3;
69 Evaluator3 eval3(mat1.contract(mat3, dims6),
DefaultDevice());
70 eval3.evalTo(mat6.
data());
72 VERIFY_IS_EQUAL(eval3.dimensions()[0], 2);
73 VERIFY_IS_EQUAL(eval3.dimensions()[1], 2);
75 VERIFY_IS_APPROX(mat6(0,0), mat1(0,0)*mat3(0,0) + mat1(0,1)*mat3(1,0) + mat1(0,2)*mat3(2,0));
76 VERIFY_IS_APPROX(mat6(0,1), mat1(0,0)*mat3(0,1) + mat1(0,1)*mat3(1,1) + mat1(0,2)*mat3(2,1));
77 VERIFY_IS_APPROX(mat6(1,0), mat1(1,0)*mat3(0,0) + mat1(1,1)*mat3(1,0) + mat1(1,2)*mat3(2,0));
78 VERIFY_IS_APPROX(mat6(1,1), mat1(1,0)*mat3(0,1) + mat1(1,1)*mat3(1,1) + mat1(1,2)*mat3(2,1));
81 template<
int DataLayout>
93 float expected = 0.0f;
94 for (
int i = 0; i < 6; ++i) {
95 expected += vec1(i) * vec2(i);
97 VERIFY_IS_APPROX(scalar(), expected);
100 template<
int DataLayout>
112 typedef TensorEvaluator<decltype(mat1.contract(mat2, dims)),
DefaultDevice> Evaluator;
114 eval.evalTo(mat3.
data());
116 VERIFY_IS_EQUAL(eval.dimensions()[0], 2);
117 VERIFY_IS_EQUAL(eval.dimensions()[1], 2);
118 VERIFY_IS_EQUAL(eval.dimensions()[2], 2);
120 VERIFY_IS_APPROX(mat3(0,0,0), mat1(0,0,0)*mat2(0,0,0,0) + mat1(0,1,0)*mat2(0,0,1,0) +
121 mat1(0,0,1)*mat2(0,0,0,1) + mat1(0,1,1)*mat2(0,0,1,1));
122 VERIFY_IS_APPROX(mat3(0,0,1), mat1(0,0,0)*mat2(0,1,0,0) + mat1(0,1,0)*mat2(0,1,1,0) +
123 mat1(0,0,1)*mat2(0,1,0,1) + mat1(0,1,1)*mat2(0,1,1,1));
124 VERIFY_IS_APPROX(mat3(0,1,0), mat1(0,0,0)*mat2(1,0,0,0) + mat1(0,1,0)*mat2(1,0,1,0) +
125 mat1(0,0,1)*mat2(1,0,0,1) + mat1(0,1,1)*mat2(1,0,1,1));
126 VERIFY_IS_APPROX(mat3(0,1,1), mat1(0,0,0)*mat2(1,1,0,0) + mat1(0,1,0)*mat2(1,1,1,0) +
127 mat1(0,0,1)*mat2(1,1,0,1) + mat1(0,1,1)*mat2(1,1,1,1));
128 VERIFY_IS_APPROX(mat3(1,0,0), mat1(1,0,0)*mat2(0,0,0,0) + mat1(1,1,0)*mat2(0,0,1,0) +
129 mat1(1,0,1)*mat2(0,0,0,1) + mat1(1,1,1)*mat2(0,0,1,1));
130 VERIFY_IS_APPROX(mat3(1,0,1), mat1(1,0,0)*mat2(0,1,0,0) + mat1(1,1,0)*mat2(0,1,1,0) +
131 mat1(1,0,1)*mat2(0,1,0,1) + mat1(1,1,1)*mat2(0,1,1,1));
132 VERIFY_IS_APPROX(mat3(1,1,0), mat1(1,0,0)*mat2(1,0,0,0) + mat1(1,1,0)*mat2(1,0,1,0) +
133 mat1(1,0,1)*mat2(1,0,0,1) + mat1(1,1,1)*mat2(1,0,1,1));
134 VERIFY_IS_APPROX(mat3(1,1,1), mat1(1,0,0)*mat2(1,1,0,0) + mat1(1,1,0)*mat2(1,1,1,0) +
135 mat1(1,0,1)*mat2(1,1,0,1) + mat1(1,1,1)*mat2(1,1,1,1));
146 typedef TensorEvaluator<decltype(mat4.contract(mat5, dims2)),
DefaultDevice> Evaluator2;
147 Evaluator2 eval2(mat4.contract(mat5, dims2),
DefaultDevice());
148 eval2.evalTo(mat6.
data());
150 VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
152 VERIFY_IS_APPROX(mat6(0), mat4(0,0)*mat5(0,0,0) + mat4(1,0)*mat5(0,1,0) +
153 mat4(0,1)*mat5(1,0,0) + mat4(1,1)*mat5(1,1,0));
154 VERIFY_IS_APPROX(mat6(1), mat4(0,0)*mat5(0,0,1) + mat4(1,0)*mat5(0,1,1) +
155 mat4(0,1)*mat5(1,0,1) + mat4(1,1)*mat5(1,1,1));
158 template<
int DataLayout>
170 VERIFY_IS_EQUAL(result.
dimension(3), 11);
171 VERIFY_IS_EQUAL(result.
dimension(4), 13);
173 for (
int i = 0; i < 5; ++i) {
174 for (
int j = 0; j < 5; ++j) {
175 for (
int k = 0; k < 5; ++k) {
176 for (
int l = 0; l < 5; ++l) {
177 for (
int m = 0; m < 5; ++m) {
178 VERIFY_IS_APPROX(result(i, j, k, l, m),
179 t1(0, i, j, 0) * t2(0, k, l, m, 0) +
180 t1(1, i, j, 0) * t2(1, k, l, m, 0) +
181 t1(0, i, j, 1) * t2(0, k, l, m, 1) +
182 t1(1, i, j, 1) * t2(1, k, l, m, 1) +
183 t1(0, i, j, 2) * t2(0, k, l, m, 2) +
184 t1(1, i, j, 2) * t2(1, k, l, m, 2));
192 template<
int DataLayout>
203 VERIFY_IS_APPROX(result(0), t1(0, 0) * t2(0, 0, 0) + t1(1, 0) * t2(1, 0, 0)
204 + t1(0, 1) * t2(0, 1, 0) + t1(1, 1) * t2(1, 1, 0));
205 VERIFY_IS_APPROX(result(1), t1(0, 0) * t2(0, 0, 1) + t1(1, 0) * t2(1, 0, 1)
206 + t1(0, 1) * t2(0, 1, 1) + t1(1, 1) * t2(1, 1, 1));
210 result = t2.contract(t1, dims);
212 VERIFY_IS_APPROX(result(0), t1(0, 0) * t2(0, 0, 0) + t1(1, 0) * t2(0, 1, 0)
213 + t1(0, 1) * t2(0, 0, 1) + t1(1, 1) * t2(0, 1, 1));
214 VERIFY_IS_APPROX(result(1), t1(0, 0) * t2(1, 0, 0) + t1(1, 0) * t2(1, 1, 0)
215 + t1(0, 1) * t2(1, 0, 1) + t1(1, 1) * t2(1, 1, 1));
218 template<
int DataLayout>
231 auto contract1 = t1.contract(t2, dims);
232 auto diff = t3 - contract1;
233 auto contract2 = t1.contract(t4, dims);
240 m1(t1.
data(), 2, 2), m2(t2.
data(), 2, 2), m3(t3.
data(), 2, 2),
243 expected = (m1 * m4) * (m3 - m1 * m2);
245 VERIFY_IS_APPROX(result(0, 0), expected(0, 0));
246 VERIFY_IS_APPROX(result(0, 1), expected(0, 1));
247 VERIFY_IS_APPROX(result(1, 0), expected(1, 0));
248 VERIFY_IS_APPROX(result(1, 1), expected(1, 1));
251 template<
int DataLayout>
262 mat3 = mat1.contract(mat2, dims);
264 VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0));
265 VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1));
266 VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0));
267 VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
270 template<
int DataLayout>
282 mat3 = mat1.contract(mat2, dims);
284 VERIFY_IS_APPROX(mat3(0, 0),
285 mat1(0,0,0)*mat2(0,0,0) + mat1(1,0,0)*mat2(0,0,1) +
286 mat1(0,0,1)*mat2(1,0,0) + mat1(1,0,1)*mat2(1,0,1));
287 VERIFY_IS_APPROX(mat3(1, 0),
288 mat1(0,1,0)*mat2(0,0,0) + mat1(1,1,0)*mat2(0,0,1) +
289 mat1(0,1,1)*mat2(1,0,0) + mat1(1,1,1)*mat2(1,0,1));
290 VERIFY_IS_APPROX(mat3(0, 1),
291 mat1(0,0,0)*mat2(0,1,0) + mat1(1,0,0)*mat2(0,1,1) +
292 mat1(0,0,1)*mat2(1,1,0) + mat1(1,0,1)*mat2(1,1,1));
293 VERIFY_IS_APPROX(mat3(1, 1),
294 mat1(0,1,0)*mat2(0,1,0) + mat1(1,1,0)*mat2(0,1,1) +
295 mat1(0,1,1)*mat2(1,1,0) + mat1(1,1,1)*mat2(1,1,1));
298 mat3 = mat1.contract(mat2, dims2);
300 VERIFY_IS_APPROX(mat3(0, 0),
301 mat1(0,0,0)*mat2(0,0,0) + mat1(1,0,0)*mat2(0,0,1) +
302 mat1(0,0,1)*mat2(1,0,0) + mat1(1,0,1)*mat2(1,0,1));
303 VERIFY_IS_APPROX(mat3(1, 0),
304 mat1(0,1,0)*mat2(0,0,0) + mat1(1,1,0)*mat2(0,0,1) +
305 mat1(0,1,1)*mat2(1,0,0) + mat1(1,1,1)*mat2(1,0,1));
306 VERIFY_IS_APPROX(mat3(0, 1),
307 mat1(0,0,0)*mat2(0,1,0) + mat1(1,0,0)*mat2(0,1,1) +
308 mat1(0,0,1)*mat2(1,1,0) + mat1(1,0,1)*mat2(1,1,1));
309 VERIFY_IS_APPROX(mat3(1, 1),
310 mat1(0,1,0)*mat2(0,1,0) + mat1(1,1,0)*mat2(0,1,1) +
311 mat1(0,1,1)*mat2(1,1,0) + mat1(1,1,1)*mat2(1,1,1));
315 template<
int DataLayout>
332 mat3 = mat1.contract(mat2, dims1);
333 mat4 = mat2.contract(mat1, dims2);
337 for (
size_t i = 0; i < 5; i++) {
338 for (
size_t j = 0; j < 10; j++) {
339 VERIFY_IS_APPROX(mat3.
data()[i + 5 * j], mat4.
data()[j + 10 * i]);
344 for (
size_t i = 0; i < 5; i++) {
345 for (
size_t j = 0; j < 10; j++) {
346 VERIFY_IS_APPROX(mat3.
data()[10 * i + j], mat4.
data()[i + 5 * j]);
352 template<
int DataLayout>
363 t_left += t_left.constant(1.0
f);
364 t_right += t_right.constant(1.0
f);
366 typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
367 MapXf m_left(t_left.
data(), 1500, 248);
368 MapXf m_right(t_right.
data(), 248, 1400);
375 t_result = t_left.contract(t_right, dims);
376 m_result = m_left * m_right;
378 for (
int i = 0; i < t_result.
dimensions().TotalSize(); i++) {
379 VERIFY(&t_result.
data()[i] != &m_result.data()[i]);
380 VERIFY_IS_APPROX(t_result.
data()[i], m_result.data()[i]);
384 template<
int DataLayout>
394 typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
395 MapXf m_left(t_left.
data(), 30, 50);
396 MapXf m_right(t_right.
data(), 50, 1);
403 t_result = t_left.contract(t_right, dims);
404 m_result = m_left * m_right;
406 for (
int i = 0; i < t_result.
dimensions().TotalSize(); i++) {
407 VERIFY(internal::isApprox(t_result(i), m_result(i, 0), 1));
412 template<
int DataLayout>
425 typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
426 MapXf m_left(t_left.
data(), 7, 13*17);
427 MapXf m_right(t_right.
data(), 1, 7);
430 for (
int i = 0; i < t_result.
dimensions().TotalSize(); i++) {
431 VERIFY(internal::isApprox(t_result(i), m_result(i, 0), 1));
436 template<
int DataLayout>
445 t_left += t_left.constant(1.0
f);
446 t_right += t_right.constant(1.0
f);
454 t_result = t_left.contract(t_right, dims);
457 Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> m_left(t_left.
data(), 150, 93);
458 Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> m_right(t_right.
data(), 93, 140);
461 for (
int i = 0; i < t_result.
dimensions().TotalSize(); i++) {
462 VERIFY_IS_APPROX(t_result.
data()[i], m_result.data()[i]);
466 template<
int DataLayout>
480 for (
int i = 0; i < result.
dimension(0); ++i) {
481 for (
int j = 0; j < result.
dimension(1); ++j) {
482 for (
int k = 0; k < result.
dimension(2); ++k) {
483 for (
int l = 0; l < result.
dimension(3); ++l) {
484 VERIFY_IS_APPROX(result(i, j, k, l), mat1(i, j) * mat2(k, l) );
492 template<
int DataLayout>
500 TensorMap<Tensor<const float, 2, DataLayout> > mat1(in1.
data(), 2, 3);
501 TensorMap<Tensor<const float, 2, DataLayout> > mat2(in2.
data(), 3, 2);
505 mat3 = mat1.contract(mat2, dims);
507 VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0));
508 VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1));
509 VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0));
510 VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
515 CALL_SUBTEST(test_evals<ColMajor>());
516 CALL_SUBTEST(test_evals<RowMajor>());
517 CALL_SUBTEST(test_scalar<ColMajor>());
518 CALL_SUBTEST(test_scalar<RowMajor>());
519 CALL_SUBTEST(test_multidims<ColMajor>());
520 CALL_SUBTEST(test_multidims<RowMajor>());
521 CALL_SUBTEST(test_holes<ColMajor>());
522 CALL_SUBTEST(test_holes<RowMajor>());
523 CALL_SUBTEST(test_full_redux<ColMajor>());
524 CALL_SUBTEST(test_full_redux<RowMajor>());
525 CALL_SUBTEST(test_contraction_of_contraction<ColMajor>());
526 CALL_SUBTEST(test_contraction_of_contraction<RowMajor>());
527 CALL_SUBTEST(test_expr<ColMajor>());
528 CALL_SUBTEST(test_expr<RowMajor>());
529 CALL_SUBTEST(test_out_of_order_contraction<ColMajor>());
530 CALL_SUBTEST(test_out_of_order_contraction<RowMajor>());
531 CALL_SUBTEST(test_consistency<ColMajor>());
532 CALL_SUBTEST(test_consistency<RowMajor>());
533 CALL_SUBTEST(test_large_contraction<ColMajor>());
534 CALL_SUBTEST(test_large_contraction<RowMajor>());
535 CALL_SUBTEST(test_matrix_vector<ColMajor>());
536 CALL_SUBTEST(test_matrix_vector<RowMajor>());
537 CALL_SUBTEST(test_tensor_vector<ColMajor>());
538 CALL_SUBTEST(test_tensor_vector<RowMajor>());
539 CALL_SUBTEST(test_small_blocking_factors<ColMajor>());
540 CALL_SUBTEST(test_small_blocking_factors<RowMajor>());
541 CALL_SUBTEST(test_tensor_product<ColMajor>());
542 CALL_SUBTEST(test_tensor_product<RowMajor>());
543 CALL_SUBTEST(test_const_inputs<ColMajor>());
544 CALL_SUBTEST(test_const_inputs<RowMajor>());
void setCpuCacheSizes(std::ptrdiff_t l1, std::ptrdiff_t l2, std::ptrdiff_t l3)
static void test_scalar()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
void test_cxx11_tensor_contraction()
static void test_consistency()
static void test_large_contraction()
A matrix or vector expression mapping an existing array of data.
static void test_matrix_vector()
static int f(const TensorMap< Tensor< int, 3 > > &tensor)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
static void test_small_blocking_factors()
static void test_tensor_product()
#define EIGEN_STATIC_ASSERT(CONDITION, MSG)
static void test_contraction_of_contraction()
Tensor< float, 1 >::DimensionPair DimPair
static void test_out_of_order_contraction()
static void test_const_inputs()
static void test_multidims()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
static void test_tensor_vector()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions & dimensions() const
The matrix class, also used for vectors and row-vectors.
static void test_full_redux()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setZero()