mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-08 09:49:03 +08:00
Fix FFT when destination does not have unit stride.
This commit is contained in:
parent
99c18bce6e
commit
e16d70bd4e
@ -232,10 +232,11 @@ class FFT {
|
|||||||
|
|
||||||
if (nfft < 1) nfft = src.size();
|
if (nfft < 1) nfft = src.size();
|
||||||
|
|
||||||
if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum))
|
Index dst_size = nfft;
|
||||||
dst.derived().resize((nfft >> 1) + 1);
|
if (NumTraits<src_type>::IsComplex == 0 && HasFlag(HalfSpectrum)) {
|
||||||
else
|
dst_size = (nfft >> 1) + 1;
|
||||||
dst.derived().resize(nfft);
|
}
|
||||||
|
dst.derived().resize(dst_size);
|
||||||
|
|
||||||
if (src.innerStride() != 1 || src.size() < nfft) {
|
if (src.innerStride() != 1 || src.size() < nfft) {
|
||||||
Matrix<src_type, 1, Dynamic> tmp;
|
Matrix<src_type, 1, Dynamic> tmp;
|
||||||
@ -245,11 +246,23 @@ class FFT {
|
|||||||
} else {
|
} else {
|
||||||
tmp = src;
|
tmp = src;
|
||||||
}
|
}
|
||||||
|
if (dst.innerStride() != 1) {
|
||||||
|
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
|
||||||
|
fwd(&out[0], &tmp[0], nfft);
|
||||||
|
dst.derived() = out;
|
||||||
|
} else {
|
||||||
fwd(&dst[0], &tmp[0], nfft);
|
fwd(&dst[0], &tmp[0], nfft);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (dst.innerStride() != 1) {
|
||||||
|
Matrix<dst_type, 1, Dynamic> out(1, dst_size);
|
||||||
|
fwd(&out[0], &src[0], nfft);
|
||||||
|
dst.derived() = out;
|
||||||
} else {
|
} else {
|
||||||
fwd(&dst[0], &src[0], nfft);
|
fwd(&dst[0], &src[0], nfft);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename InputDerived>
|
template <typename InputDerived>
|
||||||
inline fft_fwd_proxy<MatrixBase<InputDerived>, FFT<T_Scalar, T_Impl> > fwd(const MatrixBase<InputDerived>& src,
|
inline fft_fwd_proxy<MatrixBase<InputDerived>, FFT<T_Scalar, T_Impl> > fwd(const MatrixBase<InputDerived>& src,
|
||||||
@ -326,11 +339,24 @@ class FFT {
|
|||||||
} else {
|
} else {
|
||||||
tmp = src;
|
tmp = src;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (dst.innerStride() != 1) {
|
||||||
|
Matrix<dst_type, 1, Dynamic> out(1, nfft);
|
||||||
|
inv(&out[0], &tmp[0], nfft);
|
||||||
|
dst.derived() = out;
|
||||||
|
} else {
|
||||||
inv(&dst[0], &tmp[0], nfft);
|
inv(&dst[0], &tmp[0], nfft);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (dst.innerStride() != 1) {
|
||||||
|
Matrix<dst_type, 1, Dynamic> out(1, nfft);
|
||||||
|
inv(&out[0], &src[0], nfft);
|
||||||
|
dst.derived() = out;
|
||||||
} else {
|
} else {
|
||||||
inv(&dst[0], &src[0], nfft);
|
inv(&dst[0], &src[0], nfft);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Output_>
|
template <typename Output_>
|
||||||
inline void inv(std::vector<Output_>& dst, const std::vector<Complex>& src, Index nfft = -1) {
|
inline void inv(std::vector<Output_>& dst, const std::vector<Complex>& src, Index nfft = -1) {
|
||||||
|
@ -163,10 +163,42 @@ void test_complex_generic(int nfft) {
|
|||||||
VERIFY(T(dif_rmse(inbuf, buf3)) < test_precision<T>()); // gross check
|
VERIFY(T(dif_rmse(inbuf, buf3)) < test_precision<T>()); // gross check
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void test_complex_strided(int nfft) {
|
||||||
|
typedef typename FFT<T>::Complex Complex;
|
||||||
|
typedef typename Eigen::Vector<Complex, Dynamic> ComplexVector;
|
||||||
|
constexpr int kInputStride = 3;
|
||||||
|
constexpr int kOutputStride = 7;
|
||||||
|
constexpr int kInvOutputStride = 13;
|
||||||
|
|
||||||
|
FFT<T> fft;
|
||||||
|
|
||||||
|
ComplexVector inbuf(nfft * kInputStride);
|
||||||
|
inbuf.setRandom();
|
||||||
|
ComplexVector outbuf(nfft * kOutputStride);
|
||||||
|
outbuf.setRandom();
|
||||||
|
ComplexVector invoutbuf(nfft * kInvOutputStride);
|
||||||
|
invoutbuf.setRandom();
|
||||||
|
|
||||||
|
using StridedComplexVector = Map<ComplexVector, /*MapOptions=*/0, InnerStride<Dynamic>>;
|
||||||
|
StridedComplexVector input(inbuf.data(), nfft, InnerStride<Dynamic>(kInputStride));
|
||||||
|
StridedComplexVector output(outbuf.data(), nfft, InnerStride<Dynamic>(kOutputStride));
|
||||||
|
StridedComplexVector inv_output(invoutbuf.data(), nfft, InnerStride<Dynamic>(kInvOutputStride));
|
||||||
|
|
||||||
|
for (int k = 0; k < nfft; ++k)
|
||||||
|
input[k] = Complex((T)(rand() / (double)RAND_MAX - .5), (T)(rand() / (double)RAND_MAX - .5));
|
||||||
|
fft.fwd(output, input);
|
||||||
|
|
||||||
|
VERIFY(T(fft_rmse(output, input)) < test_precision<T>()); // gross check
|
||||||
|
fft.inv(inv_output, output);
|
||||||
|
VERIFY(T(dif_rmse(inv_output, input)) < test_precision<T>()); // gross check
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void test_complex(int nfft) {
|
void test_complex(int nfft) {
|
||||||
test_complex_generic<StdVectorContainer, T>(nfft);
|
test_complex_generic<StdVectorContainer, T>(nfft);
|
||||||
test_complex_generic<EigenVectorContainer, T>(nfft);
|
test_complex_generic<EigenVectorContainer, T>(nfft);
|
||||||
|
test_complex_strided<T>(nfft);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int nrows, int ncols>
|
template <typename T, int nrows, int ncols>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user