moved half-spectrum logic to Eigen::FFT

This commit is contained in:
Mark Borgerding 2009-10-30 23:38:13 -04:00
parent d659fd9b14
commit 4c3345364e
2 changed files with 56 additions and 18 deletions

View File

@ -283,12 +283,11 @@
m_realTwiddles.clear(); m_realTwiddles.clear();
} }
template <typename _Src>
inline inline
void fwd( Complex * dst,const _Src *src,int nfft) void fwd( Complex * dst,const Complex *src,int nfft)
{ {
get_plan(nfft,false).work(0, dst, src, 1,1); get_plan(nfft,false).work(0, dst, src, 1,1);
} }
// real-to-complex forward FFT // real-to-complex forward FFT
// perform two FFTs of src even and src odd // perform two FFTs of src even and src odd
@ -299,7 +298,9 @@
{ {
if ( nfft&3 ) { if ( nfft&3 ) {
// use generic mode for odd // use generic mode for odd
get_plan(nfft,false).work(0, dst, src, 1,1); m_tmpBuf1.resize(nfft);
get_plan(nfft,false).work(0, &m_tmpBuf1[0], src, 1,1);
std::copy(m_tmpBuf1.begin(),m_tmpBuf1.begin()+(nfft>>1)+1,dst );
}else{ }else{
int ncfft = nfft>>1; int ncfft = nfft>>1;
int ncfft2 = nfft>>2; int ncfft2 = nfft>>2;
@ -319,9 +320,6 @@
dst[k] = (f1k + tw) * Scalar(.5); dst[k] = (f1k + tw) * Scalar(.5);
dst[ncfft-k] = conj(f1k -tw)*Scalar(.5); dst[ncfft-k] = conj(f1k -tw)*Scalar(.5);
} }
// place conjugate-symmetric half at the end for completeness
// TODO: make this configurable ( opt-out )
dst[0] = dc; dst[0] = dc;
dst[ncfft] = nyquist; dst[ncfft] = nyquist;
} }
@ -339,27 +337,31 @@
void inv( Scalar * dst,const Complex * src,int nfft) void inv( Scalar * dst,const Complex * src,int nfft)
{ {
if (nfft&3) { if (nfft&3) {
m_tmpBuf.resize(nfft); m_tmpBuf1.resize(nfft);
inv(&m_tmpBuf[0],src,nfft); m_tmpBuf2.resize(nfft);
std::copy(src,src+(nfft>>1)+1,m_tmpBuf1.begin() );
for (int k=1;k<(nfft>>1)+1;++k)
m_tmpBuf1[nfft-k] = conj(m_tmpBuf1[k]);
inv(&m_tmpBuf2[0],&m_tmpBuf1[0],nfft);
for (int k=0;k<nfft;++k) for (int k=0;k<nfft;++k)
dst[k] = m_tmpBuf[k].real(); dst[k] = m_tmpBuf2[k].real();
}else{ }else{
// optimized version for multiple of 4 // optimized version for multiple of 4
int ncfft = nfft>>1; int ncfft = nfft>>1;
int ncfft2 = nfft>>2; int ncfft2 = nfft>>2;
Complex * rtw = real_twiddles(ncfft2); Complex * rtw = real_twiddles(ncfft2);
m_tmpBuf.resize(ncfft); m_tmpBuf1.resize(ncfft);
m_tmpBuf[0] = Complex( src[0].real() + src[ncfft].real(), src[0].real() - src[ncfft].real() ); m_tmpBuf1[0] = Complex( src[0].real() + src[ncfft].real(), src[0].real() - src[ncfft].real() );
for (int k = 1; k <= ncfft / 2; ++k) { for (int k = 1; k <= ncfft / 2; ++k) {
Complex fk = src[k]; Complex fk = src[k];
Complex fnkc = conj(src[ncfft-k]); Complex fnkc = conj(src[ncfft-k]);
Complex fek = fk + fnkc; Complex fek = fk + fnkc;
Complex tmp = fk - fnkc; Complex tmp = fk - fnkc;
Complex fok = tmp * conj(rtw[k-1]); Complex fok = tmp * conj(rtw[k-1]);
m_tmpBuf[k] = fek + fok; m_tmpBuf1[k] = fek + fok;
m_tmpBuf[ncfft-k] = conj(fek - fok); m_tmpBuf1[ncfft-k] = conj(fek - fok);
} }
get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_tmpBuf[0], 1,1); get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_tmpBuf1[0], 1,1);
} }
} }
@ -369,7 +371,8 @@
PlanMap m_plans; PlanMap m_plans;
std::map<int, std::vector<Complex> > m_realTwiddles; std::map<int, std::vector<Complex> > m_realTwiddles;
std::vector<Complex> m_tmpBuf; std::vector<Complex> m_tmpBuf1;
std::vector<Complex> m_tmpBuf2;
inline inline
int PlanKey(int nfft,bool isinverse) const { return (nfft<<1) | isinverse; } int PlanKey(int nfft,bool isinverse) const { return (nfft<<1) | isinverse; }

View File

@ -101,12 +101,34 @@ void test_scalar_generic(int nfft)
ComplexVector outbuf; ComplexVector outbuf;
for (int k=0;k<nfft;++k) for (int k=0;k<nfft;++k)
inbuf[k]= (T)(rand()/(double)RAND_MAX - .5); inbuf[k]= (T)(rand()/(double)RAND_MAX - .5);
// make sure it DOESN'T give the right full spectrum answer
// if we've asked for half-spectrum
fft.SetFlag(fft.HalfSpectrum );
fft.fwd( outbuf,inbuf);
VERIFY(outbuf.size() == (nfft>>1)+1);
VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>() );// gross check
fft.ClearFlag(fft.HalfSpectrum );
fft.fwd( outbuf,inbuf); fft.fwd( outbuf,inbuf);
VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>() );// gross check VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>() );// gross check
ScalarVector buf3; ScalarVector buf3;
fft.inv( buf3 , outbuf); fft.inv( buf3 , outbuf);
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
// verify that the Unscaled flag takes effect
ComplexVector buf4;
fft.SetFlag(fft.Unscaled);
fft.inv( buf4 , outbuf);
for (int k=0;k<nfft;++k)
buf4[k] *= T(1./nfft);
VERIFY( dif_rmse(inbuf,buf4) < test_precision<T>() );// gross check
// verify that ClearFlag works
fft.ClearFlag(fft.Unscaled);
fft.inv( buf3 , outbuf);
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
} }
template <typename T> template <typename T>
@ -136,6 +158,19 @@ void test_complex_generic(int nfft)
fft.inv( buf3 , outbuf); fft.inv( buf3 , outbuf);
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
// verify that the Unscaled flag takes effect
ComplexVector buf4;
fft.SetFlag(fft.Unscaled);
fft.inv( buf4 , outbuf);
for (int k=0;k<nfft;++k)
buf4[k] *= T(1./nfft);
VERIFY( dif_rmse(inbuf,buf4) < test_precision<T>() );// gross check
// verify that ClearFlag works
fft.ClearFlag(fft.Unscaled);
fft.inv( buf3 , outbuf);
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
} }
template <typename T> template <typename T>