needed different proxy return types for fwd,inv to work around static asserts

This commit is contained in:
Mark Borgerding 2010-03-07 22:05:48 -05:00
parent 5b2c8b77df
commit e31929337e
2 changed files with 48 additions and 21 deletions

View File

@ -111,22 +111,26 @@
namespace Eigen { namespace Eigen {
// //
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform> template<typename T_SrcMat,typename T_FftIfc> struct fft_fwd_proxy;
struct fft_result_proxy; template<typename T_SrcMat,typename T_FftIfc> struct fft_inv_proxy;
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform> template<typename T_SrcMat,typename T_FftIfc>
struct ei_traits< fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform> > struct ei_traits< fft_fwd_proxy<T_SrcMat,T_FftIfc> >
{
typedef typename T_SrcMat::PlainObject ReturnType;
};
template<typename T_SrcMat,typename T_FftIfc>
struct ei_traits< fft_inv_proxy<T_SrcMat,T_FftIfc> >
{ {
typedef typename T_SrcMat::PlainObject ReturnType; typedef typename T_SrcMat::PlainObject ReturnType;
}; };
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform> template<typename T_SrcMat,typename T_FftIfc>
struct fft_result_proxy struct fft_fwd_proxy
: public ReturnByValue<fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform> > : public ReturnByValue<fft_fwd_proxy<T_SrcMat,T_FftIfc> >
{ {
fft_result_proxy(const T_SrcMat& src,T_FftIfc & fft,int nfft) : m_src(src),m_ifc(fft), m_nfft(nfft) {} fft_fwd_proxy(const T_SrcMat& src,T_FftIfc & fft,int nfft) : m_src(src),m_ifc(fft), m_nfft(nfft) {}
template<typename T_DestMat> void evalTo(T_DestMat& dst) const; template<typename T_DestMat> void evalTo(T_DestMat& dst) const;
@ -138,6 +142,21 @@ struct fft_result_proxy
int m_nfft; int m_nfft;
}; };
template<typename T_SrcMat,typename T_FftIfc>
struct fft_inv_proxy
: public ReturnByValue<fft_inv_proxy<T_SrcMat,T_FftIfc> >
{
fft_inv_proxy(const T_SrcMat& src,T_FftIfc & fft,int nfft) : m_src(src),m_ifc(fft), m_nfft(nfft) {}
template<typename T_DestMat> void evalTo(T_DestMat& dst) const;
int rows() const { return m_src.rows(); }
int cols() const { return m_src.cols(); }
protected:
const T_SrcMat & m_src;
T_FftIfc & m_ifc;
int m_nfft;
};
template <typename T_Scalar, template <typename T_Scalar,
@ -205,10 +224,12 @@ class FFT
inline inline
void fwd( MatrixBase<ComplexDerived> & dst, const MatrixBase<InputDerived> & src,int nfft=-1) void fwd( MatrixBase<ComplexDerived> & dst, const MatrixBase<InputDerived> & src,int nfft=-1)
{ {
typedef typename ComplexDerived::Scalar dst_type;
typedef typename InputDerived::Scalar src_type;
EIGEN_STATIC_ASSERT_VECTOR_ONLY(InputDerived) EIGEN_STATIC_ASSERT_VECTOR_ONLY(InputDerived)
EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived) EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived)
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,InputDerived) // size at compile-time EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,InputDerived) // size at compile-time
EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret), EIGEN_STATIC_ASSERT((ei_is_same_type<dst_type, Complex>::ret),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY) YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
EIGEN_STATIC_ASSERT(int(InputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit, EIGEN_STATIC_ASSERT(int(InputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES) THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
@ -216,13 +237,13 @@ class FFT
if (nfft<1) if (nfft<1)
nfft = src.size(); nfft = src.size();
if ( NumTraits< typename InputDerived::Scalar >::IsComplex == 0 && HasFlag(HalfSpectrum) ) if ( NumTraits< src_type >::IsComplex == 0 && HasFlag(HalfSpectrum) )
dst.derived().resize( (nfft>>1)+1); dst.derived().resize( (nfft>>1)+1);
else else
dst.derived().resize(nfft); dst.derived().resize(nfft);
if ( src.stride() != 1 || src.size() < nfft ) { if ( src.stride() != 1 || src.size() < nfft ) {
Matrix<typename InputDerived::Scalar,1,Dynamic> tmp; Matrix<src_type,1,Dynamic> tmp;
if (src.size()<nfft) { if (src.size()<nfft) {
tmp.setZero(nfft); tmp.setZero(nfft);
tmp.block(0,0,src.size(),1 ) = src; tmp.block(0,0,src.size(),1 ) = src;
@ -237,18 +258,18 @@ class FFT
template<typename InputDerived> template<typename InputDerived>
inline inline
fft_result_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> ,true> fft_fwd_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> >
fwd( const MatrixBase<InputDerived> & src,int nfft=-1) fwd( const MatrixBase<InputDerived> & src,int nfft=-1)
{ {
return fft_result_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl>,true>( src, *this,nfft ); return fft_fwd_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl> >( src, *this,nfft );
} }
template<typename InputDerived> template<typename InputDerived>
inline inline
fft_result_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> ,false> fft_inv_proxy< MatrixBase<InputDerived>, FFT<T_Scalar,T_Impl> >
inv( const MatrixBase<InputDerived> & src,int nfft=-1) inv( const MatrixBase<InputDerived> & src,int nfft=-1)
{ {
return fft_result_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl>,false>( src, *this,nfft ); return fft_inv_proxy< MatrixBase<InputDerived> ,FFT<T_Scalar,T_Impl> >( src, *this,nfft );
} }
inline inline
@ -382,13 +403,17 @@ class FFT
int m_flag; int m_flag;
}; };
template<typename T_SrcMat,typename T_FftIfc,bool T_ForwardTransform> template<typename T_SrcMat,typename T_FftIfc>
template<typename T_DestMat> inline template<typename T_DestMat> inline
void fft_result_proxy<T_SrcMat,T_FftIfc,T_ForwardTransform>::evalTo(T_DestMat& dst) const void fft_fwd_proxy<T_SrcMat,T_FftIfc>::evalTo(T_DestMat& dst) const
{ {
if (T_ForwardTransform)
m_ifc.fwd( dst, m_src, m_nfft); m_ifc.fwd( dst, m_src, m_nfft);
else }
template<typename T_SrcMat,typename T_FftIfc>
template<typename T_DestMat> inline
void fft_inv_proxy<T_SrcMat,T_FftIfc>::evalTo(T_DestMat& dst) const
{
m_ifc.inv( dst, m_src, m_nfft); m_ifc.inv( dst, m_src, m_nfft);
} }

View File

@ -235,6 +235,9 @@ void test_return_by_value()
Matrix<complex<T>,nrows,ncols> out1; Matrix<complex<T>,nrows,ncols> out1;
Matrix<complex<T>,nrows,ncols> out2; Matrix<complex<T>,nrows,ncols> out2;
FFT<T> fft; FFT<T> fft;
fft.SetFlag(fft.HalfSpectrum );
fft.fwd(out1,in); fft.fwd(out1,in);
out2 = fft.fwd(in); out2 = fft.fwd(in);
VERIFY( (out1-out2).norm() < test_precision<T>() ); VERIFY( (out1-out2).norm() < test_precision<T>() );
@ -246,7 +249,6 @@ void test_FFTW()
{ {
test_return_by_value<float,1,32>(); test_return_by_value<float,1,32>();
test_return_by_value<double,1,32>(); test_return_by_value<double,1,32>();
//test_return_by_value<long double,1,32>();
//CALL_SUBTEST( ( test_complex2d<float,4,8> () ) ); CALL_SUBTEST( ( test_complex2d<double,4,8> () ) ); //CALL_SUBTEST( ( test_complex2d<float,4,8> () ) ); CALL_SUBTEST( ( test_complex2d<double,4,8> () ) );
//CALL_SUBTEST( ( test_complex2d<long double,4,8> () ) ); //CALL_SUBTEST( ( test_complex2d<long double,4,8> () ) );
CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) ); CALL_SUBTEST( test_complex<long double>(32) ); CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) ); CALL_SUBTEST( test_complex<long double>(32) );