refactored ei_kissfft_impl to maintain a cache of cpx fft plans

This commit is contained in:
Mark Borgerding 2009-05-25 23:06:49 -04:00
parent 210092d16c
commit 03ed6f9bfb
2 changed files with 320 additions and 299 deletions

View File

@ -24,112 +24,36 @@
#include <complex>
#include <vector>
#include <map>
namespace Eigen {
template <typename _Scalar>
struct ei_kissfft_impl
struct ei_kiss_cpx_fft
{
typedef _Scalar Scalar;
typedef std::complex<Scalar> Complex;
ei_kissfft_impl() : m_nfft(0) {}
std::vector<Complex> m_twiddles;
std::vector<int> m_stageRadix;
std::vector<int> m_stageRemainder;
bool m_inverse;
template <typename _Src>
void fwd( Complex * dst,const _Src *src,int nfft)
{
prepare(nfft,false);
work(0, dst, src, 1,1);
}
// real-to-complex forward FFT
// perform two FFTs of src even and src odd
// then twiddle to recombine them into the half-spectrum format
// then fill in the conjugate symmetric half
void fwd( Complex * dst,const Scalar * src,int nfft)
{
if ( nfft&1 ) {
// use generic mode for odd
prepare(nfft,false);
work(0, dst, src, 1,1);
}else{
int ncfft = nfft>>1;
int ncfft2 = nfft>>2;
// use optimized mode for even real
fwd( dst, reinterpret_cast<const Complex*> (src),ncfft);
make_real_twiddles(nfft);
Complex dc = dst[0].real() + dst[0].imag();
Complex nyquist = dst[0].real() - dst[0].imag();
int k;
for ( k=1;k <= ncfft2 ; ++k ) {
Complex fpk = dst[k];
Complex fpnk = conj(dst[ncfft-k]);
Complex f1k = fpk + fpnk;
Complex f2k = fpk - fpnk;
//Complex tw = f2k * exp( Complex(0,-3.14159265358979323846264338327 * ((double) (k) / ncfft + .5) ) );
Complex tw= f2k * m_realTwiddles[k-1];
dst[k] = (f1k + tw) * Scalar(.5);
dst[ncfft-k] = conj(f1k -tw)*Scalar(.5);
}
// place conjugate-symmetric half at the end for completeness
// TODO: make this configurable ( opt-out )
for ( k=1;k < ncfft ; ++k )
dst[nfft-k] = conj(dst[k]);
dst[0] = dc;
dst[ncfft] = nyquist;
}
}
// half-complex to scalar
void inv( Scalar * dst,const Complex * src,int nfft)
{
// TODO add optimized version for even numbers
std::vector<Complex> tmp(nfft);
inv(&tmp[0],src,nfft);
for (int k=0;k<nfft;++k)
dst[k] = tmp[k].real();
}
void inv(Complex * dst,const Complex *src,int nfft)
{
prepare(nfft,true);
work(0, dst, src, 1,1);
scale(dst, Scalar(1)/m_nfft );
}
void prepare(int nfft,bool inverse)
{
make_twiddles(nfft,inverse);
factorize(nfft);
}
void make_real_twiddles(int nfft)
{
int ncfft2 = nfft>>2;
if ( m_realTwiddles.size() != ncfft2) {
m_realTwiddles.resize(ncfft2);
int ncfft= nfft>>1;
for (int k=1;k<=ncfft2;++k)
m_realTwiddles[k-1] = exp( Complex(0,-3.14159265358979323846264338327 * ((double) (k) / ncfft + .5) ) );
}
}
ei_kiss_cpx_fft() { }
void make_twiddles(int nfft,bool inverse)
{
if ( m_twiddles.size() == nfft) {
// reuse the twiddles, conjugate if necessary
if (inverse != m_inverse)
for (int i=0;i<nfft;++i)
m_twiddles[i] = conj( m_twiddles[i] );
}else{
m_inverse = inverse;
m_twiddles.resize(nfft);
Scalar phinc = (inverse?2:-2)* acos( (Scalar) -1) / nfft;
for (int i=0;i<nfft;++i)
m_twiddles[i] = exp( Complex(0,i*phinc) );
}
m_inverse = inverse;
void invert()
{
m_inverse = !m_inverse;
for ( size_t i=0;i<m_twiddles.size() ;++i)
m_twiddles[i] = conj( m_twiddles[i] );
}
void factorize(int nfft)
@ -157,17 +81,8 @@ namespace Eigen {
m_stageRemainder.push_back(n);
}while(n>1);
}
m_nfft = nfft;
}
void scale(Complex *dst,Scalar s)
{
for (int k=0;k<m_nfft;++k)
dst[k] *= s;
}
private:
template <typename _Src>
void work( int stage,Complex * xout, const _Src * xin, size_t fstride,size_t in_stride)
{
@ -338,7 +253,7 @@ namespace Eigen {
int u,k,q1,q;
Complex * twiddles = &m_twiddles[0];
Complex t;
int Norig = m_nfft;
int Norig = m_twiddles.size();
Complex * scratchbuf = (Complex*)alloca(p*sizeof(Complex) );
for ( u=0; u<m; ++u ) {
@ -362,36 +277,141 @@ namespace Eigen {
}
}
}
};
int m_nfft;
bool m_inverse;
std::vector<Complex> m_twiddles;
std::vector<Complex> m_realTwiddles;
std::vector<int> m_stageRadix;
std::vector<int> m_stageRemainder;
template <typename _Scalar>
struct ei_kissfft_impl
{
typedef _Scalar Scalar;
typedef std::complex<Scalar> Complex;
ei_kissfft_impl() {}
void clear()
{
m_plans.clear();
m_realTwiddles.clear();
}
template <typename _Src>
void fwd( Complex * dst,const _Src *src,int nfft)
{
get_plan(nfft,false).work(0, dst, src, 1,1);
}
// real-to-complex forward FFT
// perform two FFTs of src even and src odd
// then twiddle to recombine them into the half-spectrum format
// then fill in the conjugate symmetric half
void fwd( Complex * dst,const Scalar * src,int nfft)
{
if ( nfft&3 ) {
// use generic mode for odd
get_plan(nfft,false).work(0, dst, src, 1,1);
}else{
int ncfft = nfft>>1;
int ncfft2 = nfft>>2;
Complex * rtw = real_twiddles(ncfft2);
// use optimized mode for even real
fwd( dst, reinterpret_cast<const Complex*> (src), ncfft);
Complex dc = dst[0].real() + dst[0].imag();
Complex nyquist = dst[0].real() - dst[0].imag();
int k;
for ( k=1;k <= ncfft2 ; ++k ) {
Complex fpk = dst[k];
Complex fpnk = conj(dst[ncfft-k]);
Complex f1k = fpk + fpnk;
Complex f2k = fpk - fpnk;
Complex tw= f2k * rtw[k-1];
dst[k] = (f1k + tw) * Scalar(.5);
dst[ncfft-k] = conj(f1k -tw)*Scalar(.5);
}
// place conjugate-symmetric half at the end for completeness
// TODO: make this configurable ( opt-out )
for ( k=1;k < ncfft ; ++k )
dst[nfft-k] = conj(dst[k]);
dst[0] = dc;
dst[ncfft] = nyquist;
}
}
// half-complex to scalar
void inv( Scalar * dst,const Complex * src,int nfft)
{
// TODO add optimized version for even numbers
std::vector<Complex> tmp(nfft);
inv(&tmp[0],src,nfft);
for (int k=0;k<nfft;++k)
dst[k] = tmp[k].real();
}
void inv(Complex * dst,const Complex *src,int nfft)
{
get_plan(nfft,true).work(0, dst, src, 1,1);
scale(dst, nfft, Scalar(1)/nfft );
}
private:
typedef ei_kiss_cpx_fft<Scalar> PlanData;
typedef std::map<int,PlanData> PlanMap;
PlanMap m_plans;
std::map<int, std::vector<Complex> > m_realTwiddles;
int PlanKey(int nfft,bool isinverse) const { return (nfft<<1) | isinverse; }
PlanData & get_plan(int nfft,bool inverse)
{
/*
enum {FORWARD,INVERSE,REAL,COMPLEX};
struct PlanKey
{
PlanKey(int nfft,bool isinverse,bool iscomplex)
{
_key = (nfft<<2) | (isinverse<<1) | iscomplex;
* for some reason this does not work
*
typedef typename std::map<int,PlanData>::iterator MapIt;
MapIt it;
it = m_plans.find( PlanKey(nfft,inverse) );
if (it == m_plans.end() ) {
// create new entry
it = m_plans.insert( make_pair( PlanKey(nfft,inverse) , PlanData() ) );
MapIt it2 = m_plans.find( PlanKey(nfft,!inverse) );
if (it2 != m_plans.end() ) {
it->second = it2.second;
it->second.invert();
}else{
it->second.make_twiddles(nfft,inverse);
it->second.factorize(nfft);
}
bool operator<(const PlanKey & other) const
{
return this->_key < other._key;
}
int _key;
};
struct PlanData
{
std::vector<Complex> m_twiddles;
};
std::map<PlanKey,
return it->second;
*/
PlanData & pd = m_plans[ PlanKey(nfft,inverse) ];
if ( pd.m_twiddles.size() == 0 ) {
pd.make_twiddles(nfft,inverse);
pd.factorize(nfft);
}
return pd;
}
Complex * real_twiddles(int ncfft2)
{
std::vector<Complex> & twidref = m_realTwiddles[ncfft2];// creates new if not there
if ( (int)twidref.size() != ncfft2 ) {
twidref.resize(ncfft2);
int ncfft= ncfft2<<1;
Scalar pi = acos( Scalar(-1) );
for (int k=1;k<=ncfft2;++k)
twidref[k-1] = exp( Complex(0,-pi * ((double) (k) / ncfft + .5) ) );
}
return &twidref[0];
}
void scale(Complex *dst,int n,Scalar s)
{
for (int k=0;k<n;++k)
dst[k] *= s;
}
};
}

