Fix FFT when destination does not have unit stride.

This commit is contained in:
Antonio Sánchez 2024-05-07 17:18:29 +00:00 committed by Rasmus Munk Larsen
parent 99c18bce6e
commit e16d70bd4e
2 changed files with 67 additions and 9 deletions

View File

@ -232,10 +232,11 @@ class FFT {
if (nfft < 1) nfft = src.size();
if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum))
dst.derived().resize((nfft >> 1) + 1);
else
dst.derived().resize(nfft);
Index dst_size = nfft;
if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum)) {
dst_size = (nfft >> 1) + 1;
}
dst.derived().resize(dst_size);
if (src.innerStride() != 1 || src.size() < nfft) {
Matrix<src_type, 1, Dynamic> tmp;
@ -245,11 +246,23 @@ class FFT {
} else {
tmp = src;
}
if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
fwd(&out[0], &tmp[0], nfft);
dst.derived() = out;
} else {
fwd(&dst[0], &tmp[0], nfft);
}
} else {
if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
fwd(&out[0], &src[0], nfft);
dst.derived() = out;
} else {
fwd(&dst[0], &src[0], nfft);
}
}
}
template <typename InputDerived>
inline fft_fwd_proxy<MatrixBase<InputDerived>, FFT<T_Scalar, T_Impl> > fwd(const MatrixBase<InputDerived>& src,
@ -326,11 +339,24 @@ class FFT {
} else {
tmp = src;
}
if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, nfft);
inv(&out[0], &tmp[0], nfft);
dst.derived() = out;
} else {
inv(&dst[0], &tmp[0], nfft);
}
} else {
if (dst.innerStride() != 1) {
Matrix<dst_type, 1, Dynamic> out(1, nfft);
inv(&out[0], &src[0], nfft);
dst.derived() = out;
} else {
inv(&dst[0], &src[0], nfft);
}
}
}
template <typename Output_>
inline void inv(std::vector<Output_>& dst, const std::vector<Complex>& src, Index nfft = -1) {

View File

@ -163,10 +163,42 @@ void test_complex_generic(int nfft) {
VERIFY(T(dif_rmse(inbuf, buf3)) < test_precision<T>()); // gross check
}
template <typename T>
void test_complex_strided(int nfft) {
typedef typename FFT<T>::Complex Complex;
typedef typename Eigen::Vector<Complex, Dynamic> ComplexVector;
constexpr int kInputStride = 3;
constexpr int kOutputStride = 7;
constexpr int kInvOutputStride = 13;
FFT<T> fft;
ComplexVector inbuf(nfft * kInputStride);
inbuf.setRandom();
ComplexVector outbuf(nfft * kOutputStride);
outbuf.setRandom();
ComplexVector invoutbuf(nfft * kInvOutputStride);
invoutbuf.setRandom();
using StridedComplexVector = Map<ComplexVector, /*MapOptions=*/0, InnerStride<Dynamic>>;
StridedComplexVector input(inbuf.data(), nfft, InnerStride<Dynamic>(kInputStride));
StridedComplexVector output(outbuf.data(), nfft, InnerStride<Dynamic>(kOutputStride));
StridedComplexVector inv_output(invoutbuf.data(), nfft, InnerStride<Dynamic>(kInvOutputStride));
for (int k = 0; k < nfft; ++k)
input[k] = Complex((T)(rand() / (double)RAND_MAX - .5), (T)(rand() / (double)RAND_MAX - .5));
fft.fwd(output, input);
VERIFY(T(fft_rmse(output, input)) < test_precision<T>()); // gross check
fft.inv(inv_output, output);
VERIFY(T(dif_rmse(inv_output, input)) < test_precision<T>()); // gross check
}
template <typename T>
void test_complex(int nfft) {
test_complex_generic<StdVectorContainer, T>(nfft);
test_complex_generic<EigenVectorContainer, T>(nfft);
test_complex_strided<T>(nfft);
}
template <typename T, int nrows, int ncols>