10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_FFT_H
11 #define EIGEN_CXX11_TENSOR_TENSOR_FFT_H
15 #if __cplusplus >= 201103L || EIGEN_COMP_MSVC >= 1900
30 template <
bool NeedUprade>
struct MakeComplex {
33 T operator() (
const T& val)
const {
return val; }
36 template <>
struct MakeComplex<true> {
39 std::complex<T> operator() (
const T& val)
const {
return std::complex<T>(val, 0); }
42 template <>
struct MakeComplex<false> {
45 std::complex<T> operator() (
const std::complex<T>& val)
const {
return val; }
48 template <
int ResultType>
struct PartOf {
49 template <
typename T> T operator() (
const T& val)
const {
return val; }
52 template <>
struct PartOf<
RealPart> {
53 template <
typename T> T operator() (
const std::complex<T>& val)
const {
return val.real(); }
56 template <>
struct PartOf<
ImagPart> {
57 template <
typename T> T operator() (
const std::complex<T>& val)
const {
return val.imag(); }
61 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDir>
62 struct traits<TensorFFTOp<FFT, XprType,
FFTResultType, FFTDir> > :
public traits<XprType> {
63 typedef traits<XprType> XprTraits;
65 typedef typename std::complex<RealScalar> ComplexScalar;
68 typedef typename XprTraits::StorageKind StorageKind;
70 typedef typename XprType::Nested Nested;
72 static const int NumDimensions = XprTraits::NumDimensions;
73 static const int Layout = XprTraits::Layout;
76 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDirection>
78 typedef const TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>& type;
81 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDirection>
83 typedef TensorFFTOp<FFT, XprType, FFTResultType, FFTDirection>
type;
88 template <
typename FFT,
typename XprType,
int FFTResultType,
int FFTDir>
89 class TensorFFTOp :
public TensorBase<TensorFFTOp<FFT, XprType, FFTResultType, FFTDir>, ReadOnlyAccessors> {
93 typedef typename std::complex<RealScalar> ComplexScalar;
95 typedef OutputScalar CoeffReturnType;
101 : m_xpr(expr), m_fft(fft) {}
104 const FFT& fft()
const {
return m_fft; }
112 typename XprType::Nested m_xpr;
117 template <
typename FFT,
typename ArgType,
typename Device,
int FFTResultType,
int FFTDir>
118 struct TensorEvaluator<const TensorFFTOp<FFT, ArgType,
FFTResultType, FFTDir>, Device> {
119 typedef TensorFFTOp<FFT, ArgType, FFTResultType, FFTDir> XprType;
121 static const int NumDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
125 typedef typename std::complex<RealScalar> ComplexScalar;
127 typedef internal::traits<XprType> XprTraits;
145 for (
int i = 0; i < NumDims; ++i) {
147 m_dimensions[i] = input_dims[i];
152 for (
int i = 1; i < NumDims; ++i) {
153 m_strides[i] = m_strides[i - 1] * m_dimensions[i - 1];
156 m_strides[NumDims - 1] = 1;
157 for (
int i = NumDims - 2; i >= 0; --i) {
158 m_strides[i] = m_strides[i + 1] * m_dimensions[i + 1];
161 m_size = m_dimensions.TotalSize();
169 m_impl.evalSubExprsIfNeeded(NULL);
192 template <
int LoadMode>
195 return internal::ploadt<PacketReturnType, LoadMode>(
m_data + index);
200 return TensorOpCost(
sizeof(
CoeffReturnType), 0, 0, vectorized, PacketSize);
209 ComplexScalar* buf = write_to_out ? (ComplexScalar*)
data : (ComplexScalar*)
m_device.allocate(
sizeof(ComplexScalar) * m_size);
211 for (
Index i = 0; i < m_size; ++i) {
212 buf[i] = MakeComplex<internal::is_same<InputScalar, RealScalar>::value>()(
m_impl.coeff(i));
215 for (
size_t i = 0; i < m_fft.size(); ++i) {
216 Index dim = m_fft[i];
218 Index line_len = m_dimensions[dim];
220 ComplexScalar* line_buf = (ComplexScalar*)
m_device.allocate(
sizeof(ComplexScalar) * line_len);
221 const bool is_power_of_two = isPowerOfTwo(line_len);
222 const Index good_composite = is_power_of_two ? 0 : findGoodComposite(line_len);
223 const Index log_len = is_power_of_two ? getLog2(line_len) : getLog2(good_composite);
225 ComplexScalar*
a = is_power_of_two ? NULL : (ComplexScalar*)
m_device.allocate(
sizeof(ComplexScalar) * good_composite);
226 ComplexScalar*
b = is_power_of_two ? NULL : (ComplexScalar*)
m_device.allocate(
sizeof(ComplexScalar) * good_composite);
227 ComplexScalar* pos_j_base_powered = is_power_of_two ? NULL : (ComplexScalar*)
m_device.allocate(
sizeof(ComplexScalar) * (line_len + 1));
228 if (!is_power_of_two) {
233 pos_j_base_powered[0] = ComplexScalar(1, 0);
236 const ComplexScalar pos_j_base = ComplexScalar(
238 pos_j_base_powered[1] = pos_j_base;
240 const ComplexScalar pos_j_base_sq = pos_j_base * pos_j_base;
241 for (
int j = 2; j < line_len + 1; ++j) {
242 pos_j_base_powered[j] = pos_j_base_powered[j - 1] *
243 pos_j_base_powered[j - 1] /
244 pos_j_base_powered[j - 2] * pos_j_base_sq;
250 for (
Index partial_index = 0; partial_index < m_size / line_len; ++partial_index) {
251 const Index base_offset = getBaseOffsetFromIndex(partial_index, dim);
254 const Index stride = m_strides[dim];
256 memcpy(line_buf, &buf[base_offset], line_len*
sizeof(ComplexScalar));
258 Index offset = base_offset;
259 for (
int j = 0; j < line_len; ++j, offset += stride) {
260 line_buf[j] = buf[offset];
265 if (is_power_of_two) {
266 processDataLineCooleyTukey(line_buf, line_len, log_len);
269 processDataLineBluestein(line_buf, line_len, good_composite, log_len,
a,
b, pos_j_base_powered);
274 memcpy(&buf[base_offset], line_buf, line_len*
sizeof(ComplexScalar));
276 Index offset = base_offset;
277 const ComplexScalar div_factor = ComplexScalar(1.0 / line_len, 0);
278 for (
int j = 0; j < line_len; ++j, offset += stride) {
279 buf[offset] = (FFTDir ==
FFT_FORWARD) ? line_buf[j] : line_buf[j] * div_factor;
284 if (!is_power_of_two) {
287 m_device.deallocate(pos_j_base_powered);
292 for (
Index i = 0; i < m_size; ++i) {
293 data[i] = PartOf<FFTResultType>()(buf[i]);
301 return !(
x & (
x - 1));
307 while (i < 2 *
n - 1) i *= 2;
313 while (m >>= 1) log2m++;
320 scramble_FFT(line_buf, line_len);
321 compute_1D_Butterfly<FFTDir>(line_buf, line_len, log_len);
325 EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE void processDataLineBluestein(ComplexScalar* line_buf,
Index line_len,
Index good_composite,
Index log_len, ComplexScalar*
a, ComplexScalar*
b,
const ComplexScalar* pos_j_base_powered) {
327 Index m = good_composite;
328 ComplexScalar*
data = line_buf;
330 for (
Index i = 0; i <
n; ++i) {
335 a[i] =
data[i] * pos_j_base_powered[i];
338 for (
Index i =
n; i < m; ++i) {
339 a[i] = ComplexScalar(0, 0);
342 for (
Index i = 0; i <
n; ++i) {
344 b[i] = pos_j_base_powered[i];
350 for (
Index i =
n; i < m -
n; ++i) {
351 b[i] = ComplexScalar(0, 0);
353 for (
Index i = m -
n; i < m; ++i) {
355 b[i] = pos_j_base_powered[m-i];
363 compute_1D_Butterfly<FFT_FORWARD>(
a, m, log_len);
366 compute_1D_Butterfly<FFT_FORWARD>(
b, m, log_len);
368 for (
Index i = 0; i < m; ++i) {
373 compute_1D_Butterfly<FFT_REVERSE>(
a, m, log_len);
376 for (
Index i = 0; i < m; ++i) {
380 for (
Index i = 0; i <
n; ++i) {
385 data[i] =
a[i] * pos_j_base_powered[i];
393 for (
Index i = 1; i <
n; ++i){
398 while (m >= 2 && j > m) {
408 ComplexScalar tmp =
data[1];
415 ComplexScalar tmp[4];
420 tmp[3] = ComplexScalar(0.0, -1.0) * (
data[2] -
data[3]);
422 tmp[3] = ComplexScalar(0.0, 1.0) * (
data[2] -
data[3]);
424 data[0] = tmp[0] + tmp[2];
425 data[1] = tmp[1] + tmp[3];
426 data[2] = tmp[0] - tmp[2];
427 data[3] = tmp[1] - tmp[3];
432 ComplexScalar tmp_1[8];
433 ComplexScalar tmp_2[8];
439 tmp_1[3] = (
data[2] -
data[3]) * ComplexScalar(0, -1);
441 tmp_1[3] = (
data[2] -
data[3]) * ComplexScalar(0, 1);
447 tmp_1[7] = (
data[6] -
data[7]) * ComplexScalar(0, -1);
449 tmp_1[7] = (
data[6] -
data[7]) * ComplexScalar(0, 1);
451 tmp_2[0] = tmp_1[0] + tmp_1[2];
452 tmp_2[1] = tmp_1[1] + tmp_1[3];
453 tmp_2[2] = tmp_1[0] - tmp_1[2];
454 tmp_2[3] = tmp_1[1] - tmp_1[3];
455 tmp_2[4] = tmp_1[4] + tmp_1[6];
457 #define SQRT2DIV2 0.7071067811865476
459 tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, -SQRT2DIV2);
460 tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, -1);
461 tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, -SQRT2DIV2);
463 tmp_2[5] = (tmp_1[5] + tmp_1[7]) * ComplexScalar(SQRT2DIV2, SQRT2DIV2);
464 tmp_2[6] = (tmp_1[4] - tmp_1[6]) * ComplexScalar(0, 1);
465 tmp_2[7] = (tmp_1[5] - tmp_1[7]) * ComplexScalar(-SQRT2DIV2, SQRT2DIV2);
467 data[0] = tmp_2[0] + tmp_2[4];
468 data[1] = tmp_2[1] + tmp_2[5];
469 data[2] = tmp_2[2] + tmp_2[6];
470 data[3] = tmp_2[3] + tmp_2[7];
471 data[4] = tmp_2[0] - tmp_2[4];
472 data[5] = tmp_2[1] - tmp_2[5];
473 data[6] = tmp_2[2] - tmp_2[6];
474 data[7] = tmp_2[3] - tmp_2[7];
483 const RealScalar wtemp = m_sin_PI_div_n_LUT[n_power_of_2];
485 ? m_minus_sin_2_PI_div_n_LUT[n_power_of_2]
486 : -m_minus_sin_2_PI_div_n_LUT[n_power_of_2];
488 const ComplexScalar wp(wtemp, wpi);
489 const ComplexScalar wp_one = wp + ComplexScalar(1, 0);
490 const ComplexScalar wp_one_2 = wp_one * wp_one;
491 const ComplexScalar wp_one_3 = wp_one_2 * wp_one;
492 const ComplexScalar wp_one_4 = wp_one_3 * wp_one;
494 ComplexScalar w(1.0, 0.0);
495 for (
Index i = 0; i < n2; i += 4) {
496 ComplexScalar temp0(
data[i + n2] * w);
497 ComplexScalar temp1(
data[i + 1 + n2] * w * wp_one);
498 ComplexScalar temp2(
data[i + 2 + n2] * w * wp_one_2);
499 ComplexScalar temp3(
data[i + 3 + n2] * w * wp_one_3);
505 data[i + 1 + n2] =
data[i + 1] - temp1;
506 data[i + 1] += temp1;
508 data[i + 2 + n2] =
data[i + 2] - temp2;
509 data[i + 2] += temp2;
511 data[i + 3 + n2] =
data[i + 3] - temp3;
512 data[i + 3] += temp3;
521 compute_1D_Butterfly<Dir>(
data,
n / 2, n_power_of_2 - 1);
522 compute_1D_Butterfly<Dir>(
data +
n / 2,
n / 2, n_power_of_2 - 1);
523 butterfly_1D_merge<Dir>(
data,
n, n_power_of_2);
525 butterfly_8<Dir>(
data);
527 butterfly_4<Dir>(
data);
529 butterfly_2<Dir>(
data);
537 for (
int i = NumDims - 1; i > omitted_dim; --i) {
538 const Index partial_m_stride = m_strides[i] / m_dimensions[omitted_dim];
539 const Index idx = index / partial_m_stride;
540 index -= idx * partial_m_stride;
541 result += idx * m_strides[i];
546 for (
Index i = 0; i < omitted_dim; ++i) {
547 const Index partial_m_stride = m_strides[i] / m_dimensions[omitted_dim];
548 const Index idx = index / partial_m_stride;
549 index -= idx * partial_m_stride;
550 result += idx * m_strides[i];
559 Index result = base + offset * m_strides[omitted_dim] ;
567 array<Index, NumDims> m_strides;
568 TensorEvaluator<ArgType, Device>
m_impl;
610 const RealScalar m_minus_sin_2_PI_div_n_LUT[32] = {
648 #endif // EIGEN_HAS_CONSTEXPR
651 #endif // EIGEN_CXX11_TENSOR_TENSOR_FFT_H