move also inverse() to ReturnByValue, by doing a solve on NestByValue<Identity>.

also: adding resize() to MatrixBase was really needed ;)
This commit is contained in:
Benoit Jacob 2009-09-26 11:40:29 -04:00
parent 176c26feb5
commit e82ab8a5dd
4 changed files with 15 additions and 33 deletions

View File

@ -200,7 +200,7 @@ struct ei_compute_inverse
{ {
static inline void run(const MatrixType& matrix, MatrixType* result) static inline void run(const MatrixType& matrix, MatrixType* result)
{ {
matrix.partialLu().computeInverse(result); result = matrix.partialLu().inverse();
} }
}; };
@ -281,9 +281,7 @@ inline void MatrixBase<Derived>::computeInverse(PlainMatrixType *result) const
template<typename Derived> template<typename Derived>
inline const typename MatrixBase<Derived>::PlainMatrixType MatrixBase<Derived>::inverse() const inline const typename MatrixBase<Derived>::PlainMatrixType MatrixBase<Derived>::inverse() const
{ {
PlainMatrixType result(rows(), cols()); return inverse(*this);
computeInverse(&result);
return result;
} }
@ -299,7 +297,7 @@ struct ei_compute_inverse_with_check
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
LU<MatrixType> lu( matrix ); LU<MatrixType> lu( matrix );
if( !lu.isInvertible() ) return false; if( !lu.isInvertible() ) return false;
lu.computeInverse(result); *result = lu.inverse();
return true; return true;
} }
}; };

View File

@ -58,7 +58,7 @@ template<typename MatrixType> struct ei_lu_image_impl;
* \include class_LU.cpp * \include class_LU.cpp
* Output: \verbinclude class_LU.out * Output: \verbinclude class_LU.out
* *
* \sa MatrixBase::lu(), MatrixBase::determinant(), MatrixBase::inverse(), MatrixBase::computeInverse() * \sa MatrixBase::lu(), MatrixBase::determinant(), MatrixBase::inverse()
*/ */
template<typename MatrixType> class LU template<typename MatrixType> class LU
{ {
@ -193,7 +193,7 @@ template<typename MatrixType> class LU
* Example: \include LU_solve.cpp * Example: \include LU_solve.cpp
* Output: \verbinclude LU_solve.out * Output: \verbinclude LU_solve.out
* *
* \sa TriangularView::solve(), kernel(), inverse(), computeInverse() * \sa TriangularView::solve(), kernel(), inverse()
*/ */
template<typename Rhs> template<typename Rhs>
inline const ei_lu_solve_impl<MatrixType, Rhs> inline const ei_lu_solve_impl<MatrixType, Rhs>
@ -277,34 +277,19 @@ template<typename MatrixType> class LU
return isInjective() && isSurjective(); return isInjective() && isSurjective();
} }
/** Computes the inverse of the matrix of which *this is the LU decomposition.
*
* \param result a pointer to the matrix into which to store the inverse. Resized if needed.
*
* \note If this matrix is not invertible, *result is left with undefined coefficients.
* Use isInvertible() to first determine whether this matrix is invertible.
*
* \sa MatrixBase::computeInverse(), inverse()
*/
inline void computeInverse(MatrixType *result) const
{
ei_assert(m_originalMatrix != 0 && "LU is not initialized.");
ei_assert(m_lu.rows() == m_lu.cols() && "You can't take the inverse of a non-square matrix!");
*result = solve(MatrixType::Identity(m_lu.rows(), m_lu.cols()));
}
/** \returns the inverse of the matrix of which *this is the LU decomposition. /** \returns the inverse of the matrix of which *this is the LU decomposition.
* *
* \note If this matrix is not invertible, the returned matrix has undefined coefficients. * \note If this matrix is not invertible, the returned matrix has undefined coefficients.
* Use isInvertible() to first determine whether this matrix is invertible. * Use isInvertible() to first determine whether this matrix is invertible.
* *
* \sa computeInverse(), MatrixBase::inverse() * \sa MatrixBase::inverse()
*/ */
inline MatrixType inverse() const inline const ei_lu_solve_impl<MatrixType,NestByValue<typename MatrixType::IdentityReturnType> > inverse() const
{ {
MatrixType result; ei_assert(m_originalMatrix != 0 && "LU is not initialized.");
computeInverse(&result); ei_assert(m_lu.rows() == m_lu.cols() && "You can't take the inverse of a non-square matrix!");
return result; return ei_lu_solve_impl<MatrixType,NestByValue<typename MatrixType::IdentityReturnType> >
(*this, MatrixType::Identity(m_lu.rows(), m_lu.cols()).nestByValue());
} }
protected: protected:

