Avoid Schur decomposition on (quasi-)triangular matrices. (Huge speed up!)

This commit is contained in:
Chen-Pang He 2012-09-30 16:30:18 +08:00
parent 332eb36436
commit eb33d307af
2 changed files with 13 additions and 34 deletions

View File

@ -129,7 +129,7 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy
int numberOfSquareRoots = 0;
int numberOfExtraSquareRoots = 0;
int degree;
MatrixType T = A;
MatrixType T = A, sqrtT;
const RealScalar maxNormForPade = maxPadeDegree<= 5? 5.3149729967117310e-1: // single precision
maxPadeDegree<= 7? 2.6429608311114350e-1: // double precision
maxPadeDegree<= 8? 2.32777776523703892094e-1L: // extended precision
@ -145,9 +145,8 @@ void MatrixLogarithmAtomic<MatrixType>::computeBig(const MatrixType& A, MatrixTy
break;
++numberOfExtraSquareRoots;
}
MatrixType sqrtT;
MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
T = sqrtT;
T = sqrtT.template triangularView<Upper>();
++numberOfSquareRoots;
}

View File

@ -79,18 +79,9 @@ template <typename MatrixType>
template <typename ResultType>
void MatrixSquareRootQuasiTriangular<MatrixType>::compute(ResultType &result)
{
// Compute Schur decomposition of m_A
const RealSchur<MatrixType> schurOfA(m_A);
const MatrixType& T = schurOfA.matrixT();
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T
MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
computeDiagonalPartOfSqrt(sqrtT, T);
computeOffDiagonalPartOfSqrt(sqrtT, T);
// Compute square root of m_A
result = U * sqrtT * U.adjoint();
result.resize(m_A.rows(), m_A.cols());
computeDiagonalPartOfSqrt(result, m_A);
computeOffDiagonalPartOfSqrt(result, m_A);
}
// pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
@ -291,17 +282,13 @@ template <typename ResultType>
void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
{
using std::sqrt;
// Compute Schur decomposition of m_A
const ComplexSchur<MatrixType> schurOfA(m_A);
const MatrixType& T = schurOfA.matrixT();
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T and store it in upper triangular part of result
// Compute square root of m_A and store it in upper triangular part of result
// This uses that the square root of triangular matrices can be computed directly.
result.resize(m_A.rows(), m_A.cols());
typedef typename MatrixType::Index Index;
for (Index i = 0; i < m_A.rows(); i++) {
result.coeffRef(i,i) = sqrt(T.coeff(i,i));
result.coeffRef(i,i) = sqrt(m_A.coeff(i,i));
}
for (Index j = 1; j < m_A.cols(); j++) {
for (Index i = j-1; i >= 0; i--) {
@ -309,14 +296,9 @@ void MatrixSquareRootTriangular<MatrixType>::compute(ResultType &result)
// if i = j-1, then segment has length 0 so tmp = 0
Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
// denominator may be zero if original matrix is singular
result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
result.coeffRef(i,j) = (m_A.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
}
}
// Compute square root of m_A as U * result * U.adjoint()
MatrixType tmp;
tmp.noalias() = U * result.template triangularView<Upper>();
result.noalias() = tmp * U.adjoint();
}
@ -373,9 +355,8 @@ class MatrixSquareRoot<MatrixType, 0>
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T
MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
tmp.compute(sqrtT);
MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.cols());
MatrixSquareRootQuasiTriangular<MatrixType>(T).compute(sqrtT);
// Compute square root of m_A
result = U * sqrtT * U.adjoint();
@ -407,12 +388,11 @@ class MatrixSquareRoot<MatrixType, 1>
const MatrixType& U = schurOfA.matrixU();
// Compute square root of T
MatrixSquareRootTriangular<MatrixType> tmp(T);
MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
tmp.compute(sqrtT);
MatrixType sqrtT;
MatrixSquareRootTriangular<MatrixType>(T).compute(sqrtT);
// Compute square root of m_A
result = U * sqrtT * U.adjoint();
result = U * (sqrtT.template triangularView<Upper>() * U.adjoint());
}
private: