Avoid Schur decomposition on (quasi-)triangular matrices. (Huge speed up!)

This commit is contained in:
Chen-Pang He 2012-09-30 16:30:18 +08:00
parent 332eb36436
commit eb33d307af
2 changed files with 13 additions and 34 deletions

View File

@ -129,7 +129,7 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy
int numberOfSquareRoots = 0; int numberOfSquareRoots = 0;
int numberOfExtraSquareRoots = 0; int numberOfExtraSquareRoots = 0;
int degree; int degree;
MatrixType T = A; MatrixType T = A, sqrtT;
const RealScalar maxNormForPade = maxPadeDegree<= 5? 5.3149729967117310e-1: // single precision const RealScalar maxNormForPade = maxPadeDegree<= 5? 5.3149729967117310e-1: // single precision
maxPadeDegree<= 7? 2.6429608311114350e-1: // double precision maxPadeDegree<= 7? 2.6429608311114350e-1: // double precision
maxPadeDegree<= 8? 2.32777776523703892094e-1L: // extended precision maxPadeDegree<= 8? 2.32777776523703892094e-1L: // extended precision
@ -145,9 +145,8 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy
break; break;
++numberOfExtraSquareRoots; ++numberOfExtraSquareRoots;
} }
MatrixType sqrtT;
MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT); MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
T = sqrtT; T = sqrtT.template triangularView<Upper>();
++numberOfSquareRoots; ++numberOfSquareRoots;
} }

View File

@ -79,18 +79,9 @@ template <typename MatrixType>
template <typename ResultType> template <typename ResultType>
void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result) void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
{ {
// Compute Schur decomposition of m_A result.resize(m_A.rows(), m_A.cols());
const RealSchur<MatrixType> schurOfA(m_A); computeDiagonalPartOfSqrt(result, m_A);
const MatrixType& T = schurOfA.matrixT(); computeOffDiagonalPartOfSqrt(result, m_A);
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T
MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
computeDiagonalPartOfSqrt(sqrtT, T);
computeOffDiagonalPartOfSqrt(sqrtT, T);
// Compute square root of m_A
result = U * sqrtT * U.adjoint();
} }
// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
@ -291,17 +282,13 @@ template <typename ResultType>
void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result) void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
{ {
using std::sqrt; using std::sqrt;
// Compute Schur decomposition of m_A
const ComplexSchur<MatrixType> schurOfA(m_A);
const MatrixType& T = schurOfA.matrixT();
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T and store it in upper triangular part of result // Compute square root of m_A and store it in upper triangular part of result
// This uses that the square root of triangular matrices can be computed directly. // This uses that the square root of triangular matrices can be computed directly.
result.resize(m_A.rows(), m_A.cols()); result.resize(m_A.rows(), m_A.cols());
typedef typename MatrixType::Index Index; typedef typename MatrixType::Index Index;
for (Index i = 0; i < m_A.rows(); i++) { for (Index i = 0; i < m_A.rows(); i++) {
result.coeffRef(i,i) = sqrt(T.coeff(i,i)); result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
} }
for (Index j = 1; j < m_A.cols(); j++) { for (Index j = 1; j < m_A.cols(); j++) {
for (Index i = j-1; i >= 0; i--) { for (Index i = j-1; i >= 0; i--) {
@ -309,14 +296,9 @@ void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
// if i = j-1, then segment has length 0 so tmp = 0 // if i = j-1, then segment has length 0 so tmp = 0
Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value(); Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
// denominator may be zero if original matrix is singular // denominator may be zero if original matrix is singular
result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j)); result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
} }
} }
// Compute square root of m_A as U * result * U.adjoint()
MatrixType tmp;
tmp.noalias() = U * result.template triangularView<Upper>();
result.noalias() = tmp * U.adjoint();
} }
@ -373,9 +355,8 @@ class MatrixSquareRoot<MatrixType, 0>
const MatrixType& U = schurOfA.matrixU(); const MatrixType& U = schurOfA.matrixU();
// Compute square root of T // Compute square root of T
MatrixSquareRootQuasiTriangular<MatrixType> tmp(T); MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
tmp.compute(sqrtT);
// Compute square root of m_A // Compute square root of m_A
result = U * sqrtT * U.adjoint(); result = U * sqrtT * U.adjoint();
@ -407,12 +388,11 @@ class MatrixSquareRoot<MatrixType, 1>
const MatrixType& U = schurOfA.matrixU(); const MatrixType& U = schurOfA.matrixU();
// Compute square root of T // Compute square root of T
MatrixSquareRootTriangular<MatrixType> tmp(T); MatrixType sqrtT;
MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows()); MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
tmp.compute(sqrtT);
// Compute square root of m_A // Compute square root of m_A
result = U * sqrtT * U.adjoint(); result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
} }
private: private: