From 082f7ddc3745160c57d8a5a185a2a22e4d781b5f Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Tue, 11 Mar 2014 13:33:44 +0100 Subject: [PATCH] Port Cholesky module to evaluators --- Eigen/src/Cholesky/LDLT.h | 86 +++++++++++++++++++++----------- Eigen/src/Cholesky/LLT.h | 32 +++++++++++- Eigen/src/Core/SelfAdjointView.h | 5 +- Eigen/src/Core/util/XprHelper.h | 2 +- 4 files changed, 93 insertions(+), 32 deletions(-) diff --git a/Eigen/src/Cholesky/LDLT.h b/Eigen/src/Cholesky/LDLT.h index f34d26465..c5ae2c87e 100644 --- a/Eigen/src/Cholesky/LDLT.h +++ b/Eigen/src/Cholesky/LDLT.h @@ -181,6 +181,17 @@ template class LDLT * * \sa MatrixBase::ldlt() */ +#ifdef EIGEN_TEST_EVALUATORS + template + inline const Solve + solve(const MatrixBase& b) const + { + eigen_assert(m_isInitialized && "LDLT is not initialized."); + eigen_assert(m_matrix.rows()==b.rows() + && "LDLT::solve(): invalid number of rows of the right hand side matrix b"); + return Solve(*this, b.derived()); + } +#else template inline const internal::solve_retval solve(const MatrixBase& b) const @@ -190,6 +201,7 @@ template class LDLT && "LDLT::solve(): invalid number of rows of the right hand side matrix b"); return internal::solve_retval(*this, b.derived()); } +#endif #ifdef EIGEN2_SUPPORT template @@ -233,6 +245,12 @@ template class LDLT eigen_assert(m_isInitialized && "LDLT is not initialized."); return Success; } + + #ifndef EIGEN_PARSED_BY_DOXYGEN + template + EIGEN_DEVICE_FUNC + void _solve_impl(const RhsType &rhs, DstType &dst) const; + #endif protected: @@ -492,7 +510,44 @@ LDLT& LDLT::rankUpdate(const MatrixBase +template +void LDLT<_MatrixType,_UpLo>::_solve_impl(const RhsType &rhs, DstType &dst) const +{ + eigen_assert(rhs.rows() == rows()); + // dst = P b + dst = m_transpositions * rhs; + + // dst = L^-1 (P b) + matrixL().solveInPlace(dst); + + // dst = D^-1 (L^-1 P b) + // more precisely, use pseudo-inverse of D (see bug 241) + using std::abs; + EIGEN_USING_STD_MATH(max); + const Diagonal vecD = vectorD(); + RealScalar tolerance = (max)( vecD.array().abs().maxCoeff() * NumTraits::epsilon(), + RealScalar(1) / NumTraits::highest()); // motivated by LAPACK's xGELSS + + for (Index i = 0; i < vecD.size(); ++i) + { + if(abs(vecD(i)) > tolerance) + dst.row(i) /= vecD(i); + else + dst.row(i).setZero(); + } + + // dst = L^-T (D^-1 L^-1 P b) + matrixU().solveInPlace(dst); + + // dst = P^-1 (L^-T D^-1 L^-1 P b) = A^-1 b + dst = m_transpositions.transpose() * dst; +} +#endif + namespace internal { +#ifndef EIGEN_TEST_EVALUATORS template struct solve_retval, Rhs> : solve_retval_base, Rhs> @@ -502,37 +557,10 @@ struct solve_retval, Rhs> template void evalTo(Dest& dst) const { - eigen_assert(rhs().rows() == dec().matrixLDLT().rows()); - // dst = P b - dst = dec().transpositionsP() * rhs(); - - // dst = L^-1 (P b) - dec().matrixL().solveInPlace(dst); - - // dst = D^-1 (L^-1 P b) - // more precisely, use pseudo-inverse of D (see bug 241) - using std::abs; - EIGEN_USING_STD_MATH(max); - typedef typename LDLTType::MatrixType MatrixType; - typedef typename LDLTType::Scalar Scalar; - typedef typename LDLTType::RealScalar RealScalar; - const Diagonal vectorD = dec().vectorD(); - RealScalar tolerance = (max)(vectorD.array().abs().maxCoeff() * NumTraits::epsilon(), - RealScalar(1) / NumTraits::highest()); // motivated by LAPACK's xGELSS - for (Index i = 0; i < vectorD.size(); ++i) { - if(abs(vectorD(i)) > tolerance) - dst.row(i) /= vectorD(i); - else - dst.row(i).setZero(); - } - - // dst = L^-T (D^-1 L^-1 P b) - dec().matrixU().solveInPlace(dst); - - // dst = P^-1 (L^-T D^-1 L^-1 P b) = A^-1 b - dst = dec().transpositionsP().transpose() * dst; + dec()._solve_impl(rhs(),dst); } }; +#endif } /** \internal use x = ldlt_object.solve(x); diff --git a/Eigen/src/Cholesky/LLT.h b/Eigen/src/Cholesky/LLT.h index 2201c641e..d9a8ef1fb 100644 --- a/Eigen/src/Cholesky/LLT.h +++ b/Eigen/src/Cholesky/LLT.h @@ -117,6 +117,17 @@ template class LLT * * \sa solveInPlace(), MatrixBase::llt() */ +#ifdef EIGEN_TEST_EVALUATORS + template + inline const Solve + solve(const MatrixBase& b) const + { + eigen_assert(m_isInitialized && "LLT is not initialized."); + eigen_assert(m_matrix.rows()==b.rows() + && "LLT::solve(): invalid number of rows of the right hand side matrix b"); + return Solve(*this, b.derived()); + } +#else template inline const internal::solve_retval solve(const MatrixBase& b) const @@ -126,6 +137,7 @@ template class LLT && "LLT::solve(): invalid number of rows of the right hand side matrix b"); return internal::solve_retval(*this, b.derived()); } +#endif #ifdef EIGEN2_SUPPORT template @@ -172,6 +184,12 @@ template class LLT template LLT rankUpdate(const VectorType& vec, const RealScalar& sigma = 1); + + #ifndef EIGEN_PARSED_BY_DOXYGEN + template + EIGEN_DEVICE_FUNC + void _solve_impl(const RhsType &rhs, DstType &dst) const; + #endif protected: /** \internal @@ -415,8 +433,19 @@ LLT<_MatrixType,_UpLo> LLT<_MatrixType,_UpLo>::rankUpdate(const VectorType& v, c return *this; } - + +#ifndef EIGEN_PARSED_BY_DOXYGEN +template +template +void LLT<_MatrixType,_UpLo>::_solve_impl(const RhsType &rhs, DstType &dst) const +{ + dst = rhs; + solveInPlace(dst); +} +#endif + namespace internal { +#ifndef EIGEN_TEST_EVALUATORS template struct solve_retval, Rhs> : solve_retval_base, Rhs> @@ -430,6 +459,7 @@ struct solve_retval, Rhs> dec().solveInPlace(dst); } }; +#endif } /** \internal use x = llt_object.solve(x); diff --git a/Eigen/src/Core/SelfAdjointView.h b/Eigen/src/Core/SelfAdjointView.h index f7f512cf4..b300c6a48 100644 --- a/Eigen/src/Core/SelfAdjointView.h +++ b/Eigen/src/Core/SelfAdjointView.h @@ -39,8 +39,11 @@ struct traits > : traits enum { Mode = UpLo | SelfAdjoint, Flags = MatrixTypeNestedCleaned::Flags & (HereditaryBits) - & (~(PacketAccessBit | DirectAccessBit | LinearAccessBit)), // FIXME these flags should be preserved + & (~(PacketAccessBit | DirectAccessBit | LinearAccessBit)) // FIXME these flags should be preserved +#ifndef EIGEN_TEST_EVALUATORS + , CoeffReadCost = MatrixTypeNestedCleaned::CoeffReadCost +#endif }; }; } diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index 8931c5a2d..bcd6183e2 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -348,7 +348,7 @@ template::type> str // When using evaluators, we never evaluate when assembling the expression!! // TODO: get rid of this nested class since it's just an alias for ref_selector. -template::type> struct nested +template struct nested { typedef typename ref_selector::type type; };