moved scaling to Eigen::FFT

This commit is contained in:
Mark Borgerding 2009-10-30 19:50:11 -04:00
parent 0fa68b9e50
commit a26b729cc9
3 changed files with 61 additions and 33 deletions

View File

@ -28,6 +28,7 @@
#include <complex> #include <complex>
#include <vector> #include <vector>
#include <map> #include <map>
#include <Eigen/Core>
#ifdef EIGEN_FFTW_DEFAULT #ifdef EIGEN_FFTW_DEFAULT
// FFTW: faster, GPL -- incompatible with Eigen in LGPL form, bigger code size // FFTW: faster, GPL -- incompatible with Eigen in LGPL form, bigger code size
@ -65,10 +66,31 @@ class FFT
typedef typename impl_type::Scalar Scalar; typedef typename impl_type::Scalar Scalar;
typedef typename impl_type::Complex Complex; typedef typename impl_type::Complex Complex;
FFT(const impl_type & impl=impl_type() ) :m_impl(impl) { } enum Flag {
Default=0, // goof proof
Unscaled=1,
HalfSpectrum=2,
// SomeOtherSpeedOptimization=4
Speedy=32767
};
template <typename _Input> FFT( const impl_type & impl=impl_type() , Flag flags=Default ) :m_impl(impl),m_flag(flags) { }
void fwd( Complex * dst, const _Input * src, int nfft)
inline
bool HasFlag(Flag f) const { return (m_flag & (int)f) == f;}
inline
void SetFlag(Flag f) { m_flag |= (int)f;}
inline
void ClearFlag(Flag f) { m_flag &= (~(int)f);}
void fwd( Complex * dst, const Scalar * src, int nfft)
{
m_impl.fwd(dst,src,nfft);
}
void fwd( Complex * dst, const Complex * src, int nfft)
{ {
m_impl.fwd(dst,src,nfft); m_impl.fwd(dst,src,nfft);
} }
@ -76,8 +98,11 @@ class FFT
template <typename _Input> template <typename _Input>
void fwd( std::vector<Complex> & dst, const std::vector<_Input> & src) void fwd( std::vector<Complex> & dst, const std::vector<_Input> & src)
{ {
dst.resize( src.size() ); if ( NumTraits<_Input>::IsComplex == 0 && HasFlag(HalfSpectrum) )
fwd( &dst[0],&src[0],src.size() ); dst.resize( (src.size()>>1)+1);
else
dst.resize(src.size());
fwd(&dst[0],&src[0],src.size());
} }
template<typename InputDerived, typename ComplexDerived> template<typename InputDerived, typename ComplexDerived>
@ -94,17 +119,18 @@ class FFT
fwd( &dst[0],&src[0],src.size() ); fwd( &dst[0],&src[0],src.size() );
} }
template <typename _Output> void inv( Complex * dst, const Complex * src, int nfft)
void inv( _Output * dst, const Complex * src, int nfft)
{ {
m_impl.inv( dst,src,nfft ); m_impl.inv( dst,src,nfft );
if ( HasFlag( Unscaled ) == false)
scale(dst,1./nfft,nfft);
} }
template <typename _Output> void inv( Scalar * dst, const Complex * src, int nfft)
void inv( std::vector<_Output> & dst, const std::vector<Complex> & src)
{ {
dst.resize( src.size() ); m_impl.inv( dst,src,nfft );
inv( &dst[0],&src[0],src.size() ); if ( HasFlag( Unscaled ) == false)
scale(dst,1./nfft,nfft);
} }
template<typename OutputDerived, typename ComplexDerived> template<typename OutputDerived, typename ComplexDerived>
@ -117,10 +143,24 @@ class FFT
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
EIGEN_STATIC_ASSERT(int(OutputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit, EIGEN_STATIC_ASSERT(int(OutputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES) THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
dst.derived().resize( src.size() );
int nfft = src.size();
int nout = HasFlag(HalfSpectrum) ? ((nfft>>1)+1) : nfft;
dst.derived().resize( nout );
inv( &dst[0],&src[0],src.size() ); inv( &dst[0],&src[0],src.size() );
} }
template <typename _Output>
void inv( std::vector<_Output> & dst, const std::vector<Complex> & src)
{
if ( NumTraits<_Output>::IsComplex == 0 && HasFlag(HalfSpectrum) )
dst.resize( 2*(src.size()-1) );
else
dst.resize( src.size() );
inv( &dst[0],&src[0],dst.size() );
}
// TODO: multi-dimensional FFTs // TODO: multi-dimensional FFTs
// TODO: handle Eigen MatrixBase // TODO: handle Eigen MatrixBase
@ -128,7 +168,16 @@ class FFT
impl_type & impl() {return m_impl;} impl_type & impl() {return m_impl;}
private: private:
template <typename _It,typename _Val>
void scale(_It x,_Val s,int nx)
{
for (int k=0;k<nx;++k)
*x++ *= s;
}
impl_type m_impl; impl_type m_impl;
int m_flag;
}; };
} }
#endif #endif

View File

@ -187,12 +187,6 @@
void inv(Complex * dst,const Complex *src,int nfft) void inv(Complex * dst,const Complex *src,int nfft)
{ {
get_plan(nfft,true,dst,src).inv(ei_fftw_cast(dst), ei_fftw_cast(src),nfft ); get_plan(nfft,true,dst,src).inv(ei_fftw_cast(dst), ei_fftw_cast(src),nfft );
//TODO move scaling to Eigen::FFT
// scaling
Scalar s = Scalar(1.)/nfft;
for (int k=0;k<nfft;++k)
dst[k] *= s;
} }
// half-complex to scalar // half-complex to scalar
@ -200,11 +194,6 @@
void inv( Scalar * dst,const Complex * src,int nfft) void inv( Scalar * dst,const Complex * src,int nfft)
{ {
get_plan(nfft,true,dst,src).inv(ei_fftw_cast(dst), ei_fftw_cast(src),nfft ); get_plan(nfft,true,dst,src).inv(ei_fftw_cast(dst), ei_fftw_cast(src),nfft );
//TODO move scaling to Eigen::FFT
Scalar s = Scalar(1.)/nfft;
for (int k=0;k<nfft;++k)
dst[k] *= s;
} }
protected: protected:

View File

@ -334,7 +334,6 @@
void inv(Complex * dst,const Complex *src,int nfft) void inv(Complex * dst,const Complex *src,int nfft)
{ {
get_plan(nfft,true).work(0, dst, src, 1,1); get_plan(nfft,true).work(0, dst, src, 1,1);
scale(dst, nfft, Scalar(1)/nfft );
} }
// half-complex to scalar // half-complex to scalar
@ -362,7 +361,6 @@
m_tmpBuf[k] = fek + fok; m_tmpBuf[k] = fek + fok;
m_tmpBuf[ncfft-k] = conj(fek - fok); m_tmpBuf[ncfft-k] = conj(fek - fok);
} }
scale(&m_tmpBuf[0], ncfft, Scalar(1)/nfft );
get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_tmpBuf[0], 1,1); get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_tmpBuf[0], 1,1);
} }
} }
@ -403,12 +401,4 @@
} }
return &twidref[0]; return &twidref[0];
} }
// TODO move scaling up into Eigen::FFT
inline
void scale(Complex *dst,int n,Scalar s)
{
for (int k=0;k<n;++k)
dst[k] *= s;
}
}; };