From a486d5590a7532b698db07eaeba8c402bbf9fd6f Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 1 Feb 2011 15:49:10 +0100 Subject: [PATCH] implement optimized path for selfadjoint rank 1 update (safe regarding dynamic alloc) --- Eigen/src/Core/products/SelfadjointProduct.h | 119 ++++++++++++++++--- test/product_syrk.cpp | 26 ++++ 2 files changed, 128 insertions(+), 17 deletions(-) diff --git a/Eigen/src/Core/products/SelfadjointProduct.h b/Eigen/src/Core/products/SelfadjointProduct.h index faefc9668..da2e6ee20 100644 --- a/Eigen/src/Core/products/SelfadjointProduct.h +++ b/Eigen/src/Core/products/SelfadjointProduct.h @@ -28,9 +28,109 @@ /********************************************************************** * This file implements a self adjoint product: C += A A^T updating only * half of the selfadjoint matrix C. -* It corresponds to the level 3 SYRK Blas routine. +* It corresponds to the level 3 SYRK and level 2 SYR Blas routines. **********************************************************************/ +template +struct selfadjoint_rank1_update; + +template +struct selfadjoint_rank1_update +{ + static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + { + internal::conj_if cj; + typedef Map > OtherMap; + typedef typename internal::conditional::type ConjRhsType; + for (Index i=0; i >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1))) + += (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1))); + } + } +}; + +template +struct selfadjoint_rank1_update +{ + static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha) + { + selfadjoint_rank1_update::run(size,mat,stride,vec,alpha); + } +}; + +template +struct selfadjoint_product_selector; + +template +struct selfadjoint_product_selector +{ + static void run(MatrixType& mat, const OtherType& other, typename MatrixType::Scalar alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + typedef internal::blas_traits OtherBlasTraits; + typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; + typedef typename internal::remove_all::type _ActualOtherType; + const ActualOtherType actualOther = OtherBlasTraits::extract(other.derived()); + + Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived()); + + enum { + StorageOrder = (internal::traits::Flags&RowMajorBit) ? RowMajor : ColMajor, + UseOtherDirectly = _ActualOtherType::InnerStrideAtCompileTime==1 + }; + internal::gemv_static_vector_if static_other; + + bool freeOtherPtr = false; + Scalar* actualOtherPtr; + if(UseOtherDirectly) + actualOtherPtr = const_cast(actualOther.data()); + else + { + if((actualOtherPtr=static_other.data())==0) + { + freeOtherPtr = true; + actualOtherPtr = ei_aligned_stack_new(Scalar,other.size()); + } + Map(actualOtherPtr, actualOther.size()) = actualOther; + } + + selfadjoint_rank1_update::IsComplex, + (!OtherBlasTraits::NeedToConjugate) && NumTraits::IsComplex> + ::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha); + + if((!UseOtherDirectly) && freeOtherPtr) ei_aligned_stack_delete(Scalar, actualOtherPtr, other.size()); + } +}; + +template +struct selfadjoint_product_selector +{ + static void run(MatrixType& mat, const OtherType& other, typename MatrixType::Scalar alpha) + { + typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::Index Index; + typedef internal::blas_traits OtherBlasTraits; + typedef typename OtherBlasTraits::DirectLinearAccessType ActualOtherType; + typedef typename internal::remove_all::type _ActualOtherType; + const ActualOtherType actualOther = OtherBlasTraits::extract(other.derived()); + + Scalar actualAlpha = alpha * OtherBlasTraits::extractScalarFactor(other.derived()); + + enum { IsRowMajor = (internal::traits::Flags&RowMajorBit) ? 1 : 0 }; + + internal::general_matrix_matrix_triangular_product::IsComplex, + Scalar, _ActualOtherType::Flags&RowMajorBit ? ColMajor : RowMajor, (!OtherBlasTraits::NeedToConjugate) && NumTraits::IsComplex, + MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> + ::run(mat.cols(), actualOther.cols(), + &actualOther.coeffRef(0,0), actualOther.outerStride(), &actualOther.coeffRef(0,0), actualOther.outerStride(), + mat.data(), mat.outerStride(), actualAlpha); + } +}; + // high level API template @@ -38,22 +138,7 @@ template SelfAdjointView& SelfAdjointView ::rankUpdate(const MatrixBase& u, Scalar alpha) { - typedef internal::blas_traits UBlasTraits; - typedef typename UBlasTraits::DirectLinearAccessType ActualUType; - typedef typename internal::remove_all::type _ActualUType; - const ActualUType actualU = UBlasTraits::extract(u.derived()); - - Scalar actualAlpha = alpha * UBlasTraits::extractScalarFactor(u.derived()); - - enum { IsRowMajor = (internal::traits::Flags&RowMajorBit) ? 1 : 0 }; - - internal::general_matrix_matrix_triangular_product::IsComplex, - Scalar, _ActualUType::Flags&RowMajorBit ? ColMajor : RowMajor, (!UBlasTraits::NeedToConjugate) && NumTraits::IsComplex, - MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo> - ::run(_expression().cols(), actualU.cols(), - &actualU.coeffRef(0,0), actualU.outerStride(), &actualU.coeffRef(0,0), actualU.outerStride(), - _expression().const_cast_derived().data(), _expression().outerStride(), actualAlpha); + selfadjoint_product_selector::run(_expression().const_cast_derived(), u.derived(), alpha); return *this; } diff --git a/test/product_syrk.cpp b/test/product_syrk.cpp index 951a254ed..553410b9b 100644 --- a/test/product_syrk.cpp +++ b/test/product_syrk.cpp @@ -44,6 +44,8 @@ template void syrk(const MatrixType& m) Rhs3 rhs3 = Rhs3::Random(internal::random(1,320), rows); Scalar s1 = internal::random(); + + Index c = internal::random(0,cols-1); m2.setZero(); VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(rhs2,s1)._expression()), @@ -68,6 +70,30 @@ template void syrk(const MatrixType& m) m2.setZero(); VERIFY_IS_APPROX(m2.template selfadjointView().rankUpdate(rhs3.adjoint(),s1)._expression(), (s1 * rhs3.adjoint() * rhs3).eval().template triangularView().toDenseMatrix()); + + m2.setZero(); + VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.col(c),s1)._expression()), + ((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView().toDenseMatrix())); + + m2.setZero(); + VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.col(c),s1)._expression()), + ((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView().toDenseMatrix())); + + m2.setZero(); + VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.col(c).conjugate(),s1)._expression()), + ((s1 * m1.col(c).conjugate() * m1.col(c).conjugate().adjoint()).eval().template triangularView().toDenseMatrix())); + + m2.setZero(); + VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.col(c).conjugate(),s1)._expression()), + ((s1 * m1.col(c).conjugate() * m1.col(c).conjugate().adjoint()).eval().template triangularView().toDenseMatrix())); + + m2.setZero(); + VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.row(c),s1)._expression()), + ((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView().toDenseMatrix())); + + m2.setZero(); + VERIFY_IS_APPROX((m2.template selfadjointView().rankUpdate(m1.row(c).adjoint(),s1)._expression()), + ((s1 * m1.row(c).adjoint() * m1.row(c).adjoint().adjoint()).eval().template triangularView().toDenseMatrix())); } void test_product_syrk()