Apply Householder U and V in-place.

This commit is contained in:
Gael Guennebaud 2014-09-04 09:17:01 +02:00
parent 8846aa6d1b
commit 15bad3670b

View File

@ -279,7 +279,7 @@ void BDCSVD<MatrixType>::copyUV(const HouseholderU &householderU, const Househol
m_matrixU = MatrixX::Identity(householderU.cols(), Ucols); m_matrixU = MatrixX::Identity(householderU.cols(), Ucols);
Index blockCols = m_computeThinU ? m_nonzeroSingularValues : m_diagSize; Index blockCols = m_computeThinU ? m_nonzeroSingularValues : m_diagSize;
m_matrixU.topLeftCorner(m_diagSize, blockCols) = naiveV.template cast<Scalar>().topLeftCorner(m_diagSize, blockCols); m_matrixU.topLeftCorner(m_diagSize, blockCols) = naiveV.template cast<Scalar>().topLeftCorner(m_diagSize, blockCols);
m_matrixU = householderU * m_matrixU; householderU.applyThisOnTheLeft(m_matrixU);
} }
if (computeV()) if (computeV())
{ {
@ -287,7 +287,7 @@ void BDCSVD<MatrixType>::copyUV(const HouseholderU &householderU, const Househol
m_matrixV = MatrixX::Identity(householderV.cols(), Vcols); m_matrixV = MatrixX::Identity(householderV.cols(), Vcols);
Index blockCols = m_computeThinV ? m_nonzeroSingularValues : m_diagSize; Index blockCols = m_computeThinV ? m_nonzeroSingularValues : m_diagSize;
m_matrixV.topLeftCorner(m_diagSize, blockCols) = naiveU.template cast<Scalar>().topLeftCorner(m_diagSize, blockCols); m_matrixV.topLeftCorner(m_diagSize, blockCols) = naiveU.template cast<Scalar>().topLeftCorner(m_diagSize, blockCols);
m_matrixV = householderV * m_matrixV; householderV.applyThisOnTheLeft(m_matrixV);
} }
} }
@ -314,7 +314,7 @@ void BDCSVD<MatrixType>::divide (Index firstCol, Index lastCol, Index firstRowW,
RealScalar betaK; RealScalar betaK;
RealScalar r0; RealScalar r0;
RealScalar lambda, phi, c0, s0; RealScalar lambda, phi, c0, s0;
MatrixXr l, f; VectorType l, f;
// We use the other algorithm which is more efficient for small // We use the other algorithm which is more efficient for small
// matrices. // matrices.
if (n < m_algoswap) if (n < m_algoswap)
@ -385,7 +385,7 @@ void BDCSVD<MatrixType>::divide (Index firstCol, Index lastCol, Index firstRowW,
// first column = q2 * s0 // first column = q2 * s0
m_naiveU.col(firstCol).segment(firstCol + k + 1, n - k) = m_naiveU.col(lastCol + 1).segment(firstCol + k + 1, n - k) * s0; m_naiveU.col(firstCol).segment(firstCol + k + 1, n - k) = m_naiveU.col(lastCol + 1).segment(firstCol + k + 1, n - k) * s0;
// q2 *= c0 // q2 *= c0
m_naiveU.col(lastCol + 1).segment(firstCol + k + 1, n - k) *= c0; m_naiveU.col(lastCol + 1).segment(firstCol + k + 1, n - k) *= c0;
} }
else else
{ {
@ -408,7 +408,6 @@ void BDCSVD<MatrixType>::divide (Index firstCol, Index lastCol, Index firstRowW,
m_computed.col(firstCol + shift).segment(firstCol + shift + 1, k) = alphaK * l.transpose().real(); m_computed.col(firstCol + shift).segment(firstCol + shift + 1, k) = alphaK * l.transpose().real();
m_computed.col(firstCol + shift).segment(firstCol + shift + k + 1, n - k - 1) = betaK * f.transpose().real(); m_computed.col(firstCol + shift).segment(firstCol + shift + k + 1, n - k - 1) = betaK * f.transpose().real();
// Second part: try to deflate singular values in combined matrix // Second part: try to deflate singular values in combined matrix
deflation(firstCol, lastCol, k, firstRowW, firstColW, shift); deflation(firstCol, lastCol, k, firstRowW, firstColW, shift);
@ -417,7 +416,7 @@ void BDCSVD<MatrixType>::divide (Index firstCol, Index lastCol, Index firstRowW,
VectorType singVals; VectorType singVals;
computeSVDofM(firstCol + shift, n, UofSVD, singVals, VofSVD); computeSVDofM(firstCol + shift, n, UofSVD, singVals, VofSVD);
if (m_compU) m_naiveU.block(firstCol, firstCol, n + 1, n + 1) *= UofSVD; // FIXME this requires a temporary if (m_compU) m_naiveU.block(firstCol, firstCol, n + 1, n + 1) *= UofSVD; // FIXME this requires a temporary
else m_naiveU.block(0, firstCol, 2, n + 1) *= UofSVD; // FIXME this requires a temporary, and exploit that there are 2 rows at compile time else m_naiveU.middleCols(firstCol, n + 1) *= UofSVD; // FIXME this requires a temporary, and exploit that there are 2 rows at compile time
if (m_compV) m_naiveV.block(firstRowW, firstColW, n, n) *= VofSVD; // FIXME this requires a temporary if (m_compV) m_naiveV.block(firstRowW, firstColW, n, n) *= VofSVD; // FIXME this requires a temporary
m_computed.block(firstCol + shift, firstCol + shift, n, n).setZero(); m_computed.block(firstCol + shift, firstCol + shift, n, n).setZero();
m_computed.block(firstCol + shift, firstCol + shift, n, n).diagonal() = singVals; m_computed.block(firstCol + shift, firstCol + shift, n, n).diagonal() = singVals;
@ -434,7 +433,8 @@ template <typename MatrixType>
void BDCSVD<MatrixType>::computeSVDofM(Index firstCol, Index n, MatrixXr& U, VectorType& singVals, MatrixXr& V) void BDCSVD<MatrixType>::computeSVDofM(Index firstCol, Index n, MatrixXr& U, VectorType& singVals, MatrixXr& V)
{ {
// TODO Get rid of these copies (?) // TODO Get rid of these copies (?)
ArrayXr col0 = m_computed.block(firstCol, firstCol, n, 1); // FIXME at least preallocate them
ArrayXr col0 = m_computed.col(firstCol).segment(firstCol, n);
ArrayXr diag = m_computed.block(firstCol, firstCol, n, n).diagonal(); ArrayXr diag = m_computed.block(firstCol, firstCol, n, n).diagonal();
diag(0) = 0; diag(0) = 0;
@ -446,14 +446,15 @@ void BDCSVD<MatrixType>::computeSVDofM(Index firstCol, Index n, MatrixXr& U, Vec
if (col0.hasNaN() || diag.hasNaN()) return; if (col0.hasNaN() || diag.hasNaN()) return;
ArrayXr shifts(n), mus(n), zhat(n); ArrayXr shifts(n), mus(n), zhat(n);
computeSingVals(col0, diag, singVals, shifts, mus); computeSingVals(col0, diag, singVals, shifts, mus);
perturbCol0(col0, diag, singVals, shifts, mus, zhat); perturbCol0(col0, diag, singVals, shifts, mus, zhat);
computeSingVecs(zhat, diag, singVals, shifts, mus, U, V); computeSingVecs(zhat, diag, singVals, shifts, mus, U, V);
// Reverse order so that singular values in increased order // Reverse order so that singular values in increased order
singVals.reverseInPlace(); singVals.reverseInPlace();
U.leftCols(n) = U.leftCols(n).rowwise().reverse().eval(); U.leftCols(n) = U.leftCols(n).rowwise().reverse().eval(); // FIXME this requires a temporary
if (m_compV) V = V.rowwise().reverse().eval(); if (m_compV) V = V.rowwise().reverse().eval(); // FIXME this requires a temporary
} }
template <typename MatrixType> template <typename MatrixType>