10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
24 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
49 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
55 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
61 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename Device_>
74 template<
typename Indices,
typename LhsXprType,
typename RhsXprType>
86 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims)
108 template<
typename Derived>
155 op.lhsExpression(), op.rhsExpression()), device),
157 op.rhsExpression(), op.lhsExpression()), device),
162 YOU_MADE_A_PROGRAMMING_MISTAKE);
170 for (
int i = 0; i <
LDims; i++) {
171 eval_left_dims[i] =
m_leftImpl.dimensions()[i];
173 for (
int i = 0; i <
RDims; i++) {
178 eval_op_indices[i].first = op.
indices()[i].first;
179 eval_op_indices[i].second = op.
indices()[i].second;
183 for (
int i = 0; i <
LDims; i++) {
186 for (
int i = 0; i <
RDims; i++) {
201 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
202 eval_op_indices[j].second != eval_op_indices[i].second &&
203 "contraction axes should be unique");
204 if (eval_op_indices[j].first < eval_op_indices[i].first) {
212 for (
int i = 0; i <
LDims-1; ++i) {
213 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
218 for (
int i = 0; i <
RDims-1; ++i) {
219 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
236 unsigned int nocontract_idx = 0;
238 for (
int i = 0; i <
LDims; i++) {
240 bool contracting =
false;
242 if (eval_op_indices[j].first == i) {
266 for (
int i = 0; i <
RDims; i++) {
267 bool contracting =
false;
270 if (eval_op_indices[j].second == i) {
297 Index left = eval_op_indices[i].first;
298 Index right = eval_op_indices[i].second;
302 "Contraction axes must be same size");
312 if (i > 0 && right < eval_op_indices[i-1].second) {
322 for (
int i = 0, j =
NumDims - 1; i < j; i++, j--) {
347 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, true, Unaligned>(buffer);
350 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, false, Unaligned>(buffer);
355 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, true, Unaligned>(buffer);
358 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, false, Unaligned>(buffer);
365 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, true, Unaligned>(buffer);
368 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, false, Unaligned>(buffer);
373 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, true, Unaligned>(buffer);
376 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, false, Unaligned>(buffer);
382 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
398 lhs_inner_dim_contiguous,
399 false, lhs_alignment> LhsMapper;
404 rhs_inner_dim_contiguous,
405 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
413 const Index resIncr(1);
419 rows, cols, lhs, rhs,
420 buffer, resIncr,
alpha);
423 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
442 const Index nr = Traits::nr;
443 const Index mr = Traits::mr;
454 lhs_inner_dim_contiguous,
460 rhs_inner_dim_contiguous,
461 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
478 OutputMapper output(buffer, m);
482 const Index kc = blocking.
kc();
485 const Index sizeA = mc * kc;
486 const Index sizeB = kc * nc;
488 LhsScalar* blockA =
static_cast<LhsScalar *
>(this->
m_device.allocate(sizeA *
sizeof(LhsScalar)));
489 RhsScalar* blockB =
static_cast<RhsScalar *
>(this->
m_device.allocate(sizeB *
sizeof(RhsScalar)));
491 for(
Index i2=0; i2<m; i2+=mc)
494 for (
Index k2 = 0; k2 < k; k2 += kc) {
497 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
500 for (
Index j2 = 0; j2 <
n; j2 += nc) {
503 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
507 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc,
Scalar(1), -1, -1, 0, 0);
534 template<
int LoadMode>
536 return internal::ploadt<PacketReturnType, LoadMode>(
m_result + index);
571 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename Device>
574 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
597 static const int LDims =
599 static const int RDims =
607 static const int NumDims = LDims + RDims - 2 * ContractDims;
615 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
617 if (this->m_j_size == 1) {
618 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
622 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
628 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H