View File

@ -40,6 +40,7 @@ template<typename MatrixType> void lu_non_invertible()
typename ei_lu_kernel_impl<MatrixType>::ReturnMatrixType m1kernel = lu.kernel(); typename ei_lu_kernel_impl<MatrixType>::ReturnMatrixType m1kernel = lu.kernel();
typename ei_lu_image_impl <MatrixType>::ReturnMatrixType m1image = lu.image(); typename ei_lu_image_impl <MatrixType>::ReturnMatrixType m1image = lu.image();
// std::cerr << rank << " " << lu.rank() << std::endl;
VERIFY(rank == lu.rank()); VERIFY(rank == lu.rank());
VERIFY(cols - lu.rank() == lu.dimensionOfKernel()); VERIFY(cols - lu.rank() == lu.dimensionOfKernel());
VERIFY(!lu.isInjective()); VERIFY(!lu.isInjective());
@ -54,7 +55,7 @@ template<typename MatrixType> void lu_non_invertible()
m3 = m1*m2; m3 = m1*m2;
m2 = MatrixType::Random(cols,cols2); m2 = MatrixType::Random(cols,cols2);
// test that the code, which does resize(), may be applied to an xpr // test that the code, which does resize(), may be applied to an xpr
m2.block(0,0,cols,cols2) = lu.solve(m3); m2.block(0,0,m2.rows(),m2.cols()) = lu.solve(m3);
VERIFY_IS_APPROX(m3, m1*m2); VERIFY_IS_APPROX(m3, m1*m2);
typedef Matrix<typename MatrixType::Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> SquareMatrixType; typedef Matrix<typename MatrixType::Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> SquareMatrixType;
@ -111,7 +112,6 @@ template<typename MatrixType> void lu_verify_assert()
VERIFY_RAISES_ASSERT(lu.isInjective()) VERIFY_RAISES_ASSERT(lu.isInjective())
VERIFY_RAISES_ASSERT(lu.isSurjective()) VERIFY_RAISES_ASSERT(lu.isSurjective())
VERIFY_RAISES_ASSERT(lu.isInvertible()) VERIFY_RAISES_ASSERT(lu.isInvertible())
VERIFY_RAISES_ASSERT(lu.computeInverse(&tmp))
VERIFY_RAISES_ASSERT(lu.inverse()) VERIFY_RAISES_ASSERT(lu.inverse())
PartialLU<MatrixType> plu; PartialLU<MatrixType> plu;
@ -119,7 +119,6 @@ template<typename MatrixType> void lu_verify_assert()
VERIFY_RAISES_ASSERT(plu.permutationP()) VERIFY_RAISES_ASSERT(plu.permutationP())
VERIFY_RAISES_ASSERT(plu.solve(tmp,&tmp)) VERIFY_RAISES_ASSERT(plu.solve(tmp,&tmp))
VERIFY_RAISES_ASSERT(plu.determinant()) VERIFY_RAISES_ASSERT(plu.determinant())
VERIFY_RAISES_ASSERT(plu.computeInverse(&tmp))
VERIFY_RAISES_ASSERT(plu.inverse()) VERIFY_RAISES_ASSERT(plu.inverse())
} }