added real-optimized inverse FFT (NFFT must be multiple of 4)

This commit is contained in:
Mark Borgerding 2009-05-25 23:52:21 -04:00
parent 03ed6f9bfb
commit 09b4733255
2 changed files with 370 additions and 348 deletions

View File

@ -53,7 +53,7 @@ template <> string nameof<long double>() {return "long double";}
using namespace Eigen; using namespace Eigen;
template <typename T> template <typename T>
void bench(int nfft) void bench(int nfft,bool fwd)
{ {
typedef typename NumTraits<T>::Real Scalar; typedef typename NumTraits<T>::Real Scalar;
typedef typename std::complex<Scalar> Complex; typedef typename std::complex<Scalar> Complex;
@ -69,7 +69,10 @@ void bench(int nfft)
for (int k=0;k<8;++k) { for (int k=0;k<8;++k) {
timer.start(); timer.start();
for(int i = 0; i < nits; i++) for(int i = 0; i < nits; i++)
if (fwd)
fft.fwd( outbuf , inbuf); fft.fwd( outbuf , inbuf);
else
fft.inv(inbuf,outbuf);
timer.stop(); timer.stop();
} }
@ -82,16 +85,27 @@ void bench(int nfft)
mflops /= 2; mflops /= 2;
} }
if (fwd)
cout << " fwd";
else
cout << " inv";
cout << " NFFT=" << nfft << " " << (double(1e-6*nfft*nits)/timer.value()) << " MS/s " << mflops << "MFLOPS\n"; cout << " NFFT=" << nfft << " " << (double(1e-6*nfft*nits)/timer.value()) << " MS/s " << mflops << "MFLOPS\n";
} }
int main(int argc,char ** argv) int main(int argc,char ** argv)
{ {
bench<complex<float> >(NFFT); bench<complex<float> >(NFFT,true);
bench<float>(NFFT); bench<complex<float> >(NFFT,false);
bench<complex<double> >(NFFT); bench<float>(NFFT,true);
bench<double>(NFFT); bench<float>(NFFT,false);
bench<complex<long double> >(NFFT); bench<complex<double> >(NFFT,true);
bench<long double>(NFFT); bench<complex<double> >(NFFT,false);
bench<double>(NFFT,true);
bench<double>(NFFT,false);
bench<complex<long double> >(NFFT,true);
bench<complex<long double> >(NFFT,false);
bench<long double>(NFFT,true);
bench<long double>(NFFT,false);
return 0; return 0;
} }

View File

