From c01ed935dd6fd536134c26c2eda7511269c28a6f Mon Sep 17 00:00:00 2001 From: Jitse Niesen Date: Thu, 25 Aug 2011 07:42:21 +0100 Subject: [PATCH] Split code for (quasi)triangular matrices from MatrixSquareRoot. This way, (quasi)triangular matrices can avoid the costly Schur decomposition. --- .../src/MatrixFunctions/MatrixSquareRoot.h | 269 +++++++++++++----- 1 file changed, 190 insertions(+), 79 deletions(-) diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h index b56fcf06b..ed1b5ee35 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h @@ -26,76 +26,71 @@ #define EIGEN_MATRIX_SQUARE_ROOT /** \ingroup MatrixFunctions_Module - * \brief Class for computing matrix square roots. - * \tparam MatrixType type of the argument of the matrix square root, - * expected to be an instantiation of the Matrix class template. + * \brief Class for computing matrix square roots of upper quasi-triangular matrices. + * \tparam MatrixType type of the argument of the matrix square root, + * expected to be an instantiation of the Matrix class template. + * + * This class computes the square root of the upper quasi-triangular + * matrix stored in the upper Hessenberg part of the matrix passed to + * the constructor. + * + * \sa MatrixSquareRoot, MatrixSquareRootTriangular */ -template ::Scalar>::IsComplex> -class MatrixSquareRoot -{ +template +class MatrixSquareRootQuasiTriangular +{ public: /** \brief Constructor. * - * \param[in] A matrix whose square root is to be computed. + * \param[in] A upper quasi-triangular matrix whose square root + * is to be computed. * * The class stores a reference to \p A, so it should not be * changed (or destroyed) before compute() is called. */ - MatrixSquareRoot(const MatrixType& A); - + MatrixSquareRootQuasiTriangular(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + /** \brief Compute the matrix square root * * \param[out] result square root of \p A, as specified in the constructor. * - * See MatrixBase::sqrt() for details on how this computation - * is implemented. + * Only the upper Hessenberg part of \p result is updated, the + * rest is not touched. See MatrixBase::sqrt() for details on + * how this computation is implemented. */ - template - void compute(ResultType &result); -}; - - -// ********** Partial specialization for real matrices ********** - -template -class MatrixSquareRoot -{ -public: - MatrixSquareRoot(const MatrixType& A) - : m_A(A) - { - eigen_assert(A.rows() == A.cols()); - } + template void compute(ResultType &result); + + private: + typedef typename MatrixType::Index Index; + typedef typename MatrixType::Scalar Scalar; + + void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); + void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); + void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); + void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); + void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); + void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); + void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j); - template void compute(ResultType &result); + template + static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, + const SmallMatrixType& B, const SmallMatrixType& C); -private: - typedef typename MatrixType::Index Index; - typedef typename MatrixType::Scalar Scalar; - - void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); - void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T); - void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i); - void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, - typename MatrixType::Index i, typename MatrixType::Index j); - - template - static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, - const SmallMatrixType& B, const SmallMatrixType& C); - - const MatrixType& m_A; + const MatrixType& m_A; }; template template -void MatrixSquareRoot::compute(ResultType &result) +void MatrixSquareRootQuasiTriangular::compute(ResultType &result) { // Compute Schur decomposition of m_A const RealSchur schurOfA(m_A); @@ -114,7 +109,8 @@ void MatrixSquareRoot::compute(ResultType &result) // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T template -void MatrixSquareRoot::computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T) +void MatrixSquareRootQuasiTriangular::computeDiagonalPartOfSqrt(MatrixType& sqrtT, + const MatrixType& T) { const Index size = m_A.rows(); for (Index i = 0; i < size; i++) { @@ -132,7 +128,8 @@ void MatrixSquareRoot::computeDiagonalPartOfSqrt(MatrixType& sqrt // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T. // post: sqrtT is the square root of T. template -void MatrixSquareRoot::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T) +void MatrixSquareRootQuasiTriangular::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, + const MatrixType& T) { const Index size = m_A.rows(); for (Index j = 1; j < size; j++) { @@ -158,9 +155,8 @@ void MatrixSquareRoot::computeOffDiagonalPartOfSqrt(MatrixType& s // pre: T.block(i,i,2,2) has complex conjugate eigenvalues // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2) template -void MatrixSquareRoot::compute2x2diagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i) +void MatrixSquareRootQuasiTriangular + ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i) { // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere // in EigenSolver. If we expose it, we could call it directly from here. @@ -174,10 +170,9 @@ void MatrixSquareRoot::compute2x2diagonalBlock(MatrixType& sqrtT, // all blocks of sqrtT to left of and below (i,j) are correct // post: sqrtT(i,j) has the correct value template -void MatrixSquareRoot::compute1x1offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular + ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value(); sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j)); @@ -185,10 +180,9 @@ void MatrixSquareRoot::compute1x1offDiagonalBlock(MatrixType& sqr // similar to compute1x1offDiagonalBlock() template -void MatrixSquareRoot::compute1x2offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular + ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Matrix rhs = T.template block<1,2>(i,j); if (j-i > 1) @@ -200,10 +194,9 @@ void MatrixSquareRoot::compute1x2offDiagonalBlock(MatrixType& sqr // similar to compute1x1offDiagonalBlock() template -void MatrixSquareRoot::compute2x1offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular + ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Matrix rhs = T.template block<2,1>(i,j); if (j-i > 2) @@ -215,10 +208,9 @@ void MatrixSquareRoot::compute2x1offDiagonalBlock(MatrixType& sqr // similar to compute1x1offDiagonalBlock() template -void MatrixSquareRoot::compute2x2offDiagonalBlock(MatrixType& sqrtT, - const MatrixType& T, - typename MatrixType::Index i, - typename MatrixType::Index j) +void MatrixSquareRootQuasiTriangular + ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T, + typename MatrixType::Index i, typename MatrixType::Index j) { Matrix A = sqrtT.template block<2,2>(i,i); Matrix B = sqrtT.template block<2,2>(j,j); @@ -233,10 +225,9 @@ void MatrixSquareRoot::compute2x2offDiagonalBlock(MatrixType& sqr // solves the equation A X + X B = C where all matrices are 2-by-2 template template -void MatrixSquareRoot::solveAuxiliaryEquation(SmallMatrixType& X, - const SmallMatrixType& A, - const SmallMatrixType& B, - const SmallMatrixType& C) +void MatrixSquareRootQuasiTriangular + ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A, + const SmallMatrixType& B, const SmallMatrixType& C) { EIGEN_STATIC_ASSERT((internal::is_same >::value), EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT); @@ -270,18 +261,37 @@ void MatrixSquareRoot::solveAuxiliaryEquation(SmallMatrixType& X, X.coeffRef(1,1) = result.coeff(3); } -// ********** Partial specialization for complex matrices ********** +/** \ingroup MatrixFunctions_Module + * \brief Class for computing matrix square roots of upper triangular matrices. + * \tparam MatrixType type of the argument of the matrix square root, + * expected to be an instantiation of the Matrix class template. + * + * This class computes the square root of the upper triangular matrix + * stored in the upper triangular part (including the diagonal) of + * the matrix passed to the constructor. + * + * \sa MatrixSquareRoot, MatrixSquareRootQuasiTriangular + */ template -class MatrixSquareRoot +class MatrixSquareRootTriangular { public: - MatrixSquareRoot(const MatrixType& A) + MatrixSquareRootTriangular(const MatrixType& A) : m_A(A) { eigen_assert(A.rows() == A.cols()); } + /** \brief Compute the matrix square root + * + * \param[out] result square root of \p A, as specified in the constructor. + * + * Only the upper triangular part (including the diagonal) of + * \p result is updated, the rest is not touched. See + * MatrixBase::sqrt() for details on how this computation is + * implemented. + */ template void compute(ResultType &result); private: @@ -290,7 +300,7 @@ class MatrixSquareRoot template template -void MatrixSquareRoot::compute(ResultType &result) +void MatrixSquareRootTriangular::compute(ResultType &result) { // Compute Schur decomposition of m_A const ComplexSchur schurOfA(m_A); @@ -320,6 +330,107 @@ void MatrixSquareRoot::compute(ResultType &result) result.noalias() = tmp * U.adjoint(); } + +/** \ingroup MatrixFunctions_Module + * \brief Class for computing matrix square roots of general matrices. + * \tparam MatrixType type of the argument of the matrix square root, + * expected to be an instantiation of the Matrix class template. + * + * \sa MatrixSquareRootTriangular, MatrixSquareRootQuasiTriangular, MatrixBase::sqrt() + */ +template ::Scalar>::IsComplex> +class MatrixSquareRoot +{ + public: + + /** \brief Constructor. + * + * \param[in] A matrix whose square root is to be computed. + * + * The class stores a reference to \p A, so it should not be + * changed (or destroyed) before compute() is called. + */ + MatrixSquareRoot(const MatrixType& A); + + /** \brief Compute the matrix square root + * + * \param[out] result square root of \p A, as specified in the constructor. + * + * See MatrixBase::sqrt() for details on how this computation is + * implemented. + */ + template void compute(ResultType &result); +}; + + +// ********** Partial specialization for real matrices ********** + +template +class MatrixSquareRoot +{ + public: + + MatrixSquareRoot(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + + template void compute(ResultType &result) + { + // Compute Schur decomposition of m_A + const RealSchur schurOfA(m_A); + const MatrixType& T = schurOfA.matrixT(); + const MatrixType& U = schurOfA.matrixU(); + + // Compute square root of T + MatrixSquareRootQuasiTriangular tmp(T); + MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); + tmp.compute(sqrtT); + + // Compute square root of m_A + result = U * sqrtT * U.adjoint(); + } + + private: + const MatrixType& m_A; +}; + + +// ********** Partial specialization for complex matrices ********** + +template +class MatrixSquareRoot +{ + public: + + MatrixSquareRoot(const MatrixType& A) + : m_A(A) + { + eigen_assert(A.rows() == A.cols()); + } + + template void compute(ResultType &result) + { + // Compute Schur decomposition of m_A + const ComplexSchur schurOfA(m_A); + const MatrixType& T = schurOfA.matrixT(); + const MatrixType& U = schurOfA.matrixU(); + + // Compute square root of T + MatrixSquareRootTriangular tmp(T); + MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); + tmp.compute(sqrtT); + + // Compute square root of m_A + result = U * sqrtT * U.adjoint(); + } + + private: + const MatrixType& m_A; +}; + + /** \ingroup MatrixFunctions_Module * * \brief Proxy for the matrix square root of some matrix (expression).