12 #include <Eigen/CXX11/Tensor>
19 template<
int DataLayout>
55 eval2.evalTo(mat5.
data());
70 eval3.evalTo(mat6.
data());
81 template<
int DataLayout>
94 for (
int i = 0;
i < 6; ++
i) {
100 template<
int DataLayout>
120 VERIFY_IS_APPROX(mat3(0,0,0),
mat1(0,0,0)*mat2(0,0,0,0) +
mat1(0,1,0)*mat2(0,0,1,0) +
121 mat1(0,0,1)*mat2(0,0,0,1) +
mat1(0,1,1)*mat2(0,0,1,1));
122 VERIFY_IS_APPROX(mat3(0,0,1),
mat1(0,0,0)*mat2(0,1,0,0) +
mat1(0,1,0)*mat2(0,1,1,0) +
123 mat1(0,0,1)*mat2(0,1,0,1) +
mat1(0,1,1)*mat2(0,1,1,1));
124 VERIFY_IS_APPROX(mat3(0,1,0),
mat1(0,0,0)*mat2(1,0,0,0) +
mat1(0,1,0)*mat2(1,0,1,0) +
125 mat1(0,0,1)*mat2(1,0,0,1) +
mat1(0,1,1)*mat2(1,0,1,1));
126 VERIFY_IS_APPROX(mat3(0,1,1),
mat1(0,0,0)*mat2(1,1,0,0) +
mat1(0,1,0)*mat2(1,1,1,0) +
127 mat1(0,0,1)*mat2(1,1,0,1) +
mat1(0,1,1)*mat2(1,1,1,1));
128 VERIFY_IS_APPROX(mat3(1,0,0),
mat1(1,0,0)*mat2(0,0,0,0) +
mat1(1,1,0)*mat2(0,0,1,0) +
129 mat1(1,0,1)*mat2(0,0,0,1) +
mat1(1,1,1)*mat2(0,0,1,1));
130 VERIFY_IS_APPROX(mat3(1,0,1),
mat1(1,0,0)*mat2(0,1,0,0) +
mat1(1,1,0)*mat2(0,1,1,0) +
131 mat1(1,0,1)*mat2(0,1,0,1) +
mat1(1,1,1)*mat2(0,1,1,1));
132 VERIFY_IS_APPROX(mat3(1,1,0),
mat1(1,0,0)*mat2(1,0,0,0) +
mat1(1,1,0)*mat2(1,0,1,0) +
133 mat1(1,0,1)*mat2(1,0,0,1) +
mat1(1,1,1)*mat2(1,0,1,1));
134 VERIFY_IS_APPROX(mat3(1,1,1),
mat1(1,0,0)*mat2(1,1,0,0) +
mat1(1,1,0)*mat2(1,1,1,0) +
135 mat1(1,0,1)*mat2(1,1,0,1) +
mat1(1,1,1)*mat2(1,1,1,1));
147 Evaluator2 eval2(mat4.contract(mat5, dims2),
DefaultDevice());
148 eval2.evalTo(mat6.
data());
153 mat4(0,1)*mat5(1,0,0) + mat4(1,1)*mat5(1,1,0));
155 mat4(0,1)*mat5(1,0,1) + mat4(1,1)*mat5(1,1,1));
158 template<
int DataLayout>
173 for (
int i = 0;
i < 5; ++
i) {
174 for (
int j = 0;
j < 5; ++
j) {
175 for (
int k = 0; k < 5; ++k) {
176 for (
int l = 0;
l < 5; ++
l) {
177 for (
int m = 0;
m < 5; ++
m) {
179 t1(0,
i,
j, 0) * t2(0, k,
l,
m, 0) +
180 t1(1,
i,
j, 0) * t2(1, k,
l,
m, 0) +
181 t1(0,
i,
j, 1) * t2(0, k,
l,
m, 1) +
182 t1(1,
i,
j, 1) * t2(1, k,
l,
m, 1) +
183 t1(0,
i,
j, 2) * t2(0, k,
l,
m, 2) +
184 t1(1,
i,
j, 2) * t2(1, k,
l,
m, 2));
192 template<
int DataLayout>
204 + t1(0, 1) * t2(0, 1, 0) + t1(1, 1) * t2(1, 1, 0));
206 + t1(0, 1) * t2(0, 1, 1) + t1(1, 1) * t2(1, 1, 1));
210 result = t2.contract(t1, dims);
213 + t1(0, 1) * t2(0, 0, 1) + t1(1, 1) * t2(0, 1, 1));
215 + t1(0, 1) * t2(1, 0, 1) + t1(1, 1) * t2(1, 1, 1));
218 template<
int DataLayout>
231 auto contract1 = t1.contract(t2, dims);
232 auto diff = t3 - contract1;
233 auto contract2 = t1.contract(t4, dims);
251 template<
int DataLayout>
262 mat3 =
mat1.contract(mat2, dims);
270 template<
int DataLayout>
282 mat3 =
mat1.contract(mat2, dims);
285 mat1(0,0,0)*mat2(0,0,0) +
mat1(1,0,0)*mat2(0,0,1) +
286 mat1(0,0,1)*mat2(1,0,0) +
mat1(1,0,1)*mat2(1,0,1));
288 mat1(0,1,0)*mat2(0,0,0) +
mat1(1,1,0)*mat2(0,0,1) +
289 mat1(0,1,1)*mat2(1,0,0) +
mat1(1,1,1)*mat2(1,0,1));
291 mat1(0,0,0)*mat2(0,1,0) +
mat1(1,0,0)*mat2(0,1,1) +
292 mat1(0,0,1)*mat2(1,1,0) +
mat1(1,0,1)*mat2(1,1,1));
294 mat1(0,1,0)*mat2(0,1,0) +
mat1(1,1,0)*mat2(0,1,1) +
295 mat1(0,1,1)*mat2(1,1,0) +
mat1(1,1,1)*mat2(1,1,1));
298 mat3 =
mat1.contract(mat2, dims2);
301 mat1(0,0,0)*mat2(0,0,0) +
mat1(1,0,0)*mat2(0,0,1) +
302 mat1(0,0,1)*mat2(1,0,0) +
mat1(1,0,1)*mat2(1,0,1));
304 mat1(0,1,0)*mat2(0,0,0) +
mat1(1,1,0)*mat2(0,0,1) +
305 mat1(0,1,1)*mat2(1,0,0) +
mat1(1,1,1)*mat2(1,0,1));
307 mat1(0,0,0)*mat2(0,1,0) +
mat1(1,0,0)*mat2(0,1,1) +
308 mat1(0,0,1)*mat2(1,1,0) +
mat1(1,0,1)*mat2(1,1,1));
310 mat1(0,1,0)*mat2(0,1,0) +
mat1(1,1,0)*mat2(0,1,1) +
311 mat1(0,1,1)*mat2(1,1,0) +
mat1(1,1,1)*mat2(1,1,1));
315 template<
int DataLayout>
332 mat3 =
mat1.contract(mat2, dims1);
333 mat4 = mat2.contract(
mat1, dims2);
337 for (
size_t i = 0;
i < 5;
i++) {
338 for (
size_t j = 0;
j < 10;
j++) {
344 for (
size_t i = 0;
i < 5;
i++) {
345 for (
size_t j = 0;
j < 10;
j++) {
352 template<
int DataLayout>
363 t_left += t_left.constant(1.0
f);
364 t_right += t_right.constant(1.0
f);
367 MapXf m_left(t_left.
data(), 1500, 248);
368 MapXf m_right(t_right.
data(), 248, 1400);
375 t_result = t_left.contract(t_right, dims);
376 m_result = m_left * m_right;
378 for (
int i = 0;
i < t_result.
dimensions().TotalSize();
i++) {
384 template<
int DataLayout>
395 MapXf m_left(t_left.
data(), 30, 50);
396 MapXf m_right(t_right.
data(), 50, 1);
403 t_result = t_left.contract(t_right, dims);
404 m_result = m_left * m_right;
406 for (
int i = 0;
i < t_result.
dimensions().TotalSize();
i++) {
412 template<
int DataLayout>
426 MapXf m_left(t_left.
data(), 7, 13*17);
427 MapXf m_right(t_right.
data(), 1, 7);
430 for (
int i = 0;
i < t_result.
dimensions().TotalSize();
i++) {
436 template<
int DataLayout>
445 t_left += t_left.constant(1.0
f);
446 t_right += t_right.constant(1.0
f);
454 t_result = t_left.contract(t_right, dims);
461 for (
int i = 0;
i < t_result.
dimensions().TotalSize();
i++) {
466 template<
int DataLayout>
481 for (
int i = 0;
i <
result.dimension(0); ++
i) {
482 for (
int j = 0;
j <
result.dimension(1); ++
j) {
483 for (
int k = 0; k <
result.dimension(2); ++k) {
484 for (
int l = 0;
l <
result.dimension(3); ++
l) {
493 template<
int DataLayout>
506 mat3 =
mat1.contract(mat2, dims);
516 template <
typename Index,
typename Scalar>
518 const internal::blas_data_mapper<Scalar, Index, ColMajor>& output_mapper,
520 Index num_cols)
const {
521 for (
int i = 0;
i < num_rows; ++
i) {
522 for (
int j = 0;
j < num_cols; ++
j) {
529 template <
int DataLayout>
541 t_left += t_left.constant(1.0
f);
542 t_right += t_right.constant(1.0
f);
545 MapXf m_left(t_left.
data(), 1500, 248);
546 MapXf m_right(t_right.
data(), 248, 1400);
555 m_result = m_left * m_right;
557 for (std::ptrdiff_t
i = 0;
i < t_result.dimensions().TotalSize();
i++) {
595 CALL_SUBTEST_8(test_large_contraction_with_output_kernel<ColMajor>());
596 CALL_SUBTEST_8(test_large_contraction_with_output_kernel<RowMajor>());