@ -38,8 +38,6 @@ namespace Eigen {
std::vector<int> m_stageRemainder; std::vector<int> m_stageRemainder;
bool m_inverse; bool m_inverse;
ei_kiss_cpx_fft() { }
void make_twiddles(int nfft,bool inverse) void make_twiddles(int nfft,bool inverse)
{ {
m_inverse = inverse; m_inverse = inverse;
@ -49,7 +47,7 @@ namespace Eigen {
m_twiddles[i] = exp( Complex(0,i*phinc) ); m_twiddles[i] = exp( Complex(0,i*phinc) );
} }
void invert() void conjugate()
{ {
m_inverse = !m_inverse; m_inverse = !m_inverse;
for ( size_t i=0;i<m_twiddles.size() ;++i) for ( size_t i=0;i<m_twiddles.size() ;++i)
@ -58,11 +56,6 @@ namespace Eigen {
void factorize(int nfft) void factorize(int nfft)
{ {
if (m_stageRadix.size()==0 || m_stageRadix[0] * m_stageRemainder[0] != nfft)
{
m_stageRadix.resize(0);
m_stageRemainder.resize(0);
//factorize
//start factoring out 4's, then 2's, then 3,5,7,9,... //start factoring out 4's, then 2's, then 3,5,7,9,...
int n= nfft; int n= nfft;
int p=4; int p=4;
@ -81,7 +74,6 @@ namespace Eigen {
m_stageRemainder.push_back(n); m_stageRemainder.push_back(n);
}while(n>1); }while(n>1);
} }
}
template <typename _Src> template <typename _Src>
void work( int stage,Complex * xout, const _Src * xin, size_t fstride,size_t in_stride) void work( int stage,Complex * xout, const _Src * xin, size_t fstride,size_t in_stride)
@ -279,13 +271,11 @@ namespace Eigen {
} }
}; };
template <typename _Scalar> template <typename _Scalar>
struct ei_kissfft_impl struct ei_kissfft_impl
{ {
typedef _Scalar Scalar; typedef _Scalar Scalar;
typedef std::complex<Scalar> Complex; typedef std::complex<Scalar> Complex;
ei_kissfft_impl() {}
void clear() void clear()
{ {
@ -324,7 +314,6 @@ namespace Eigen {
Complex f1k = fpk + fpnk; Complex f1k = fpk + fpnk;
Complex f2k = fpk - fpnk; Complex f2k = fpk - fpnk;
Complex tw= f2k * rtw[k-1]; Complex tw= f2k * rtw[k-1];
dst[k] = (f1k + tw) * Scalar(.5); dst[k] = (f1k + tw) * Scalar(.5);
dst[ncfft-k] = conj(f1k -tw)*Scalar(.5); dst[ncfft-k] = conj(f1k -tw)*Scalar(.5);
} }
@ -333,28 +322,47 @@ namespace Eigen {
// TODO: make this configurable ( opt-out ) // TODO: make this configurable ( opt-out )
for ( k=1;k < ncfft ; ++k ) for ( k=1;k < ncfft ; ++k )
dst[nfft-k] = conj(dst[k]); dst[nfft-k] = conj(dst[k]);
dst[0] = dc; dst[0] = dc;
dst[ncfft] = nyquist; dst[ncfft] = nyquist;
} }
} }
// half-complex to scalar // inverse complex-to-complex
void inv( Scalar * dst,const Complex * src,int nfft)
{
// TODO add optimized version for even numbers
std::vector<Complex> tmp(nfft);
inv(&tmp[0],src,nfft);
for (int k=0;k<nfft;++k)
dst[k] = tmp[k].real();
}
void inv(Complex * dst,const Complex *src,int nfft) void inv(Complex * dst,const Complex *src,int nfft)
{ {
get_plan(nfft,true).work(0, dst, src, 1,1); get_plan(nfft,true).work(0, dst, src, 1,1);
scale(dst, nfft, Scalar(1)/nfft ); scale(dst, nfft, Scalar(1)/nfft );
} }
// half-complex to scalar
void inv( Scalar * dst,const Complex * src,int nfft)
{
if (nfft&3) {
m_scratchBuf.resize(nfft);
inv(&m_scratchBuf[0],src,nfft);
for (int k=0;k<nfft;++k)
dst[k] = m_scratchBuf[k].real();
}else{
// optimized version for multiple of 4
int ncfft = nfft>>1;
int ncfft2 = nfft>>2;
Complex * rtw = real_twiddles(ncfft2);
m_scratchBuf.resize(ncfft);
m_scratchBuf[0] = Complex( src[0].real() + src[ncfft].real(), src[0].real() - src[ncfft].real() );
for (int k = 1; k <= ncfft / 2; ++k) {
Complex fk = src[k];
Complex fnkc = conj(src[ncfft-k]);
Complex fek = fk + fnkc;
Complex tmp = fk - fnkc;
Complex fok = tmp * conj(rtw[k-1]);
m_scratchBuf[k] = fek + fok;
m_scratchBuf[ncfft-k] = conj(fek - fok);
}
scale(&m_scratchBuf[0], ncfft, Scalar(1)/nfft );
get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_scratchBuf[0], 1,1);
}
}
private: private:
typedef ei_kiss_cpx_fft<Scalar> PlanData; typedef ei_kiss_cpx_fft<Scalar> PlanData;
@ -362,16 +370,16 @@ namespace Eigen {
typedef std::map<int,PlanData> PlanMap; typedef std::map<int,PlanData> PlanMap;
PlanMap m_plans; PlanMap m_plans;
std::map<int, std::vector<Complex> > m_realTwiddles; std::map<int, std::vector<Complex> > m_realTwiddles;
std::vector<Complex> m_scratchBuf;
int PlanKey(int nfft,bool isinverse) const { return (nfft<<1) | isinverse; } int PlanKey(int nfft,bool isinverse) const { return (nfft<<1) | isinverse; }
PlanData & get_plan(int nfft,bool inverse) PlanData & get_plan(int nfft,bool inverse)
{ {
/* /* TODO: figure out why this does not work (g++ 4.3.2)
* for some reason this does not work * for some reason this does not work
* *
typedef typename std::map<int,PlanData>::iterator MapIt; PlanMap::iterator it;
MapIt it;
it = m_plans.find( PlanKey(nfft,inverse) ); it = m_plans.find( PlanKey(nfft,inverse) );
if (it == m_plans.end() ) { if (it == m_plans.end() ) {
// create new entry // create new entry
@ -379,7 +387,7 @@ namespace Eigen {
MapIt it2 = m_plans.find( PlanKey(nfft,!inverse) ); MapIt it2 = m_plans.find( PlanKey(nfft,!inverse) );
if (it2 != m_plans.end() ) { if (it2 != m_plans.end() ) {
it->second = it2.second; it->second = it2.second;
it->second.invert(); it->second.conjugate();
}else{ }else{
it->second.make_twiddles(nfft,inverse); it->second.make_twiddles(nfft,inverse);
it->second.factorize(nfft); it->second.factorize(nfft);