10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H
23 template<
typename Broadcast,
typename XprType>
30 typedef typename XprType::Nested
Nested;
32 static const int NumDimensions = XprTraits::NumDimensions;
33 static const int Layout = XprTraits::Layout;
37 template<
typename Broadcast,
typename XprType>
43 template<
typename Broadcast,
typename XprType>
49 template <
typename Dims>
51 static const bool value =
false;
57 #ifndef EIGEN_EMULATE_CXX11_META_H
58 template <
typename std::ptrdiff_t...
Indices>
68 template<
typename Broadcast,
typename XprType>
96 template<
typename Broadcast,
typename ArgType,
typename Device>
142 : isCopy(false), nByOne(false), oneByN(false),
152 for (
int i = 0;
i < NumDims; ++
i) {
154 m_dimensions[
i] = input_dims[
i] * m_broadcast[
i];
155 if (m_broadcast[
i] != 1) {
161 m_inputStrides[0] = 1;
162 m_outputStrides[0] = 1;
163 for (
int i = 1;
i < NumDims; ++
i) {
164 m_inputStrides[
i] = m_inputStrides[
i-1] * input_dims[
i-1];
165 m_outputStrides[
i] = m_outputStrides[
i-1] * m_dimensions[
i-1];
168 m_inputStrides[NumDims-1] = 1;
169 m_outputStrides[NumDims-1] = 1;
170 for (
int i = NumDims-2;
i >= 0; --
i) {
171 m_inputStrides[
i] = m_inputStrides[
i+1] * input_dims[
i+1];
172 m_outputStrides[
i] = m_outputStrides[
i+1] * m_dimensions[
i+1];
176 if (input_dims[0] == 1) {
178 for (
int i = 1;
i < NumDims; ++
i) {
179 if (m_broadcast[
i] != 1) {
184 }
else if (input_dims[NumDims-1] == 1) {
186 for (
int i = 0;
i < NumDims-1; ++
i) {
187 if (m_broadcast[
i] != 1) {
196 if (!oneByN && !nByOne) {
197 if (input_dims[0] == 1 && input_dims[NumDims-1] == 1 && NumDims > 2) {
200 for (
int i = 1;
i < NumDims-1; ++
i) {
201 if (m_broadcast[
i] != 1) {
214 m_impl.evalSubExprsIfNeeded(
NULL);
218 #ifdef EIGEN_USE_THREADS
219 template <
typename EvalSubExprsCallback>
222 m_impl.evalSubExprsIfNeededAsync(
nullptr, [done](
bool) { done(
true); });
224 #endif // EIGEN_USE_THREADS
233 return m_impl.coeff(0);
238 return m_impl.coeff(index);
240 return coeffColMajor(index);
244 return m_impl.coeff(index);
246 return coeffRowMajor(index);
253 Index inputIndex = 0;
255 for (
int i = NumDims - 1;
i > 0; --
i) {
256 const Index idx = index / m_outputStrides[
i];
257 if (internal::index_statically_eq<Broadcast>(
i, 1)) {
259 inputIndex += idx * m_inputStrides[
i];
261 if (internal::index_statically_eq<InputDimensions>(
i, 1)) {
264 inputIndex += (idx % m_impl.dimensions()[
i]) * m_inputStrides[
i];
267 index -= idx * m_outputStrides[
i];
269 if (internal::index_statically_eq<Broadcast>(0, 1)) {
273 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
276 inputIndex += (index % m_impl.dimensions()[0]);
284 return m_impl.coeff(indexColMajor(index));
288 Index inputIndex = 0;
290 for (
int i = 0;
i < NumDims - 1; ++
i) {
291 const Index idx = index / m_outputStrides[
i];
292 if (internal::index_statically_eq<Broadcast>(
i, 1)) {
294 inputIndex += idx * m_inputStrides[
i];
296 if (internal::index_statically_eq<InputDimensions>(
i, 1)) {
299 inputIndex += (idx % m_impl.dimensions()[
i]) * m_inputStrides[
i];
302 index -= idx * m_outputStrides[
i];
304 if (internal::index_statically_eq<Broadcast>(NumDims - 1, 1)) {
308 if (internal::index_statically_eq<InputDimensions>(NumDims - 1, 1)) {
309 eigen_assert(index % m_impl.dimensions()[NumDims - 1] == 0);
311 inputIndex += (index % m_impl.dimensions()[NumDims - 1]);
319 return m_impl.coeff(indexRowMajor(index));
322 template<
int LoadMode>
326 return internal::pset1<PacketReturnType>(m_impl.coeff(0));
331 #ifdef EIGEN_GPU_COMPILE_PHASE
334 return m_impl.template packet<Unaligned>(index);
336 return m_impl.template packet<LoadMode>(index);
338 }
else if (oneByN && !nByOne) {
339 return packetNByOne<LoadMode>(index);
340 }
else if (!oneByN && nByOne) {
341 return packetOneByN<LoadMode>(index);
342 }
else if (oneByN && nByOne) {
343 return packetOneByNByOne<LoadMode>(index);
345 return packetColMajor<LoadMode>(index);
349 #ifdef EIGEN_GPU_COMPILE_PHASE
351 return m_impl.template packet<Unaligned>(index);
353 return m_impl.template packet<LoadMode>(index);
355 }
else if (oneByN && !nByOne) {
356 return packetOneByN<LoadMode>(index);
357 }
else if (!oneByN && nByOne) {
358 return packetNByOne<LoadMode>(index);
359 }
else if (oneByN && nByOne) {
360 return packetOneByNByOne<LoadMode>(index);
362 return packetRowMajor<LoadMode>(index);
367 template<
int LoadMode>
375 Index startDim, endDim;
376 Index inputIndex, outputOffset, batchedIndex;
379 startDim = NumDims - 1;
383 endDim = NumDims - 2;
386 batchedIndex = index % m_outputStrides[startDim];
387 inputIndex = batchedIndex / m_outputStrides[endDim];
388 outputOffset = batchedIndex % m_outputStrides[endDim];
390 if (outputOffset +
PacketSize <= m_outputStrides[endDim]) {
391 values[0] = m_impl.coeff(inputIndex);
392 return internal::pload1<PacketReturnType>(
values);
396 if (outputOffset + cur < m_outputStrides[endDim]) {
397 values[
i] = m_impl.coeff(inputIndex);
400 inputIndex = (inputIndex == m_inputStrides[startDim] ? 0 : inputIndex);
401 values[
i] = m_impl.coeff(inputIndex);
406 return internal::pload<PacketReturnType>(
values);
410 template<
int LoadMode>
416 Index dim, inputIndex;
424 inputIndex = index % m_inputStrides[dim];
425 if (inputIndex +
PacketSize <= m_inputStrides[dim]) {
426 return m_impl.template packet<Unaligned>(inputIndex);
431 if (inputIndex > m_inputStrides[dim]-1) {
434 values[
i] = m_impl.coeff(inputIndex++);
436 return internal::pload<PacketReturnType>(
values);
440 template<
int LoadMode>
447 Index dim, inputIndex, outputOffset;
455 inputIndex = index / m_outputStrides[dim];
456 outputOffset = index % m_outputStrides[dim];
457 if (outputOffset +
PacketSize <= m_outputStrides[dim]) {
458 values[0] = m_impl.coeff(inputIndex);
459 return internal::pload1<PacketReturnType>(
values);
463 if (outputOffset + cur < m_outputStrides[dim]) {
464 values[
i] = m_impl.coeff(inputIndex);
466 values[
i] = m_impl.coeff(++inputIndex);
471 return internal::pload<PacketReturnType>(
values);
477 template<
int LoadMode>
483 const Index originalIndex = index;
485 Index inputIndex = 0;
487 for (
int i = NumDims - 1;
i > 0; --
i) {
488 const Index idx = index / m_outputStrides[
i];
489 if (internal::index_statically_eq<Broadcast>(
i, 1)) {
491 inputIndex += idx * m_inputStrides[
i];
493 if (internal::index_statically_eq<InputDimensions>(
i, 1)) {
496 inputIndex += (idx % m_impl.dimensions()[
i]) * m_inputStrides[
i];
499 index -= idx * m_outputStrides[
i];
502 if (internal::index_statically_eq<Broadcast>(0, 1)) {
504 innermostLoc = index;
506 if (internal::index_statically_eq<InputDimensions>(0, 1)) {
510 innermostLoc = index % m_impl.dimensions()[0];
513 inputIndex += innermostLoc;
517 if (innermostLoc +
PacketSize <= m_impl.dimensions()[0]) {
518 return m_impl.template packet<Unaligned>(inputIndex);
521 values[0] = m_impl.coeff(inputIndex);
524 if (innermostLoc +
i < m_impl.dimensions()[0]) {
525 values[
i] = m_impl.coeff(inputIndex+
i);
527 values[
i] = coeffColMajor(originalIndex+
i);
535 template<
int LoadMode>
541 const Index originalIndex = index;
543 Index inputIndex = 0;
545 for (
int i = 0;
i < NumDims - 1; ++
i) {
546 const Index idx = index / m_outputStrides[
i];
547 if (internal::index_statically_eq<Broadcast>(
i, 1)) {
549 inputIndex += idx * m_inputStrides[
i];
551 if (internal::index_statically_eq<InputDimensions>(
i, 1)) {
554 inputIndex += (idx % m_impl.dimensions()[
i]) * m_inputStrides[
i];
557 index -= idx * m_outputStrides[
i];
560 if (internal::index_statically_eq<Broadcast>(NumDims-1, 1)) {
562 innermostLoc = index;
564 if (internal::index_statically_eq<InputDimensions>(NumDims-1, 1)) {
565 eigen_assert(index % m_impl.dimensions()[NumDims-1] == 0);
568 innermostLoc = index % m_impl.dimensions()[NumDims-1];
571 inputIndex += innermostLoc;
575 if (innermostLoc +
PacketSize <= m_impl.dimensions()[NumDims-1]) {
576 return m_impl.template packet<Unaligned>(inputIndex);
579 values[0] = m_impl.coeff(inputIndex);
582 if (innermostLoc +
i < m_impl.dimensions()[NumDims-1]) {
583 values[
i] = m_impl.coeff(inputIndex+
i);
585 values[
i] = coeffRowMajor(originalIndex+
i);
595 double compute_cost = TensorOpCost::AddCost<Index>();
596 if (!isCopy && NumDims > 0) {
598 for (
int i = NumDims - 1;
i > 0; --
i) {
599 compute_cost += TensorOpCost::DivCost<Index>();
600 if (internal::index_statically_eq<Broadcast>(
i, 1)) {
602 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
604 if (!internal::index_statically_eq<InputDimensions>(
i, 1)) {
605 compute_cost += TensorOpCost::MulCost<Index>() +
606 TensorOpCost::ModCost<Index>() +
607 TensorOpCost::AddCost<Index>();
611 TensorOpCost::MulCost<Index>() + TensorOpCost::AddCost<Index>();
614 return m_impl.costPerCoeff(vectorized) +
622 const size_t target_size =
m_device.firstLevelCacheSize();
624 m_impl.getResourceRequirements(),
625 internal::TensorBlockResourceRequirements::skewed<Scalar>(target_size));
630 bool =
false)
const {
631 BlockBroadcastingParams
params = blockBroadcastingParams(
desc);
633 if (
params.inner_dim_size == 0 ||
params.bcast_dim_size == 0) {
643 size_t materialized_input_size = 0;
652 for (
int i =
params.inner_dim_count + 1;
i < NumDims; ++
i) {
653 const Index dim = IsColMajor ?
i : NumDims - 1 -
i;
656 it[idx].output_stride = m_outputStrides[dim];
657 it[idx].output_span = it[idx].output_stride * (it[idx].
size - 1);
662 Index output_offset = 0;
666 const Index output_size = NumDims == 0 ? 1 :
params.output_dims.TotalSize();
668 for (
Index num_output_coeffs = 0; num_output_coeffs < output_size;) {
669 ScalarNoConst* bcast_output = materialized_output + num_output_coeffs;
670 Index bcast_offset =
desc.offset() + output_offset;
673 num_output_coeffs += BroadcastBlockAlongBcastDim(
674 params, bcast_offset, scratch, bcast_output, &materialized_input,
675 &materialized_input_size);
678 for (
int j = 0;
j < idx; ++
j) {
679 if (++it[
j].count < it[
j].
size) {
680 output_offset += it[
j].output_stride;
684 output_offset -= it[
j].output_span;
695 Broadcast
functor()
const {
return m_broadcast; }
696 #ifdef EIGEN_USE_SYCL
699 cl::sycl::handler& cgh)
const {
704 static const bool IsColMajor =
724 struct BlockBroadcastingParams {
745 struct BlockBroadcastingIteratorState {
754 BlockBroadcastingParams
params;
760 params.output_strides = internal::strides<Layout>(
params.output_dims);
765 params.bcast_dim_size = 1;
766 params.inner_dim_size = 1;
770 params.inner_dim_count = 0;
772 for (
int i = 0;
i < NumDims; ++
i) {
773 const int dim = IsColMajor ?
i : NumDims -
i - 1;
775 if (
params.output_dims[dim] == m_dimensions[dim]) {
789 for (
int i = 0;
i <
params.inner_dim_count; ++
i) {
790 const int dim = IsColMajor ?
i : NumDims -
i - 1;
793 for (
int i =
params.inner_dim_count;
i < NumDims; ++
i) {
794 const int dim = IsColMajor ?
i : NumDims -
i - 1;
795 params.input_block_sizes[dim] = 1;
797 params.input_block_strides =
798 internal::strides<Layout>(
params.input_block_sizes);
818 for (
int i = 0;
i <
params.inner_dim_count; ++
i) {
819 const int dim = IsColMajor ?
i : NumDims -
i - 1;
821 const int copy_dim = IsColMajor ? 2 *
i : 2 * NumDims - 2 *
i - 1;
822 const int broadcast_dim = IsColMajor ? copy_dim + 1 : copy_dim - 1;
824 params.bcast_block_sizes[copy_dim] =
params.input_dims[dim];
825 params.bcast_block_sizes[broadcast_dim] = m_broadcast[dim];
826 params.bcast_block_strides[copy_dim] =
params.output_strides[dim];
827 params.bcast_block_strides[broadcast_dim] =
829 params.bcast_input_strides[copy_dim] =
params.input_block_strides[dim];
830 params.bcast_input_strides[broadcast_dim] = 0;
833 for (
int i = 2 *
params.inner_dim_count;
i < 2 * NumDims; ++
i) {
834 const int dim = IsColMajor ?
i : 2 * NumDims -
i - 1;
835 params.bcast_block_sizes[dim] = 1;
836 params.bcast_block_strides[dim] = 0;
837 params.bcast_input_strides[dim] = 0;
850 BlockBroadcastingParams
params,
Index bcast_offset,
853 size_t* materialized_input_size)
const {
854 if (
params.bcast_dim_size == 1) {
856 return BroadcastBlock(
859 params.bcast_input_strides, bcast_offset, 0, scratch,
860 materialized_output, materialized_input, materialized_input_size);
864 const int broadcast_bcast_dim =
865 IsColMajor ? 2 *
params.inner_dim_count + 1
866 : 2 * NumDims - 2 *
params.inner_dim_count - 2;
868 params.bcast_block_sizes[broadcast_bcast_dim] =
params.bcast_dim_size;
869 params.bcast_input_strides[broadcast_bcast_dim] = 0;
870 params.bcast_block_strides[broadcast_bcast_dim] =
873 return BroadcastBlock(
876 params.bcast_input_strides, bcast_offset, 0, scratch,
877 materialized_output, materialized_input, materialized_input_size);
882 Index num_output_coeffs = 0;
904 const Index bcast_dim_left_index =
905 bcast_offset / m_outputStrides[
params.bcast_dim];
913 divup<Index>(bcast_dim_left_index, input_bcast_dim_size) *
914 input_bcast_dim_size;
918 const Index last_multiple =
919 (bcast_dim_left_index +
params.bcast_dim_size) /
920 input_bcast_dim_size * input_bcast_dim_size;
921 const int copy_bcast_dim =
922 IsColMajor ? 2 *
params.inner_dim_count
923 : 2 * NumDims - 2 *
params.inner_dim_count - 1;
924 const int broadcast_bcast_dim =
925 IsColMajor ? 2 *
params.inner_dim_count + 1
926 : 2 * NumDims - 2 *
params.inner_dim_count - 2;
931 params.bcast_block_sizes[copy_bcast_dim] = head_size;
932 params.bcast_input_strides[copy_bcast_dim] =
934 params.bcast_block_strides[copy_bcast_dim] =
936 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
937 params.bcast_input_strides[broadcast_bcast_dim] = 0;
938 params.bcast_block_strides[broadcast_bcast_dim] =
942 num_output_coeffs += BroadcastBlock(
945 params.bcast_input_strides, bcast_offset, 0, scratch,
946 materialized_output, materialized_input, materialized_input_size);
949 params.input_block_sizes[
params.bcast_dim] = input_bcast_dim_size;
950 params.bcast_block_sizes[copy_bcast_dim] = input_bcast_dim_size;
951 params.bcast_input_strides[copy_bcast_dim] =
953 params.bcast_block_strides[copy_bcast_dim] =
955 params.bcast_block_sizes[broadcast_bcast_dim] =
957 params.bcast_input_strides[broadcast_bcast_dim] = 0;
958 params.bcast_block_strides[broadcast_bcast_dim] =
962 m_outputStrides[
params.bcast_dim];
964 num_output_coeffs += BroadcastBlock(
967 params.bcast_input_strides, bcast_offset,
offset, scratch,
968 materialized_output, materialized_input, materialized_input_size);
970 if (last_multiple < bcast_dim_left_index +
params.bcast_dim_size) {
971 const Index tail_size =
972 bcast_dim_left_index +
params.bcast_dim_size - last_multiple;
974 params.bcast_block_sizes[copy_bcast_dim] = tail_size;
975 params.bcast_input_strides[copy_bcast_dim] =
977 params.bcast_block_strides[copy_bcast_dim] =
979 params.bcast_block_sizes[broadcast_bcast_dim] = 1;
980 params.bcast_input_strides[broadcast_bcast_dim] = 0;
981 params.bcast_block_strides[broadcast_bcast_dim] =
984 const Index offset = (last_multiple - bcast_dim_left_index) *
985 m_outputStrides[
params.bcast_dim];
987 num_output_coeffs += BroadcastBlock(
990 params.bcast_input_strides, bcast_offset,
offset, scratch,
991 materialized_output, materialized_input, materialized_input_size);
995 const int copy_bcast_dim =
996 IsColMajor ? 2 *
params.inner_dim_count
997 : 2 * NumDims - 2 *
params.inner_dim_count - 1;
999 params.bcast_block_sizes[copy_bcast_dim] =
params.bcast_dim_size;
1000 params.bcast_input_strides[copy_bcast_dim] =
1002 params.bcast_block_strides[copy_bcast_dim] =
1005 num_output_coeffs += BroadcastBlock(
1008 params.bcast_input_strides, bcast_offset, 0, scratch,
1009 materialized_output, materialized_input, materialized_input_size);
1012 return num_output_coeffs;
1024 size_t* materialized_input_size)
const {
1029 IsColMajor ? indexColMajor(input_offset) : indexRowMajor(input_offset),
1041 input_buffer = input_block.
data();
1048 const size_t input_total_size = input_block_sizes.
TotalSize();
1049 if (*materialized_input ==
NULL ||
1050 *materialized_input_size < input_total_size) {
1051 *materialized_input_size = input_total_size;
1052 void* mem = scratch.
allocate(*materialized_input_size *
sizeof(
Scalar));
1058 TensorBlockAssignment;
1060 TensorBlockAssignment::Run(
1061 TensorBlockAssignment::target(input_block_sizes, input_block_strides,
1062 *materialized_input),
1063 input_block.
expr());
1065 input_buffer = *materialized_input;
1074 typename TensorBlockIO::Src src(bcast_input_strides, input_buffer);
1075 typename TensorBlockIO::Dst dst(bcast_block_sizes, bcast_block_strides,
1076 materialized_output +
offset);
1078 return TensorBlockIO::Copy(dst, src);
1093 #endif // EIGEN_CXX11_TENSOR_TENSOR_BROADCASTING_H