10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
27 template <
typename Tensor,
bool HasRawAccess,
template <
class>
class MakePointer_ =
MakePointer>
31 typename nocontract_t,
typename contract_t,
int packet_size,
32 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
36 template <
typename Tensor,
bool HasRawAccess,
template <
class>
class MakePointer_>
59 return m_tensor.template packet<LoadMode>(index);
73 template <
typename Tensor,
template <
class>
class MakePointer_>
95 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
110 template<
typename Scalar,
typename Index,
int side,
112 typename nocontract_t,
typename contract_t,
113 int packet_size,
bool inner_dim_contiguous,
int Alignment,
template <
class>
class MakePointer_ =
MakePointer>
152 const bool left = (side ==
Lhs);
163 if (side ==
Lhs && inner_dim_contiguous) {
165 linidx += nocontract_val;
180 if (side ==
Rhs && inner_dim_contiguous) {
182 linidx += contract_val;
193 const bool left = (side ==
Lhs);
196 Index linidx[2] = {0, 0};
207 if (side ==
Lhs && inner_dim_contiguous) {
209 linidx[0] += nocontract_val[0];
210 linidx[1] += nocontract_val[1];
229 if (side ==
Rhs && inner_dim_contiguous) {
231 linidx[0] += contract_val[0];
232 linidx[1] += contract_val[1];
245 return (Alignment ==
Aligned) && (side ==
Lhs) && inner_dim_contiguous ? 0 :
size;
251 #ifdef EIGEN_USE_SYCL
277 template<
typename Scalar,
typename Index,
int side,
279 typename nocontract_t,
typename contract_t,
280 int packet_size,
bool inner_dim_contiguous,
281 bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_>
289 const nocontract_t& nocontract_strides,
290 const nocontract_t& ij_strides,
291 const contract_t& contract_strides,
292 const contract_t& k_strides) :
293 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
295 template <
typename PacketT,
int AlignmentType>
306 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
307 const Index index = this->computeIndex(
i,
j);
308 eigen_assert(this->computeIndex(
i+packet_size-1,
j) == index + packet_size-1);
309 return this->m_tensor.template packet<AlignmentType>(index);
320 if (Tensor::PacketAccess &&
322 (lastIdx -
first) == (packet_size - 1)) {
324 return this->m_tensor.template packet<AlignmentType>(
first);
331 for (
Index k = 1; k < packet_size - 1; k += 2) {
333 data[k] = this->m_tensor.coeff(internal_pair.
first);
334 data[k + 1] = this->m_tensor.coeff(internal_pair.
second);
336 data[packet_size - 1] = this->m_tensor.coeff(lastIdx);
338 return pload<PacketT>(
data);
341 template <
typename PacketT,
int AlignmentType>
349 const IndexPair<Index> indexPair = this->computeIndexPair(
i,
j, requested_packet_size - 1);
354 for (
Index k = 1; k < requested_packet_size - 1; k += 2) {
356 data[k] = this->m_tensor.coeff(internal_pair.
first);
357 data[k + 1] = this->m_tensor.coeff(internal_pair.
second);
359 data[requested_packet_size - 1] = this->m_tensor.coeff(lastIdx);
361 return pload<PacketT>(
data);
364 template <
typename PacketT,
int AlignmentType>
367 return this->load<PacketT,AlignmentType>(
i,
j);
372 template<
typename Scalar,
typename Index,
int side,
374 typename nocontract_t,
typename contract_t,
375 bool inner_dim_contiguous,
376 bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_>
377 class BaseTensorContractionMapper<
Scalar,
Index, side,
Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_>
378 :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment, MakePointer_>
385 const nocontract_t& nocontract_strides,
386 const nocontract_t& ij_strides,
387 const contract_t& contract_strides,
388 const contract_t& k_strides) :
389 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
394 data[0] = this->m_tensor.coeff(this->computeIndex(
i,
j));
395 return pload<PacketT>(
data);
400 data[0] = this->m_tensor.coeff(this->computeIndex(
i,
j));
401 return pload<PacketT>(
data);
406 template<
typename Scalar,
typename Index,
int side,
408 typename nocontract_t,
typename contract_t,
410 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_=
MakePointer>
414 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> ParentMapper;
415 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Self;
447 template <
typename PacketT>
450 return m_base_mapper.template loadPacket<PacketT,Alignment>(
i, 0);
455 template <
typename PacketT>
463 template <
typename PacketT,
int AlignmentType>
471 template <
typename PacketT>
486 template <
typename PacketT,
int AlignmentType>
491 return m_base_mapper.template loadPacket<PacketT,ActualAlignment>(
i, 0);
496 template <
typename PacketT>
501 #ifdef EIGEN_USE_SYCL
519 template<
typename Scalar_,
typename Index,
int side,
521 typename nocontract_t,
typename contract_t,
523 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment,
template <
class>
class MakePointer_=
MakePointer>
525 :
public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> {
529 typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Base;
530 typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> SubMapper;
534 const nocontract_t& nocontract_strides,
535 const nocontract_t& ij_strides,
536 const contract_t& contract_strides,
537 const contract_t& k_strides)
538 :
Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
550 return Base::m_tensor;
557 template<
typename Scalar_,
typename Index_,
int side_,
559 typename nocontract_t_,
typename contract_t_,
561 bool inner_dim_contiguous_,
bool inner_dim_reordered_,
int Alignment_,
template <
class>
class MakePointer_>
563 nocontract_t_, contract_t_, packet_size_, inner_dim_contiguous_,
564 inner_dim_reordered_, Alignment_, MakePointer_> > {
567 static const bool inner_dim_contiguous = inner_dim_contiguous_;
568 static const bool inner_dim_reordered = inner_dim_reordered_;
575 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H