mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-08 09:49:03 +08:00
Fix FFT when destination does not have unit stride.
This commit is contained in:
parent
99c18bce6e
commit
e16d70bd4e
@ -232,10 +232,11 @@ class FFT {
|
||||
|
||||
if (nfft < 1) nfft = src.size();
|
||||
|
||||
if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum))
|
||||
dst.derived().resize((nfft >> 1) + 1);
|
||||
else
|
||||
dst.derived().resize(nfft);
|
||||
Index dst_size = nfft;
|
||||
if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum)) {
|
||||
dst_size = (nfft >> 1) + 1;
|
||||
}
|
||||
dst.derived().resize(dst_size);
|
||||
|
||||
if (src.innerStride() != 1 || src.size() < nfft) {
|
||||
Matrix<src_type, 1, Dynamic> tmp;
|
||||
@ -245,9 +246,21 @@ class FFT {
|
||||
} else {
|
||||
tmp = src;
|
||||
}
|
||||
fwd(&dst[0], &tmp[0], nfft);
|
||||
if (dst.innerStride() != 1) {
|
||||
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
|
||||
fwd(&out[0], &tmp[0], nfft);
|
||||
dst.derived() = out;
|
||||
} else {
|
||||
fwd(&dst[0], &tmp[0], nfft);
|
||||
}
|
||||
} else {
|
||||
fwd(&dst[0], &src[0], nfft);
|
||||
if (dst.innerStride() != 1) {
|
||||
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
|
||||
fwd(&out[0], &src[0], nfft);
|
||||
dst.derived() = out;
|
||||
} else {
|
||||
fwd(&dst[0], &src[0], nfft);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -326,9 +339,22 @@ class FFT {
|
||||
} else {
|
||||
tmp = src;
|
||||
}
|
||||
inv(&dst[0], &tmp[0], nfft);
|
||||
|
||||
if (dst.innerStride() != 1) {
|
||||
Matrix<dst_type, 1, Dynamic> out(1, nfft);
|
||||
inv(&out[0], &tmp[0], nfft);
|
||||
dst.derived() = out;
|
||||
} else {
|
||||
inv(&dst[0], &tmp[0], nfft);
|
||||
}
|
||||
} else {
|
||||
inv(&dst[0], &src[0], nfft);
|
||||
if (dst.innerStride() != 1) {
|
||||
Matrix<dst_type, 1, Dynamic> out(1, nfft);
|
||||
inv(&out[0], &src[0], nfft);
|
||||
dst.derived() = out;
|
||||
} else {
|
||||
inv(&dst[0], &src[0], nfft);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -163,10 +163,42 @@ void test_complex_generic(int nfft) {
|
||||
VERIFY(T(dif_rmse(inbuf, buf3)) < test_precision<T>()); // gross check
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void test_complex_strided(int nfft) {
|
||||
typedef typename FFT<T>::Complex Complex;
|
||||
typedef typename Eigen::Vector<Complex, Dynamic> ComplexVector;
|
||||
constexpr int kInputStride = 3;
|
||||
constexpr int kOutputStride = 7;
|
||||
constexpr int kInvOutputStride = 13;
|
||||
|
||||
FFT<T> fft;
|
||||
|
||||
ComplexVector inbuf(nfft * kInputStride);
|
||||
inbuf.setRandom();
|
||||
ComplexVector outbuf(nfft * kOutputStride);
|
||||
outbuf.setRandom();
|
||||
ComplexVector invoutbuf(nfft * kInvOutputStride);
|
||||
invoutbuf.setRandom();
|
||||
|
||||
using StridedComplexVector = Map<ComplexVector, /*MapOptions=*/0, InnerStride<Dynamic>>;
|
||||
StridedComplexVector input(inbuf.data(), nfft, InnerStride<Dynamic>(kInputStride));
|
||||
StridedComplexVector output(outbuf.data(), nfft, InnerStride<Dynamic>(kOutputStride));
|
||||
StridedComplexVector inv_output(invoutbuf.data(), nfft, InnerStride<Dynamic>(kInvOutputStride));
|
||||
|
||||
for (int k = 0; k < nfft; ++k)
|
||||
input[k] = Complex((T)(rand() / (double)RAND_MAX - .5), (T)(rand() / (double)RAND_MAX - .5));
|
||||
fft.fwd(output, input);
|
||||
|
||||
VERIFY(T(fft_rmse(output, input)) < test_precision<T>()); // gross check
|
||||
fft.inv(inv_output, output);
|
||||
VERIFY(T(dif_rmse(inv_output, input)) < test_precision<T>()); // gross check
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void test_complex(int nfft) {
|
||||
test_complex_generic<StdVectorContainer, T>(nfft);
|
||||
test_complex_generic<EigenVectorContainer, T>(nfft);
|
||||
test_complex_strided<T>(nfft);
|
||||
}
|
||||
|
||||
template <typename T, int nrows, int ncols>
|
||||
|
Loading…
x
Reference in New Issue
Block a user