View File

@ -44,7 +44,7 @@ complex<long double> promote(long double x) { return complex<long double>( x);
{
long double totalpower=0;
long double difpower=0;
cerr <<"idx\ttruth\t\tvalue\n";
cerr <<"idx\ttruth\t\tvalue\t|dif|=\n";
for (size_t k0=0;k0<fftbuf.size();++k0) {
complex<long double> acc = 0;
long double phinc = -2.*k0* M_PIl / timebuf.size();
@ -55,7 +55,7 @@ complex<long double> promote(long double x) { return complex<long double>( x);
complex<long double> x = promote(fftbuf[k0]);
complex<long double> dif = acc - x;
difpower += norm(dif);
cerr << k0 << "\t" << acc << "\t" << x << endl;
cerr << k0 << "\t" << acc << "\t" << x << "\t" << sqrt(norm(dif)) << endl;
}
cerr << "rmse:" << sqrt(difpower/totalpower) << endl;
return sqrt(difpower/totalpower);
@ -127,8 +127,9 @@ void test_FFT()
#endif
#if 1
CALL_SUBTEST( test_scalar<float>(45) ); CALL_SUBTEST( test_scalar<double>(45) ); CALL_SUBTEST( test_scalar<long double>(45) );
CALL_SUBTEST( test_scalar<float>(32) ); CALL_SUBTEST( test_scalar<double>(32) ); CALL_SUBTEST( test_scalar<long double>(32) );
CALL_SUBTEST( test_scalar<float>(45) ); CALL_SUBTEST( test_scalar<double>(45) ); CALL_SUBTEST( test_scalar<long double>(45) );
CALL_SUBTEST( test_scalar<float>(50) ); CALL_SUBTEST( test_scalar<double>(50) ); CALL_SUBTEST( test_scalar<long double>(50) );
CALL_SUBTEST( test_scalar<float>(256) ); CALL_SUBTEST( test_scalar<double>(256) ); CALL_SUBTEST( test_scalar<long double>(256) );
CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) );
#endif