diff --git a/unsupported/Eigen/FFT b/unsupported/Eigen/FFT index 630be1ea7..557fdf6c6 100644 --- a/unsupported/Eigen/FFT +++ b/unsupported/Eigen/FFT @@ -231,11 +231,12 @@ class FFT { THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES) if (nfft < 1) nfft = src.size(); - - if (NumTraits::IsComplex == 0 && HasFlag(HalfSpectrum)) - dst.derived().resize((nfft >> 1) + 1); - else - dst.derived().resize(nfft); + + Index dst_size = nfft; + if (NumTraits::IsComplex == 0 && HasFlag(HalfSpectrum)) { + dst_size = (nfft >> 1) + 1; + } + dst.derived().resize(dst_size); if (src.innerStride() != 1 || src.size() < nfft) { Matrix tmp; @@ -245,9 +246,21 @@ class FFT { } else { tmp = src; } - fwd(&dst[0], &tmp[0], nfft); + if (dst.innerStride() != 1) { + Matrix out(1, dst_size); + fwd(&out[0], &tmp[0], nfft); + dst.derived() = out; + } else { + fwd(&dst[0], &tmp[0], nfft); + } } else { - fwd(&dst[0], &src[0], nfft); + if (dst.innerStride() != 1) { + Matrix out(1, dst_size); + fwd(&out[0], &src[0], nfft); + dst.derived() = out; + } else { + fwd(&dst[0], &src[0], nfft); + } } } @@ -326,9 +339,22 @@ class FFT { } else { tmp = src; } - inv(&dst[0], &tmp[0], nfft); + + if (dst.innerStride() != 1) { + Matrix out(1, nfft); + inv(&out[0], &tmp[0], nfft); + dst.derived() = out; + } else { + inv(&dst[0], &tmp[0], nfft); + } } else { - inv(&dst[0], &src[0], nfft); + if (dst.innerStride() != 1) { + Matrix out(1, nfft); + inv(&out[0], &src[0], nfft); + dst.derived() = out; + } else { + inv(&dst[0], &src[0], nfft); + } } } diff --git a/unsupported/test/fft_test_shared.h b/unsupported/test/fft_test_shared.h index 0e040ad70..3adcd9098 100644 --- a/unsupported/test/fft_test_shared.h +++ b/unsupported/test/fft_test_shared.h @@ -163,10 +163,42 @@ void test_complex_generic(int nfft) { VERIFY(T(dif_rmse(inbuf, buf3)) < test_precision()); // gross check } +template +void test_complex_strided(int nfft) { + typedef typename FFT::Complex Complex; + typedef typename Eigen::Vector ComplexVector; + constexpr int kInputStride = 3; + constexpr int kOutputStride = 7; + constexpr int kInvOutputStride = 13; + + FFT fft; + + ComplexVector inbuf(nfft * kInputStride); + inbuf.setRandom(); + ComplexVector outbuf(nfft * kOutputStride); + outbuf.setRandom(); + ComplexVector invoutbuf(nfft * kInvOutputStride); + invoutbuf.setRandom(); + + using StridedComplexVector = Map>; + StridedComplexVector input(inbuf.data(), nfft, InnerStride(kInputStride)); + StridedComplexVector output(outbuf.data(), nfft, InnerStride(kOutputStride)); + StridedComplexVector inv_output(invoutbuf.data(), nfft, InnerStride(kInvOutputStride)); + + for (int k = 0; k < nfft; ++k) + input[k] = Complex((T)(rand() / (double)RAND_MAX - .5), (T)(rand() / (double)RAND_MAX - .5)); + fft.fwd(output, input); + + VERIFY(T(fft_rmse(output, input)) < test_precision()); // gross check + fft.inv(inv_output, output); + VERIFY(T(dif_rmse(inv_output, input)) < test_precision()); // gross check +} + template void test_complex(int nfft) { test_complex_generic(nfft); test_complex_generic(nfft); + test_complex_strided(nfft); } template