Fix FFT when destination does not have unit stride.

This commit is contained in:
Antonio Sánchez 2024-05-07 17:18:29 +00:00 committed by Rasmus Munk Larsen
parent 99c18bce6e
commit e16d70bd4e
2 changed files with 67 additions and 9 deletions

View File

@ -232,10 +232,11 @@ class FFT {
if (nfft < 1) nfft = src.size(); if (nfft < 1) nfft = src.size();
if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum)) Index dst_size = nfft;
dst.derived().resize((nfft >> 1) + 1); if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum)) {
else dst_size = (nfft >> 1) + 1;
dst.derived().resize(nfft); }
dst.derived().resize(dst_size);
if (src.innerStride() != 1 || src.size() < nfft) { if (src.innerStride() != 1 || src.size() < nfft) {
Matrix<src_type, 1, Dynamic> tmp; Matrix<src_type, 1, Dynamic> tmp;
@ -245,9 +246,21 @@ class FFT {
} else { } else {
tmp = src; tmp = src;
} }
fwd(&dst[0], &tmp[0], nfft); if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
fwd(&out[0], &tmp[0], nfft);
dst.derived() = out;
} else {
fwd(&dst[0], &tmp[0], nfft);
}
} else { } else {
fwd(&dst[0], &src[0], nfft); if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
fwd(&out[0], &src[0], nfft);
dst.derived() = out;
} else {
fwd(&dst[0], &src[0], nfft);
}
} }
} }
@ -326,9 +339,22 @@ class FFT {
} else { } else {
tmp = src; tmp = src;
} }
inv(&dst[0], &tmp[0], nfft);
if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, nfft);
inv(&out[0], &tmp[0], nfft);
dst.derived() = out;
} else {
inv(&dst[0], &tmp[0], nfft);
}
} else { } else {
inv(&dst[0], &src[0], nfft); if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, nfft);
inv(&out[0], &src[0], nfft);
dst.derived() = out;
} else {
inv(&dst[0], &src[0], nfft);
}
} }
} }

View File

@ -163,10 +163,42 @@ void test_complex_generic(int nfft) {
VERIFY(T(dif_rmse(inbuf, buf3)) < test_precision<T>()); // gross check VERIFY(T(dif_rmse(inbuf, buf3)) < test_precision<T>()); // gross check
} }
template <typename T>
void test_complex_strided(int nfft) {
typedef typename FFT<T>::Complex Complex;
typedef typename Eigen::Vector<Complex, Dynamic> ComplexVector;
constexpr int kInputStride = 3;
constexpr int kOutputStride = 7;
constexpr int kInvOutputStride = 13;
FFT<T> fft;
ComplexVector inbuf(nfft * kInputStride);
inbuf.setRandom();
ComplexVector outbuf(nfft * kOutputStride);
outbuf.setRandom();
ComplexVector invoutbuf(nfft * kInvOutputStride);
invoutbuf.setRandom();
using StridedComplexVector = Map<ComplexVector, /*MapOptions=*/0, InnerStride<Dynamic>>;
StridedComplexVector input(inbuf.data(), nfft, InnerStride<Dynamic>(kInputStride));
StridedComplexVector output(outbuf.data(), nfft, InnerStride<Dynamic>(kOutputStride));
StridedComplexVector inv_output(invoutbuf.data(), nfft, InnerStride<Dynamic>(kInvOutputStride));
for (int k = 0; k < nfft; ++k)
input[k] = Complex((T)(rand() / (double)RAND_MAX - .5), (T)(rand() / (double)RAND_MAX - .5));
fft.fwd(output, input);
VERIFY(T(fft_rmse(output, input)) < test_precision<T>()); // gross check
fft.inv(inv_output, output);
VERIFY(T(dif_rmse(inv_output, input)) < test_precision<T>()); // gross check
}
template <typename T> template <typename T>
void test_complex(int nfft) { void test_complex(int nfft) {
test_complex_generic<StdVectorContainer, T>(nfft); test_complex_generic<StdVectorContainer, T>(nfft);
test_complex_generic<EigenVectorContainer, T>(nfft); test_complex_generic<EigenVectorContainer, T>(nfft);
test_complex_strided<T>(nfft);
} }
template <typename T, int nrows, int ncols> template <typename T, int nrows, int ncols>