From e7f8e939e282a64025203a7a22e511165e1e3647 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 22 Jul 2009 18:04:16 +0200 Subject: [PATCH] * GEMM enhencement: no need to pre-transpose the rhs => faster a * b.transpose() product => this also fix a bug in a so far untested situation * SYMM is now ready for use => still have to write the high level stuff to convert natural expressions into a call to SYMM --- Eigen/src/Core/Product.h | 114 +----- Eigen/src/Core/products/GeneralMatrixMatrix.h | 73 +++- .../Core/products/SelfadjointMatrixMatrix.h | 342 ++++++++++++++---- Eigen/src/Core/util/BlasUtil.h | 77 +++- 4 files changed, 418 insertions(+), 188 deletions(-) diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 78cb88c33..754ce4c24 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -73,79 +73,6 @@ struct ProductReturnType typedef Product Type; }; -/* Helper class to analyze the factors of a Product expression. - * In particular it allows to pop out operator-, scalar multiples, - * and conjugate */ -template struct ei_blas_traits -{ - typedef typename ei_traits::Scalar Scalar; - typedef XprType ActualXprType; - enum { - IsComplex = NumTraits::IsComplex, - NeedToConjugate = false, - ActualAccess = int(ei_traits::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess - }; - typedef typename ei_meta_if::ret DirectLinearAccessType; - static inline const ActualXprType& extract(const XprType& x) { return x; } - static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); } -}; - -// pop conjugate -template struct ei_blas_traits, NestedXpr> > - : ei_blas_traits -{ - typedef ei_blas_traits Base; - typedef CwiseUnaryOp, NestedXpr> XprType; - typedef typename Base::ActualXprType ActualXprType; - - enum { - IsComplex = NumTraits::IsComplex, - NeedToConjugate = IsComplex - }; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } - static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); } -}; - -// pop scalar multiple -template struct ei_blas_traits, NestedXpr> > - : ei_blas_traits -{ - typedef ei_blas_traits Base; - typedef CwiseUnaryOp, NestedXpr> XprType; - typedef typename Base::ActualXprType ActualXprType; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } - static inline Scalar extractScalarFactor(const XprType& x) - { return x._functor().m_other * Base::extractScalarFactor(x._expression()); } -}; - -// pop opposite -template struct ei_blas_traits, NestedXpr> > - : ei_blas_traits -{ - typedef ei_blas_traits Base; - typedef CwiseUnaryOp, NestedXpr> XprType; - typedef typename Base::ActualXprType ActualXprType; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } - static inline Scalar extractScalarFactor(const XprType& x) - { return - Base::extractScalarFactor(x._expression()); } -}; - -// pop opposite -template struct ei_blas_traits > - : ei_blas_traits -{ - typedef typename NestedXpr::Scalar Scalar; - typedef ei_blas_traits Base; - typedef NestByValue XprType; - typedef typename Base::ActualXprType ActualXprType; - static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast(x)); } - static inline Scalar extractScalarFactor(const XprType& x) - { return Base::extractScalarFactor(static_cast(x)); } -}; - /* Helper class to determine the type of the product, can be either: * - NormalProduct * - CacheFriendlyProduct @@ -869,25 +796,6 @@ inline Derived& MatrixBase::lazyAssign(const Product struct ei_product_copy_rhs -{ - typedef typename ei_meta_if< - (ei_traits::Flags & RowMajorBit) - || (!(ei_traits::Flags & DirectAccessBit)), - typename ei_plain_matrix_type_column_major::type, - const T& - >::ret type; -}; - -template struct ei_product_copy_lhs -{ - typedef typename ei_meta_if< - (!(int(ei_traits::Flags) & DirectAccessBit)), - typename ei_plain_matrix_type::type, - const T& - >::ret type; -}; - template template inline void Product::_cacheFriendlyEvalAndAdd(DestDerived& res, Scalar alpha) const @@ -895,26 +803,22 @@ inline void Product::_cacheFriendlyEvalAndAdd(DestDerived& typedef ei_blas_traits<_LhsNested> LhsProductTraits; typedef ei_blas_traits<_RhsNested> RhsProductTraits; - typedef typename LhsProductTraits::ActualXprType ActualLhsType; - typedef typename RhsProductTraits::ActualXprType ActualRhsType; + typedef typename LhsProductTraits::DirectLinearAccessType ActualLhsType; + typedef typename RhsProductTraits::DirectLinearAccessType ActualRhsType; - const ActualLhsType& actualLhs = LhsProductTraits::extract(m_lhs); - const ActualRhsType& actualRhs = RhsProductTraits::extract(m_rhs); + typedef typename ei_cleantype::type _ActualLhsType; + typedef typename ei_cleantype::type _ActualRhsType; + + const ActualLhsType lhs = LhsProductTraits::extract(m_lhs); + const ActualRhsType rhs = RhsProductTraits::extract(m_rhs); Scalar actualAlpha = alpha * LhsProductTraits::extractScalarFactor(m_lhs) * RhsProductTraits::extractScalarFactor(m_rhs); - typedef typename ei_product_copy_lhs::type LhsCopy; - typedef typename ei_unref::type _LhsCopy; - typedef typename ei_product_copy_rhs::type RhsCopy; - typedef typename ei_unref::type _RhsCopy; - LhsCopy lhs(actualLhs); - RhsCopy rhs(actualRhs); - ei_general_matrix_matrix_product< Scalar, - (_LhsCopy::Flags&RowMajorBit)?RowMajor:ColMajor, bool(LhsProductTraits::NeedToConjugate), - (_RhsCopy::Flags&RowMajorBit)?RowMajor:ColMajor, bool(RhsProductTraits::NeedToConjugate), + (_ActualLhsType::Flags&RowMajorBit)?RowMajor:ColMajor, bool(LhsProductTraits::NeedToConjugate), + (_ActualRhsType::Flags&RowMajorBit)?RowMajor:ColMajor, bool(RhsProductTraits::NeedToConjugate), (DestDerived::Flags&RowMajorBit)?RowMajor:ColMajor> ::run( rows(), cols(), lhs.cols(), diff --git a/Eigen/src/Core/products/GeneralMatrixMatrix.h b/Eigen/src/Core/products/GeneralMatrixMatrix.h index 68949499a..1c48a5ed4 100644 --- a/Eigen/src/Core/products/GeneralMatrixMatrix.h +++ b/Eigen/src/Core/products/GeneralMatrixMatrix.h @@ -89,16 +89,16 @@ static void run(int rows, int cols, int depth, // we have selected one row panel of rhs and one column panel of lhs // pack rhs's panel into a sequential chunk of memory // and expand each coeff to a constant packet for further reuse - ei_gemm_pack_rhs()(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); + ei_gemm_pack_rhs()(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); // => GEPP_VAR1 for(int i2=0; i2()(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc); - ei_gebp_kernel >() + ei_gebp_kernel >() (res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); } } @@ -110,11 +110,13 @@ static void run(int rows, int cols, int depth, }; // optimized GEneral packed Block * packed Panel product kernel -template +template struct ei_gebp_kernel { void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int actual_mc, int actual_kc, int packet_cols, int i2, int cols) { + typedef typename ei_packet_traits::type PacketType; + enum { PacketSize = ei_packet_traits::size }; Conj cj; const int peeled_mc = (actual_mc/mr)*mr; // loops on each cache friendly block of the result/rhs @@ -276,7 +278,7 @@ struct ei_gebp_kernel if(nr==4) res[(j2+3)*resStride + i2 + i] += C3; } } - + // process remaining rhs/res columns one at a time // => do the same but with nr==1 for(int j2=packet_cols; j2 -struct ei_gemm_pack_rhs +// this version is optimized for column major matrices +template +struct ei_gemm_pack_rhs { + enum { PacketSize = ei_packet_traits::size }; void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols) { bool hasAlpha = alpha != Scalar(1); @@ -419,6 +423,61 @@ struct ei_gemm_pack_rhs } }; +// this version is optimized for row major matrices +template +struct ei_gemm_pack_rhs +{ + enum { PacketSize = ei_packet_traits::size }; + void operator()(Scalar* blockB, const Scalar* rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols) + { + bool hasAlpha = alpha != Scalar(1); + int count = 0; + for(int j2=0; j2 lhs(_lhs,lhsStride); + ei_const_blas_data_mapper lhs(_lhs,lhsStride); int count = 0; const int peeled_mc = (actual_mc/mr)*mr; for(int i=0; i +struct ei_symm_pack_rhs +{ + enum { PacketSize = ei_packet_traits::size }; + void operator()(Scalar* blockB, const Scalar* _rhs, int rhsStride, Scalar alpha, int actual_kc, int packet_cols, int cols, int k2) + { + int end_k = k2 + actual_kc; + int count = 0; + ei_const_blas_data_mapper rhs(_rhs,rhsStride); + + // first part: standard case + for(int j2=0; j2 the same with nr==1) + for(int j2=packet_cols; j2 -static EIGEN_DONT_INLINE void ei_product_selfadjoint_matrix( - int size, - const Scalar* _lhs, int lhsStride, - const Scalar* _rhs, int rhsStride, bool rhsRowMajor, int cols, - Scalar* res, int resStride, - Scalar alpha) +template +struct ei_product_selfadjoint_matrix; + +template +struct ei_product_selfadjoint_matrix { - typedef typename ei_packet_traits::type Packet; - ei_const_blas_data_mapper lhs(_lhs,lhsStride); - ei_const_blas_data_mapper rhs(_rhs,rhsStride); - - if (ConjugateRhs) - alpha = ei_conj(alpha); - - typedef typename ei_packet_traits::type PacketType; - - const bool lhsRowMajor = (StorageOrder==RowMajor); - - typedef ei_product_blocking_traits Blocking; - - int kc = std::min(Blocking::Max_kc,size); // cache block size along the K direction - int mc = std::min(Blocking::Max_mc,size); // cache block size along the M direction - - Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); - Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); - - // number of columns which can be processed by packet of nr columns - int packet_cols = (cols/Blocking::nr)*Blocking::nr; - - ei_gebp_kernel > gebp_kernel; - - for(int k2=0; k2() - (blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); - - // the select lhs's panel has to be split in three different parts: - // 1 - the transposed panel above the diagonal block => transposed packed copy - // 2 - the diagonal block => special packed copy - // 3 - the panel below the diagonal block => generic packed copy - for(int i2=0; i2() - (blockA, &lhs(k2,i2), lhsStride, actual_kc, actual_mc); - - gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); - } - // the block diagonal - { - const int actual_mc = std::min(k2+kc,size)-k2; - // symmetric packed copy - ei_symm_pack_lhs() - (blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc); - gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, k2, cols); - } - - for(int i2=k2+kc; i2() - (blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc); - gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); - } + ei_product_selfadjoint_matrix + ::run(rows, cols, rhs, rhsStride, lhs, lhsStride, res, resStride, alpha); } +}; - ei_aligned_stack_delete(Scalar, blockA, kc*mc); - ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize); -} +template +struct ei_product_selfadjoint_matrix +{ + + static EIGEN_DONT_INLINE void run( + int rows, int cols, + const Scalar* _lhs, int lhsStride, + const Scalar* _rhs, int rhsStride, + Scalar* res, int resStride, + Scalar alpha) + { + int size = rows; + + ei_const_blas_data_mapper lhs(_lhs,lhsStride); + ei_const_blas_data_mapper rhs(_rhs,rhsStride); + + if (ConjugateRhs) + alpha = ei_conj(alpha); + + typedef ei_product_blocking_traits Blocking; + + int kc = std::min(Blocking::Max_kc,size); // cache block size along the K direction + int mc = std::min(Blocking::Max_mc,rows); // cache block size along the M direction + + Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); + Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); + + // number of columns which can be processed by packet of nr columns + int packet_cols = (cols/Blocking::nr)*Blocking::nr; + + ei_gebp_kernel > gebp_kernel; + + for(int k2=0; k2() + (blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, packet_cols, cols); + + // the select lhs's panel has to be split in three different parts: + // 1 - the transposed panel above the diagonal block => transposed packed copy + // 2 - the diagonal block => special packed copy + // 3 - the panel below the diagonal block => generic packed copy + for(int i2=0; i2() + (blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); + } + // the block diagonal + { + const int actual_mc = std::min(k2+kc,size)-k2; + // symmetric packed copy + ei_symm_pack_lhs() + (blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, k2, cols); + } + + for(int i2=k2+kc; i2() + (blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); + } + } + + ei_aligned_stack_delete(Scalar, blockA, kc*mc); + ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize); + } +}; + +// matrix * selfadjoint product +template +struct ei_product_selfadjoint_matrix +{ + + static EIGEN_DONT_INLINE void run( + int rows, int cols, + const Scalar* _lhs, int lhsStride, + const Scalar* _rhs, int rhsStride, + Scalar* res, int resStride, + Scalar alpha) + { + int size = cols; + + ei_const_blas_data_mapper lhs(_lhs,lhsStride); + ei_const_blas_data_mapper rhs(_rhs,rhsStride); + + if (ConjugateRhs) + alpha = ei_conj(alpha); + + typedef ei_product_blocking_traits Blocking; + + int kc = std::min(Blocking::Max_kc,size); // cache block size along the K direction + int mc = std::min(Blocking::Max_mc,rows); // cache block size along the M direction + + Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); + Scalar* blockB = ei_aligned_stack_new(Scalar, kc*cols*Blocking::PacketSize); + + // number of columns which can be processed by packet of nr columns + int packet_cols = (cols/Blocking::nr)*Blocking::nr; + + ei_gebp_kernel > gebp_kernel; + + for(int k2=0; k2() + (blockB, _rhs, rhsStride, alpha, actual_kc, packet_cols, cols, k2); + + // => GEPP + for(int i2=0; i2() + (blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); + + gebp_kernel(res, resStride, blockA, blockB, actual_mc, actual_kc, packet_cols, i2, cols); + } + } + + ei_aligned_stack_delete(Scalar, blockA, kc*mc); + ei_aligned_stack_delete(Scalar, blockB, kc*cols*Blocking::PacketSize); + } +}; #endif // EIGEN_SELFADJOINT_MATRIX_MATRIX_H diff --git a/Eigen/src/Core/util/BlasUtil.h b/Eigen/src/Core/util/BlasUtil.h index 6e4b21e6a..25829652f 100644 --- a/Eigen/src/Core/util/BlasUtil.h +++ b/Eigen/src/Core/util/BlasUtil.h @@ -29,10 +29,10 @@ // implement and control fast level 2 and level 3 BLAS-like routines. // forward declarations -template +template struct ei_gebp_kernel; -template +template struct ei_gemm_pack_rhs; template @@ -154,4 +154,77 @@ struct ei_product_blocking_traits }; }; +/* Helper class to analyze the factors of a Product expression. + * In particular it allows to pop out operator-, scalar multiples, + * and conjugate */ +template struct ei_blas_traits +{ + typedef typename ei_traits::Scalar Scalar; + typedef XprType ActualXprType; + enum { + IsComplex = NumTraits::IsComplex, + NeedToConjugate = false, + ActualAccess = int(ei_traits::Flags)&DirectAccessBit ? HasDirectAccess : NoDirectAccess + }; + typedef typename ei_meta_if::ret DirectLinearAccessType; + static inline const ActualXprType& extract(const XprType& x) { return x; } + static inline Scalar extractScalarFactor(const XprType&) { return Scalar(1); } +}; + +// pop conjugate +template struct ei_blas_traits, NestedXpr> > + : ei_blas_traits +{ + typedef ei_blas_traits Base; + typedef CwiseUnaryOp, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + + enum { + IsComplex = NumTraits::IsComplex, + NeedToConjugate = IsComplex + }; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) { return ei_conj(Base::extractScalarFactor(x._expression())); } +}; + +// pop scalar multiple +template struct ei_blas_traits, NestedXpr> > + : ei_blas_traits +{ + typedef ei_blas_traits Base; + typedef CwiseUnaryOp, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return x._functor().m_other * Base::extractScalarFactor(x._expression()); } +}; + +// pop opposite +template struct ei_blas_traits, NestedXpr> > + : ei_blas_traits +{ + typedef ei_blas_traits Base; + typedef CwiseUnaryOp, NestedXpr> XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(x._expression()); } + static inline Scalar extractScalarFactor(const XprType& x) + { return - Base::extractScalarFactor(x._expression()); } +}; + +// pop opposite +template struct ei_blas_traits > + : ei_blas_traits +{ + typedef typename NestedXpr::Scalar Scalar; + typedef ei_blas_traits Base; + typedef NestByValue XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast(x)); } + static inline Scalar extractScalarFactor(const XprType& x) + { return Base::extractScalarFactor(static_cast(x)); } +}; + #endif // EIGEN_BLASUTIL_H