MatrixBase:

* support resize() to same size (nop). The case of FFT was another case where that make one's life far easier.
   hope that's ok with you Gael. but indeed, i don't use it in the ReturnByValue stuff.

FFT:
 * Support MatrixBase (well, in the case with direct memory access such as Map)
 * adapt unit test
This commit is contained in:
Benoit Jacob 2009-10-20 23:25:49 -04:00
parent 471b4d5092
commit c3180b7ffb
4 changed files with 148 additions and 33 deletions

View File

@ -190,6 +190,25 @@ template<typename Derived> class MatrixBase
* i.e., the number of rows for a columns major matrix, and the number of cols otherwise */ * i.e., the number of rows for a columns major matrix, and the number of cols otherwise */
int innerSize() const { return (int(Flags)&RowMajorBit) ? this->cols() : this->rows(); } int innerSize() const { return (int(Flags)&RowMajorBit) ? this->cols() : this->rows(); }
/** Only plain matrices, not expressions may be resized; therefore the only useful resize method is
* Matrix::resize(). The present method only asserts that the new size equals the old size, and does
* nothing else.
*/
void resize(int size)
{
ei_assert(size == this->size()
&& "MatrixBase::resize() does not actually allow to resize.");
}
/** Only plain matrices, not expressions may be resized; therefore the only useful resize method is
* Matrix::resize(). The present method only asserts that the new size equals the old size, and does
* nothing else.
*/
void resize(int rows, int cols)
{
ei_assert(rows == this->rows() && cols == this->cols()
&& "MatrixBase::resize() does not actually allow to resize.");
}
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
/** \internal the plain matrix type corresponding to this expression. Note that is not necessarily /** \internal the plain matrix type corresponding to this expression. Note that is not necessarily
* exactly the return type of eval(): in the case of plain matrices, the return type of eval() is a const * exactly the return type of eval(): in the case of plain matrices, the return type of eval() is a const

View File

@ -78,7 +78,8 @@
INVALID_MATRIX_TEMPLATE_PARAMETERS, INVALID_MATRIX_TEMPLATE_PARAMETERS,
BOTH_MATRICES_MUST_HAVE_THE_SAME_STORAGE_ORDER, BOTH_MATRICES_MUST_HAVE_THE_SAME_STORAGE_ORDER,
THIS_METHOD_IS_ONLY_FOR_DIAGONAL_MATRIX, THIS_METHOD_IS_ONLY_FOR_DIAGONAL_MATRIX,
THE_MATRIX_OR_EXPRESSION_THAT_YOU_PASSED_DOES_NOT_HAVE_THE_EXPECTED_TYPE THE_MATRIX_OR_EXPRESSION_THAT_YOU_PASSED_DOES_NOT_HAVE_THE_EXPECTED_TYPE,
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES
}; };
}; };

View File

@ -36,7 +36,7 @@
#define DEFAULT_FFT_IMPL ei_fftw_impl #define DEFAULT_FFT_IMPL ei_fftw_impl
#endif #endif
// intel Math Kernel Library: fastest, commerical -- incompatible with Eigen in GPL form // intel Math Kernel Library: fastest, commercial -- incompatible with Eigen in GPL form
#ifdef _MKL_DFTI_H_ // mkl_dfti.h has been included, we can use MKL FFT routines #ifdef _MKL_DFTI_H_ // mkl_dfti.h has been included, we can use MKL FFT routines
// TODO // TODO
// #include "src/FFT/ei_imkl_impl.h" // #include "src/FFT/ei_imkl_impl.h"
@ -70,6 +70,20 @@ class FFT
fwd( &dst[0],&src[0],src.size() ); fwd( &dst[0],&src[0],src.size() );
} }
template<typename InputDerived, typename ComplexDerived>
void fwd( MatrixBase<ComplexDerived> & dst, const MatrixBase<InputDerived> & src)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(InputDerived)
EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived)
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,InputDerived) // size at compile-time
EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
EIGEN_STATIC_ASSERT(int(InputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
dst.derived().resize( src.size() );
fwd( &dst[0],&src[0],src.size() );
}
template <typename _Output> template <typename _Output>
void inv( _Output * dst, const Complex * src, int nfft) void inv( _Output * dst, const Complex * src, int nfft)
{ {
@ -83,8 +97,24 @@ class FFT
inv( &dst[0],&src[0],src.size() ); inv( &dst[0],&src[0],src.size() );
} }
template<typename OutputDerived, typename ComplexDerived>
void inv( MatrixBase<OutputDerived> & dst, const MatrixBase<ComplexDerived> & src)
{
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OutputDerived)
EIGEN_STATIC_ASSERT_VECTOR_ONLY(ComplexDerived)
EIGEN_STATIC_ASSERT_SAME_VECTOR_SIZE(ComplexDerived,OutputDerived) // size at compile-time
EIGEN_STATIC_ASSERT((ei_is_same_type<typename ComplexDerived::Scalar, Complex>::ret),
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
EIGEN_STATIC_ASSERT(int(OutputDerived::Flags)&int(ComplexDerived::Flags)&DirectAccessBit,
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_WITH_DIRECT_MEMORY_ACCESS_SUCH_AS_MAP_OR_PLAIN_MATRICES)
dst.derived().resize( src.size() );
inv( &dst[0],&src[0],src.size() );
}
// TODO: multi-dimensional FFTs // TODO: multi-dimensional FFTs
// TODO: handle Eigen MatrixBase // TODO: handle Eigen MatrixBase
// ---> i added fwd and inv specializations above + unit test, is this enough? (bjacob)
traits_type & traits() {return m_traits;} traits_type & traits() {return m_traits;}
private: private:

View File

@ -39,16 +39,16 @@ complex<long double> promote(double x) { return complex<long double>( x); }
complex<long double> promote(long double x) { return complex<long double>( x); } complex<long double> promote(long double x) { return complex<long double>( x); }
template <typename T1,typename T2> template <typename VectorType1,typename VectorType2>
long double fft_rmse( const vector<T1> & fftbuf,const vector<T2> & timebuf) long double fft_rmse( const VectorType1 & fftbuf,const VectorType2 & timebuf)
{ {
long double totalpower=0; long double totalpower=0;
long double difpower=0; long double difpower=0;
cerr <<"idx\ttruth\t\tvalue\t|dif|=\n"; cerr <<"idx\ttruth\t\tvalue\t|dif|=\n";
for (size_t k0=0;k0<fftbuf.size();++k0) { for (size_t k0=0;k0<size_t(fftbuf.size());++k0) {
complex<long double> acc = 0; complex<long double> acc = 0;
long double phinc = -2.*k0* M_PIl / timebuf.size(); long double phinc = -2.*k0* M_PIl / timebuf.size();
for (size_t k1=0;k1<timebuf.size();++k1) { for (size_t k1=0;k1<size_t(timebuf.size());++k1) {
acc += promote( timebuf[k1] ) * exp( complex<long double>(0,k1*phinc) ); acc += promote( timebuf[k1] ) * exp( complex<long double>(0,k1*phinc) );
} }
totalpower += norm(acc); totalpower += norm(acc);
@ -61,8 +61,8 @@ complex<long double> promote(long double x) { return complex<long double>( x);
return sqrt(difpower/totalpower); return sqrt(difpower/totalpower);
} }
template <typename T1,typename T2> template <typename VectorType1,typename VectorType2>
long double dif_rmse( const vector<T1> buf1,const vector<T2> buf2) long double dif_rmse( const VectorType1& buf1,const VectorType2& buf2)
{ {
long double totalpower=0; long double totalpower=0;
long double difpower=0; long double difpower=0;
@ -74,35 +74,59 @@ complex<long double> promote(long double x) { return complex<long double>( x);
return sqrt(difpower/totalpower); return sqrt(difpower/totalpower);
} }
template <class T> enum { StdVectorContainer, EigenVectorContainer };
void test_scalar(int nfft)
template<int Container, typename Scalar> struct VectorType;
template<typename Scalar> struct VectorType<StdVectorContainer,Scalar>
{ {
typedef typename Eigen::FFT<T>::Complex Complex; typedef vector<Scalar> type;
typedef typename Eigen::FFT<T>::Scalar Scalar; };
template<typename Scalar> struct VectorType<EigenVectorContainer,Scalar>
{
typedef Matrix<Scalar,Dynamic,1> type;
};
template <int Container, typename T>
void test_scalar_generic(int nfft)
{
typedef typename FFT<T>::Complex Complex;
typedef typename FFT<T>::Scalar Scalar;
typedef typename VectorType<Container,Scalar>::type ScalarVector;
typedef typename VectorType<Container,Complex>::type ComplexVector;
FFT<T> fft; FFT<T> fft;
vector<Scalar> inbuf(nfft); ScalarVector inbuf(nfft);
vector<Complex> outbuf; ComplexVector outbuf;
for (int k=0;k<nfft;++k) for (int k=0;k<nfft;++k)
inbuf[k]= (T)(rand()/(double)RAND_MAX - .5); inbuf[k]= (T)(rand()/(double)RAND_MAX - .5);
fft.fwd( outbuf,inbuf); fft.fwd( outbuf,inbuf);
VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>() );// gross check VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>() );// gross check
vector<Scalar> buf3; ScalarVector buf3;
fft.inv( buf3 , outbuf); fft.inv( buf3 , outbuf);
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
} }
template <class T> template <typename T>
void test_complex(int nfft) void test_scalar(int nfft)
{ {
typedef typename Eigen::FFT<T>::Complex Complex; test_scalar_generic<StdVectorContainer,T>(nfft);
test_scalar_generic<EigenVectorContainer,T>(nfft);
}
template <int Container, typename T>
void test_complex_generic(int nfft)
{
typedef typename FFT<T>::Complex Complex;
typedef typename VectorType<Container,Complex>::type ComplexVector;
FFT<T> fft; FFT<T> fft;
vector<Complex> inbuf(nfft); ComplexVector inbuf(nfft);
vector<Complex> outbuf; ComplexVector outbuf;
vector<Complex> buf3; ComplexVector buf3;
for (int k=0;k<nfft;++k) for (int k=0;k<nfft;++k)
inbuf[k]= Complex( (T)(rand()/(double)RAND_MAX - .5), (T)(rand()/(double)RAND_MAX - .5) ); inbuf[k]= Complex( (T)(rand()/(double)RAND_MAX - .5), (T)(rand()/(double)RAND_MAX - .5) );
fft.fwd( outbuf , inbuf); fft.fwd( outbuf , inbuf);
@ -114,22 +138,63 @@ void test_complex(int nfft)
VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>() );// gross check
} }
template <typename T>
void test_complex(int nfft)
{
test_complex_generic<StdVectorContainer,T>(nfft);
test_complex_generic<EigenVectorContainer,T>(nfft);
}
void test_FFT() void test_FFT()
{ {
CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) ); CALL_SUBTEST( test_complex<long double>(32) ); CALL_SUBTEST( test_complex<float>(32) );
CALL_SUBTEST( test_complex<float>(256) ); CALL_SUBTEST( test_complex<double>(256) ); CALL_SUBTEST( test_complex<long double>(256) ); CALL_SUBTEST( test_complex<double>(32) );
CALL_SUBTEST( test_complex<float>(3*8) ); CALL_SUBTEST( test_complex<double>(3*8) ); CALL_SUBTEST( test_complex<long double>(3*8) ); CALL_SUBTEST( test_complex<long double>(32) );
CALL_SUBTEST( test_complex<float>(5*32) ); CALL_SUBTEST( test_complex<double>(5*32) ); CALL_SUBTEST( test_complex<long double>(5*32) );
CALL_SUBTEST( test_complex<float>(2*3*4) ); CALL_SUBTEST( test_complex<double>(2*3*4) ); CALL_SUBTEST( test_complex<long double>(2*3*4) ); CALL_SUBTEST( test_complex<float>(256) );
CALL_SUBTEST( test_complex<float>(2*3*4*5) ); CALL_SUBTEST( test_complex<double>(2*3*4*5) ); CALL_SUBTEST( test_complex<long double>(2*3*4*5) ); CALL_SUBTEST( test_complex<double>(256) );
CALL_SUBTEST( test_complex<float>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<double>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<long double>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<long double>(256) );
CALL_SUBTEST( test_complex<float>(3*8) );
CALL_SUBTEST( test_complex<double>(3*8) );
CALL_SUBTEST( test_complex<long double>(3*8) );
CALL_SUBTEST( test_complex<float>(5*32) );
CALL_SUBTEST( test_complex<double>(5*32) );
CALL_SUBTEST( test_complex<long double>(5*32) );
CALL_SUBTEST( test_complex<float>(2*3*4) );
CALL_SUBTEST( test_complex<double>(2*3*4) );
CALL_SUBTEST( test_complex<long double>(2*3*4) );
CALL_SUBTEST( test_complex<float>(2*3*4*5) );
CALL_SUBTEST( test_complex<double>(2*3*4*5) );
CALL_SUBTEST( test_complex<long double>(2*3*4*5) );
CALL_SUBTEST( test_complex<float>(2*3*4*5*7) );
CALL_SUBTEST( test_complex<double>(2*3*4*5*7) );
CALL_SUBTEST( test_complex<long double>(2*3*4*5*7) );
CALL_SUBTEST( test_scalar<float>(32) ); CALL_SUBTEST( test_scalar<double>(32) ); CALL_SUBTEST( test_scalar<long double>(32) ); CALL_SUBTEST( test_scalar<float>(32) );
CALL_SUBTEST( test_scalar<float>(45) ); CALL_SUBTEST( test_scalar<double>(45) ); CALL_SUBTEST( test_scalar<long double>(45) ); CALL_SUBTEST( test_scalar<double>(32) );
CALL_SUBTEST( test_scalar<float>(50) ); CALL_SUBTEST( test_scalar<double>(50) ); CALL_SUBTEST( test_scalar<long double>(50) ); CALL_SUBTEST( test_scalar<long double>(32) );
CALL_SUBTEST( test_scalar<float>(256) ); CALL_SUBTEST( test_scalar<double>(256) ); CALL_SUBTEST( test_scalar<long double>(256) );
CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<float>(45) );
CALL_SUBTEST( test_scalar<double>(45) );
CALL_SUBTEST( test_scalar<long double>(45) );
CALL_SUBTEST( test_scalar<float>(50) );
CALL_SUBTEST( test_scalar<double>(50) );
CALL_SUBTEST( test_scalar<long double>(50) );
CALL_SUBTEST( test_scalar<float>(256) );
CALL_SUBTEST( test_scalar<double>(256) );
CALL_SUBTEST( test_scalar<long double>(256) );
CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) );
CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) );
CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) );
} }