From c60818fca8ed58a272fab9f3f62024e04eac1a1c Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 1 Feb 2011 11:38:46 +0100 Subject: [PATCH] fix trmv regarding strided vectors and static allocation of temporaries --- Eigen/src/Core/ProductBase.h | 2 + .../Core/products/TriangularMatrixVector.h | 194 +++++++++++++++--- test/nomalloc.cpp | 32 ++- 3 files changed, 198 insertions(+), 30 deletions(-) diff --git a/Eigen/src/Core/ProductBase.h b/Eigen/src/Core/ProductBase.h index 1f2b373cd..287ea554f 100644 --- a/Eigen/src/Core/ProductBase.h +++ b/Eigen/src/Core/ProductBase.h @@ -84,12 +84,14 @@ class ProductBase : public MatrixBase typedef internal::blas_traits<_LhsNested> LhsBlasTraits; typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType; typedef typename internal::remove_all::type _ActualLhsType; + typedef typename internal::traits::Scalar LhsScalar; typedef typename Rhs::Nested RhsNested; typedef typename internal::remove_all::type _RhsNested; typedef internal::blas_traits<_RhsNested> RhsBlasTraits; typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType; typedef typename internal::remove_all::type _ActualRhsType; + typedef typename internal::traits::Scalar RhsScalar; // Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once typedef CoeffBasedProduct FullyLazyCoeffBaseProductType; diff --git a/Eigen/src/Core/products/TriangularMatrixVector.h b/Eigen/src/Core/products/TriangularMatrixVector.h index c1f64dcea..23aa52ade 100644 --- a/Eigen/src/Core/products/TriangularMatrixVector.h +++ b/Eigen/src/Core/products/TriangularMatrixVector.h @@ -152,6 +152,10 @@ struct traits > : traits, Lhs, Rhs> > {}; + +template +struct trmv_selector; + } // end namespace internal template @@ -165,20 +169,8 @@ struct TriangularProduct template void scaleAndAddTo(Dest& dst, Scalar alpha) const { eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); - - const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); - const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); - - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) - * RhsBlasTraits::extractScalarFactor(m_rhs); - - internal::product_triangular_matrix_vector - ::Flags)&RowMajorBit) ? RowMajor : ColMajor> - ::run(lhs.rows(),lhs.cols(),lhs.data(),lhs.outerStride(),rhs.data(),rhs.innerStride(), - dst.data(),dst.innerStride(),actualAlpha); + + internal::trmv_selector<(int(internal::traits::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha); } }; @@ -192,23 +184,167 @@ struct TriangularProduct template void scaleAndAddTo(Dest& dst, Scalar alpha) const { - eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols()); - - const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs); - const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs); - - Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs) - * RhsBlasTraits::extractScalarFactor(m_rhs); - - internal::product_triangular_matrix_vector - ::Flags)&RowMajorBit) ? ColMajor : RowMajor> - ::run(rhs.rows(),rhs.cols(),rhs.data(),rhs.outerStride(),lhs.data(),lhs.innerStride(), - dst.data(),dst.innerStride(),actualAlpha); + + typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose,false,Transpose,true> TriangularProductTranspose; + Transpose dstT(dst); + internal::trmv_selector<(int(internal::traits::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run( + TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha); } }; +namespace internal { + +// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same. + +template<> struct trmv_selector +{ + template + static void run(const TriangularProduct& prod, Dest& dest, typename TriangularProduct::Scalar alpha) + { + typedef TriangularProduct ProductType; + typedef typename ProductType::Index Index; + typedef typename ProductType::LhsScalar LhsScalar; + typedef typename ProductType::RhsScalar RhsScalar; + typedef typename ProductType::Scalar ResScalar; + typedef typename ProductType::RealScalar RealScalar; + typedef typename ProductType::ActualLhsType ActualLhsType; + typedef typename ProductType::ActualRhsType ActualRhsType; + typedef typename ProductType::LhsBlasTraits LhsBlasTraits; + typedef typename ProductType::RhsBlasTraits RhsBlasTraits; + typedef Map, Aligned> MappedDest; + + const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs()); + const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs()); + + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) + * RhsBlasTraits::extractScalarFactor(prod.rhs()); + + enum { + // FIXME find a way to allow an inner stride on the result if packet_traits::size==1 + // on, the other hand it is good for the cache to pack the vector anyways... + EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1, + ComplexByReal = (NumTraits::IsComplex) && (!NumTraits::IsComplex), + MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal + }; + + gemv_static_vector_if static_dest; + + bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0)); + bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible; + + RhsScalar compatibleAlpha = get_factor::run(actualAlpha); + + ResScalar* actualDestPtr; + bool freeDestPtr = false; + if (evalToDest) + { + actualDestPtr = dest.data(); + } + else + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = dest.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + if((actualDestPtr = static_dest.data())==0) + { + freeDestPtr = true; + actualDestPtr = ei_aligned_stack_new(ResScalar,dest.size()); + } + if(!alphaIsCompatible) + { + MappedDest(actualDestPtr, dest.size()).setZero(); + compatibleAlpha = RhsScalar(1); + } + else + MappedDest(actualDestPtr, dest.size()) = dest; + } + + internal::product_triangular_matrix_vector + + ::run(actualLhs.rows(),actualLhs.cols(), + actualLhs.data(),actualLhs.outerStride(), + actualRhs.data(),actualRhs.innerStride(), + actualDestPtr,1,compatibleAlpha); + + if (!evalToDest) + { + if(!alphaIsCompatible) + dest += actualAlpha * MappedDest(actualDestPtr, dest.size()); + else + dest = MappedDest(actualDestPtr, dest.size()); + if(freeDestPtr) ei_aligned_stack_delete(ResScalar, actualDestPtr, dest.size()); + } + } +}; + +template<> struct trmv_selector +{ + template + static void run(const TriangularProduct& prod, Dest& dest, typename TriangularProduct::Scalar alpha) + { + typedef TriangularProduct ProductType; + typedef typename ProductType::LhsScalar LhsScalar; + typedef typename ProductType::RhsScalar RhsScalar; + typedef typename ProductType::Scalar ResScalar; + typedef typename ProductType::Index Index; + typedef typename ProductType::ActualLhsType ActualLhsType; + typedef typename ProductType::ActualRhsType ActualRhsType; + typedef typename ProductType::_ActualRhsType _ActualRhsType; + typedef typename ProductType::LhsBlasTraits LhsBlasTraits; + typedef typename ProductType::RhsBlasTraits RhsBlasTraits; + + typename add_const::type actualLhs = LhsBlasTraits::extract(prod.lhs()); + typename add_const::type actualRhs = RhsBlasTraits::extract(prod.rhs()); + + ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs()) + * RhsBlasTraits::extractScalarFactor(prod.rhs()); + + enum { + DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1 + }; + + gemv_static_vector_if static_rhs; + + RhsScalar* actualRhsPtr; + bool freeRhsPtr = false; + if (DirectlyUseRhs) + { + actualRhsPtr = const_cast(actualRhs.data()); + } + else + { + #ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN + int size = actualRhs.size(); + EIGEN_DENSE_STORAGE_CTOR_PLUGIN + #endif + if((actualRhsPtr = static_rhs.data())==0) + { + freeRhsPtr = true; + actualRhsPtr = ei_aligned_stack_new(RhsScalar, actualRhs.size()); + } + Map(actualRhsPtr, actualRhs.size()) = actualRhs; + } + + internal::product_triangular_matrix_vector + + ::run(actualLhs.rows(),actualLhs.cols(), + actualLhs.data(),actualLhs.outerStride(), + actualRhsPtr,1, + dest.data(),dest.innerStride(), + actualAlpha); + + if((!DirectlyUseRhs) && freeRhsPtr) ei_aligned_stack_delete(RhsScalar, actualRhsPtr, prod.rhs().size()); + } +}; + +} // end namespace internal + #endif // EIGEN_TRIANGULARMATRIXVECTOR_H diff --git a/test/nomalloc.cpp b/test/nomalloc.cpp index 94c1b0533..7ef71bfcd 100644 --- a/test/nomalloc.cpp +++ b/test/nomalloc.cpp @@ -71,7 +71,7 @@ template void nomalloc(const MatrixType& m) VERIFY_IS_APPROX((m1+m2)(r,c), (m1(r,c))+(m2(r,c))); VERIFY_IS_APPROX(m1.cwiseProduct(m1.block(0,0,rows,cols)), (m1.array()*m1.array()).matrix()); VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2)); - + m2.col(0).noalias() = m1 * m1.col(0); m2.col(0).noalias() -= m1.adjoint() * m1.col(0); m2.col(0).noalias() -= m1 * m1.row(0).adjoint(); @@ -81,6 +81,36 @@ template void nomalloc(const MatrixType& m) m2.row(0).noalias() -= m1.row(0) * m1.adjoint(); m2.row(0).noalias() -= m1.col(0).adjoint() * m1; m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint(); + VERIFY_IS_APPROX(m2,m2); + + m2.col(0).noalias() = m1.template triangularView() * m1.col(0); + m2.col(0).noalias() -= m1.adjoint().template triangularView() * m1.col(0); + m2.col(0).noalias() -= m1.template triangularView() * m1.row(0).adjoint(); + m2.col(0).noalias() -= m1.adjoint().template triangularView() * m1.row(0).adjoint(); + + m2.row(0).noalias() = m1.row(0) * m1.template triangularView(); + m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template triangularView(); + m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template triangularView(); + m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template triangularView(); + VERIFY_IS_APPROX(m2,m2); + + m2.col(0).noalias() = m1.template selfadjointView() * m1.col(0); + m2.col(0).noalias() -= m1.adjoint().template selfadjointView() * m1.col(0); + m2.col(0).noalias() -= m1.template selfadjointView() * m1.row(0).adjoint(); + m2.col(0).noalias() -= m1.adjoint().template selfadjointView() * m1.row(0).adjoint(); + + m2.row(0).noalias() = m1.row(0) * m1.template selfadjointView(); + m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template selfadjointView(); + m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template selfadjointView(); + m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template selfadjointView(); + VERIFY_IS_APPROX(m2,m2); + + // The following fancy matrix-matrix products are not safe yet regarding static allocation +// m1 += m1.template triangularView() * m2.col(; +// m1.template selfadjointView().rankUpdate(m2); +// m1 += m1.template triangularView() * m2; +// m1 += m1.template selfadjointView() * m2; +// VERIFY_IS_APPROX(m1,m1); } template