From a25c9b1e46d110ab896936f085a2986d335d578b Mon Sep 17 00:00:00 2001 From: Jitse Niesen Date: Sun, 27 Dec 2009 18:09:50 +0000 Subject: [PATCH] Simplify and document Sylvester equation solver in MatrixFunction. --- .../src/MatrixFunctions/MatrixFunction.h | 85 +++++++++++++------ 1 file changed, 57 insertions(+), 28 deletions(-) diff --git a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h index 7103ac541..963244771 100644 --- a/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h +++ b/unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h @@ -168,7 +168,7 @@ class MatrixFunction void swapEntriesInSchur(int index, MatrixType& T, MatrixType& U); void computeTriangular(const MatrixType& T, MatrixType& result, const IntVectorType& blockSize); void computeBlockAtomic(const MatrixType& T, MatrixType& result, const IntVectorType& blockSize); - MatrixType solveSylvester(const MatrixType& A, const MatrixType& B, const MatrixType& C); + MatrixType solveTriangularSylvester(const MatrixType& A, const MatrixType& B, const MatrixType& C); MatrixType computeAtomic(const MatrixType& T); void divideInBlocks(const VectorType& v, listOfLists* result); void constructPermutation(const VectorType& diag, const listOfLists& blocks, @@ -264,47 +264,76 @@ void MatrixFunction::computeTriangular(const MatrixType& T, Matr C += result.block(blockStart(blockIndex), blockStart(k), blockSize(blockIndex), blockSize(k)) * T.block(blockStart(k), blockStart(blockIndex+diagIndex), blockSize(k), blockSize(blockIndex+diagIndex)); C -= T.block(blockStart(blockIndex), blockStart(k), blockSize(blockIndex), blockSize(k)) * result.block(blockStart(k), blockStart(blockIndex+diagIndex), blockSize(k), blockSize(blockIndex+diagIndex)); } - result.block(blockStart(blockIndex), blockStart(blockIndex+diagIndex), blockSize(blockIndex), blockSize(blockIndex+diagIndex)) = solveSylvester(A, B, C); + result.block(blockStart(blockIndex), blockStart(blockIndex+diagIndex), blockSize(blockIndex), blockSize(blockIndex+diagIndex)) = solveTriangularSylvester(A, B, C); } } } -// solve AX + XB = C <=> U* A' U X V V* + U* U X V B' V* = U* U C V V* <=> A' U X V + U X V B' = U C V -// Schur: A* = U A'* U* (so A = U* A' U), B = V B' V*, define: X' = U X V, C' = U C V, to get: A' X' + X' B' = C' -// A is m-by-m, B is n-by-n, X is m-by-n, C is m-by-n, U is m-by-m, V is n-by-n +/** \brief Solve a triangular Sylvester equation AX + XB = C + * + * \param[in] A The matrix A; should be square and upper triangular + * \param[in] B The matrix B; should be square and upper triangular + * \param[in] C The matrix C; should have correct size. + * + * \returns The solution X. + * + * If A is m-by-m and B is n-by-n, then both C and X are m-by-n. + * The (i,j)-th component of the Sylvester equation is + * \f[ + * \sum_{k=i}^m A_{ik} X_{kj} + \sum_{k=1}^j X_{ik} B_{kj} = C_{ij}. + * \f] + * This can be re-arranged to yield: + * \f[ + * X_{ij} = \frac{1}{A_{ii} + B_{jj}} \Bigl( C_{ij} + * - \sum_{k=i+1}^m A_{ik} X_{kj} - \sum_{k=1}^{j-1} X_{ik} B_{kj} \Bigr). + * \f] + * It is assumed that A and B are such that the numerator is never + * zero (otherwise the Sylvester equation does not have a unique + * solution). In that case, these equations can be evaluated in the + * order \f$ i=m,\ldots,1 \f$ and \f$ j=1,\ldots,n \f$. + */ template -MatrixType MatrixFunction::solveSylvester(const MatrixType& A, const MatrixType& B, const MatrixType& C) +MatrixType MatrixFunction::solveTriangularSylvester( + const MatrixType& A, + const MatrixType& B, + const MatrixType& C) { - MatrixType U = MatrixType::Zero(A.rows(), A.rows()); - for (int i = 0; i < A.rows(); i++) { - U(i, A.rows() - 1 - i) = static_cast(1); - } - MatrixType Aprime = U * A * U; + ei_assert(A.rows() == A.cols()); + ei_assert(A.isUpperTriangular()); + ei_assert(B.rows() == B.cols()); + ei_assert(B.isUpperTriangular()); + ei_assert(C.rows() == A.rows()); + ei_assert(C.cols() == B.rows()); - MatrixType Bprime = B; - MatrixType V = MatrixType::Identity(B.rows(), B.rows()); + int m = A.rows(); + int n = B.rows(); + MatrixType X(m, n); - MatrixType Cprime = U * C * V; - MatrixType Xprime(A.rows(), B.rows()); - for (int l = 0; l < B.rows(); l++) { - for (int k = 0; k < A.rows(); k++) { - Scalar tmp1, tmp2; - if (k == 0) { - tmp1 = 0; + for (int i = m - 1; i >= 0; --i) { + for (int j = 0; j < n; ++j) { + + // Compute AX = \sum_{k=i+1}^m A_{ik} X_{kj} + Scalar AX; + if (i == m - 1) { + AX = 0; } else { - Matrix tmp1matrix = Aprime.row(k).start(k) * Xprime.col(l).start(k); - tmp1 = tmp1matrix(0,0); + Matrix AXmatrix = A.row(i).end(m-1-i) * X.col(j).end(m-1-i); + AX = AXmatrix(0,0); } - if (l == 0) { - tmp2 = 0; + + // Compute XB = \sum_{k=1}^{j-1} X_{ik} B_{kj} + Scalar XB; + if (j == 0) { + XB = 0; } else { - Matrix tmp2matrix = Xprime.row(k).start(l) * Bprime.col(l).start(l); - tmp2 = tmp2matrix(0,0); + Matrix XBmatrix = X.row(i).start(j) * B.col(j).start(j); + XB = XBmatrix(0,0); } - Xprime(k,l) = (Cprime(k,l) - tmp1 - tmp2) / (Aprime(k,k) + Bprime(l,l)); + + X(i,j) = (C(i,j) - AX - XB) / (A(i,i) + B(j,j)); } } - return U.adjoint() * Xprime * V.adjoint(); + return X; }