* solveTriangularInPlace(): take a const ref and const_cast it, to allow passing temporary xprs.

* improvements, simplifications in LU::solve()
* remove remnant of old norm2()
This commit is contained in:
Benoit Jacob 2009-01-25 23:46:51 +00:00
parent 414ee1db4b
commit 00d7f8e567
3 changed files with 17 additions and 24 deletions

View File

@ -344,13 +344,12 @@ template<typename Derived> class MatrixBase
solveTriangular(const MatrixBase<OtherDerived>& other) const; solveTriangular(const MatrixBase<OtherDerived>& other) const;
template<typename OtherDerived> template<typename OtherDerived>
void solveTriangularInPlace(MatrixBase<OtherDerived>& other) const; void solveTriangularInPlace(const MatrixBase<OtherDerived>& other) const;
template<typename OtherDerived> template<typename OtherDerived>
Scalar dot(const MatrixBase<OtherDerived>& other) const; Scalar dot(const MatrixBase<OtherDerived>& other) const;
RealScalar squaredNorm() const; RealScalar squaredNorm() const;
RealScalar norm2() const;
RealScalar norm() const; RealScalar norm() const;
const PlainMatrixType normalized() const; const PlainMatrixType normalized() const;
void normalize(); void normalize();

View File

@ -221,13 +221,17 @@ struct ei_solve_triangular_selector<Lhs,Rhs,UpLo,ColMajor|IsDense>
}; };
/** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other /** "in-place" version of MatrixBase::solveTriangular() where the result is written in \a other
*
* The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here.
* *
* See MatrixBase:solveTriangular() for the details. * See MatrixBase:solveTriangular() for the details.
*/ */
template<typename Derived> template<typename Derived>
template<typename OtherDerived> template<typename OtherDerived>
void MatrixBase<Derived>::solveTriangularInPlace(MatrixBase<OtherDerived>& other) const void MatrixBase<Derived>::solveTriangularInPlace(const MatrixBase<OtherDerived>& _other) const
{ {
MatrixBase<OtherDerived>& other = _other.const_cast_derived();
ei_assert(derived().cols() == derived().rows()); ei_assert(derived().cols() == derived().rows());
ei_assert(derived().cols() == other.rows()); ei_assert(derived().cols() == other.rows());
ei_assert(!(Flags & ZeroDiagBit)); ei_assert(!(Flags & ZeroDiagBit));

View File

@ -474,13 +474,13 @@ bool LU<MatrixType>::solve(
* So we proceed as follows: * So we proceed as follows:
* Step 1: compute c = Pb. * Step 1: compute c = Pb.
* Step 2: replace c by the solution x to Lx = c. Exists because L is invertible. * Step 2: replace c by the solution x to Lx = c. Exists because L is invertible.
* Step 3: compute d such that Ud = c. Check if such d really exists. * Step 3: replace c by the solution x to Ux = c. Check if a solution really exists.
* Step 4: result = Qd; * Step 4: result = Qc;
*/ */
const int rows = m_lu.rows(), cols = m_lu.cols(); const int rows = m_lu.rows(), cols = m_lu.cols();
ei_assert(b.rows() == rows); ei_assert(b.rows() == rows);
const int smalldim = std::min(rows, m_lu.cols()); const int smalldim = std::min(rows, cols);
typename OtherDerived::PlainMatrixType c(b.rows(), b.cols()); typename OtherDerived::PlainMatrixType c(b.rows(), b.cols());
@ -488,19 +488,13 @@ bool LU<MatrixType>::solve(
for(int i = 0; i < rows; ++i) c.row(m_p.coeff(i)) = b.row(i); for(int i = 0; i < rows; ++i) c.row(m_p.coeff(i)) = b.row(i);
// Step 2 // Step 2
if(rows <= cols) m_lu.corner(Eigen::TopLeft,smalldim,smalldim).template marked<UnitLowerTriangular>()
m_lu.corner(Eigen::TopLeft,rows,smalldim).template marked<UnitLowerTriangular>().solveTriangularInPlace(c); .solveTriangularInPlace(
else c.corner(Eigen::TopLeft, smalldim, c.cols()));
if(rows>cols)
{ {
// construct the L matrix. We shouldn't do that everytime, it is a very large overhead in the case of vector solving. c.corner(Eigen::BottomLeft, rows-cols, c.cols())
// However the case rows>cols is rather unusual with LU so this is probably not a huge priority. -= m_lu.corner(Eigen::BottomLeft, rows-cols, cols) * c.corner(Eigen::TopLeft, cols, c.cols());
Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime,
MatrixType::Options,
MatrixType::MaxRowsAtCompileTime,
MatrixType::MaxRowsAtCompileTime> l(rows, rows);
l.setZero();
l.corner(Eigen::TopLeft,rows,smalldim) = m_lu.corner(Eigen::TopLeft,rows,smalldim);
l.template marked<UnitLowerTriangular>().solveTriangularInPlace(c);
} }
// Step 3 // Step 3
@ -513,17 +507,13 @@ bool LU<MatrixType>::solve(
if(!ei_isMuchSmallerThan(c.coeff(row,col), biggest_in_c)) if(!ei_isMuchSmallerThan(c.coeff(row,col), biggest_in_c))
return false; return false;
} }
Matrix<Scalar, Dynamic, OtherDerived::ColsAtCompileTime,
MatrixType::Options,
MatrixType::MaxRowsAtCompileTime, OtherDerived::MaxColsAtCompileTime>
d(c.corner(TopLeft, m_rank, c.cols()));
m_lu.corner(TopLeft, m_rank, m_rank) m_lu.corner(TopLeft, m_rank, m_rank)
.template marked<UpperTriangular>() .template marked<UpperTriangular>()
.solveTriangularInPlace(d); .solveTriangularInPlace(c.corner(TopLeft, m_rank, c.cols()));
// Step 4 // Step 4
result->resize(m_lu.cols(), b.cols()); result->resize(m_lu.cols(), b.cols());
for(int i = 0; i < m_rank; ++i) result->row(m_q.coeff(i)) = d.row(i); for(int i = 0; i < m_rank; ++i) result->row(m_q.coeff(i)) = c.row(i);
for(int i = m_rank; i < m_lu.cols(); ++i) result->row(m_q.coeff(i)).setZero(); for(int i = m_rank; i < m_lu.cols(); ++i) result->row(m_q.coeff(i)).setZero();
return true; return true;
} }