10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H
26 template <
typename Tensor,
bool HasRawAccess>
struct CoeffLoader {
42 return m_tensor.template packet<LoadMode>(index);
66 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
75 typename nocontract_t,
typename contract_t,
76 int packet_size,
bool inner_dim_contiguous,
int Alignment>
81 const nocontract_t& nocontract_strides,
82 const nocontract_t& ij_strides,
83 const contract_t& contract_strides,
84 const contract_t& k_strides) :
115 const bool left = (side ==
Lhs);
124 if (side ==
Lhs && inner_dim_contiguous) {
126 linidx += nocontract_val;
140 if (side ==
Rhs && inner_dim_contiguous) {
142 linidx += contract_val;
153 const bool left = (side ==
Lhs);
155 Index linidx[2] = {0, 0};
165 if (side ==
Lhs && inner_dim_contiguous) {
167 linidx[0] += nocontract_val[0];
168 linidx[1] += nocontract_val[1];
186 if (side ==
Rhs && inner_dim_contiguous) {
188 linidx[0] += contract_val[0];
189 linidx[1] += contract_val[1];
202 return (Alignment ==
Aligned) && (side ==
Lhs) && inner_dim_contiguous ? 0 :
size;
217 template<
typename Scalar,
typename Index,
int side,
219 typename nocontract_t,
typename contract_t,
220 int packet_size,
bool inner_dim_contiguous,
221 bool inner_dim_reordered,
int Alignment>
229 const nocontract_t& nocontract_strides,
230 const nocontract_t& ij_strides,
231 const contract_t& contract_strides,
232 const contract_t& k_strides) :
233 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
235 typedef typename Tensor::PacketReturnType
Packet;
238 template <
int AlignmentType>
247 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
250 return this->
m_tensor.template packet<AlignmentType>(index);
261 if (Tensor::PacketAccess &&
263 (last - first) == (packet_size - 1)) {
265 return this->
m_tensor.template packet<AlignmentType>(first);
270 data[0] = this->
m_tensor.coeff(first);
271 for (
Index k = 1; k < packet_size - 1; k += 2) {
276 data[packet_size - 1] = this->
m_tensor.coeff(last);
278 return pload<Packet>(data);
281 template <
int AlignmentType>
288 if (half_packet_size == packet_size) {
289 return loadPacket<AlignmentType>(i, j);
292 for (
Index k = 0; k < half_packet_size; k++) {
295 return pload<HalfPacket>(data);
300 template<
typename Scalar,
typename Index,
int side,
302 typename nocontract_t,
typename contract_t,
303 bool inner_dim_contiguous,
304 bool inner_dim_reordered,
int Alignment>
305 class BaseTensorContractionMapper<
Scalar,
Index, side,
Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
312 const nocontract_t& nocontract_strides,
313 const nocontract_t& ij_strides,
314 const contract_t& contract_strides,
315 const contract_t& k_strides) :
316 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
318 typedef typename Tensor::PacketReturnType
Packet;
319 template <
int> EIGEN_DEVICE_FUNC
323 return pload<typename Tensor::PacketReturnType>(data);
325 template <
int> EIGEN_DEVICE_FUNC
332 template<
typename Scalar,
typename Index,
int side,
334 typename nocontract_t,
typename contract_t,
336 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment>
339 typedef typename Tensor::PacketReturnType
Packet;
390 return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
409 template <
typename PacketT,
int AlignmentType>
414 return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
419 template <
typename Packet>
431 template<
typename Scalar_,
typename Index,
int side,
433 typename nocontract_t,
typename contract_t,
435 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment>
437 :
public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
446 const nocontract_t& nocontract_strides,
447 const nocontract_t& ij_strides,
448 const contract_t& contract_strides,
449 const contract_t& k_strides)
450 :
Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
467 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H