fix 168 : now TriangularView::solve returns by value making TriangularView::solveInPlace less important.

Also fix the very outdated documentation of this function.
This commit is contained in:
Gael Guennebaud 2011-02-01 17:21:20 +01:00
parent 59af20b390
commit 8915d5bd22
3 changed files with 65 additions and 30 deletions

View File

@ -173,8 +173,6 @@ struct triangular_solver_selector<Lhs,Rhs,OnTheRight,Mode,CompleteUnrolling,1> {
***************************************************************************/ ***************************************************************************/
/** "in-place" version of TriangularView::solve() where the result is written in \a other /** "in-place" version of TriangularView::solve() where the result is written in \a other
*
*
* *
* \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here.
* This function will const_cast it, so constness isn't honored here. * This function will const_cast it, so constness isn't honored here.
@ -205,43 +203,68 @@ void TriangularView<MatrixType,Mode>::solveInPlace(const MatrixBase<OtherDerived
/** \returns the product of the inverse of \c *this with \a other, \a *this being triangular. /** \returns the product of the inverse of \c *this with \a other, \a *this being triangular.
* *
* This function computes the inverse-matrix matrix product inverse(\c *this) * \a other if
* \a Side==OnTheLeft (the default), or the right-inverse-multiply \a other * inverse(\c *this) if
* \a Side==OnTheRight.
* *
*
* This function computes the inverse-matrix matrix product inverse(\c *this) * \a other.
* The matrix \c *this must be triangular and invertible (i.e., all the coefficients of the * The matrix \c *this must be triangular and invertible (i.e., all the coefficients of the
* diagonal must be non zero). It works as a forward (resp. backward) substitution if \c *this * diagonal must be non zero). It works as a forward (resp. backward) substitution if \c *this
* is an upper (resp. lower) triangular matrix. * is an upper (resp. lower) triangular matrix.
* *
* It is required that \c *this be marked as either an upper or a lower triangular matrix, which
* can be done by marked(), and that is automatically the case with expressions such as those returned
* by extract().
*
* Example: \include MatrixBase_marked.cpp * Example: \include MatrixBase_marked.cpp
* Output: \verbinclude MatrixBase_marked.out * Output: \verbinclude MatrixBase_marked.out
* *
* This function is essentially a wrapper to the faster solveTriangularInPlace() function creating * This function returns an expression of the inverse-multiply and can works in-place if it is assigned
* a temporary copy of \a other, calling solveTriangularInPlace() on the copy and returning it. * to the same matrix or vector \a other.
* Therefore, if \a other is not needed anymore, it is quite faster to call solveTriangularInPlace()
* instead of solveTriangular().
* *
* For users coming from BLAS, this function (and more specifically solveTriangularInPlace()) offer * For users coming from BLAS, this function (and more specifically solveInPlace()) offer
* all the operations supported by the \c *TRSV and \c *TRSM BLAS routines. * all the operations supported by the \c *TRSV and \c *TRSM BLAS routines.
* *
* \b Tips: to perform a \em "right-inverse-multiply" you can simply transpose the operation, e.g.:
* \code
* M * T^1 <=> T.transpose().solveInPlace(M.transpose());
* \endcode
*
* \sa TriangularView::solveInPlace() * \sa TriangularView::solveInPlace()
*/ */
template<typename Derived, unsigned int Mode> template<typename Derived, unsigned int Mode>
template<int Side, typename RhsDerived> template<int Side, typename Other>
typename internal::plain_matrix_type_column_major<RhsDerived>::type const internal::triangular_solve_retval<Side,TriangularView<Derived,Mode>,Other>
TriangularView<Derived,Mode>::solve(const MatrixBase<RhsDerived>& rhs) const TriangularView<Derived,Mode>::solve(const MatrixBase<Other>& other) const
{ {
typename internal::plain_matrix_type_column_major<RhsDerived>::type res(rhs); return internal::triangular_solve_retval<Side,TriangularView,Other>(*this, other.derived());
solveInPlace<Side>(res);
return res;
} }
namespace internal {
template<int Side, typename TriangularType, typename Rhs>
struct traits<triangular_solve_retval<Side, TriangularType, Rhs> >
{
typedef typename internal::plain_matrix_type_column_major<Rhs>::type ReturnType;
};
template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval
: public ReturnByValue<triangular_solve_retval<Side, TriangularType, Rhs> >
{
typedef typename remove_all<typename Rhs::Nested>::type RhsNestedCleaned;
typedef ReturnByValue<triangular_solve_retval> Base;
typedef typename Base::Index Index;
triangular_solve_retval(const TriangularType& tri, const Rhs& rhs)
: m_triangularMatrix(tri), m_rhs(rhs)
{}
inline Index rows() const { return m_rhs.rows(); }
inline Index cols() const { return m_rhs.cols(); }
template<typename Dest> inline void evalTo(Dest& dst) const
{
if(!(is_same<RhsNestedCleaned,Dest>::value && extract_data(dst) == extract_data(m_rhs)))
dst = m_rhs;
m_triangularMatrix.template solveInPlace<Side>(dst);
}
protected:
const TriangularType& m_triangularMatrix;
const typename Rhs::Nested m_rhs;
};
} // namespace internal
#endif // EIGEN_SOLVETRIANGULAR_H #endif // EIGEN_SOLVETRIANGULAR_H

View File

@ -26,6 +26,12 @@
#ifndef EIGEN_TRIANGULARMATRIX_H #ifndef EIGEN_TRIANGULARMATRIX_H
#define EIGEN_TRIANGULARMATRIX_H #define EIGEN_TRIANGULARMATRIX_H
namespace internal {
template<int Side, typename TriangularType, typename Rhs> struct triangular_solve_retval;
}
/** \internal /** \internal
* *
* \class TriangularBase * \class TriangularBase
@ -332,16 +338,16 @@ template<typename _MatrixType, unsigned int _Mode> class TriangularView
} }
#endif // EIGEN2_SUPPORT #endif // EIGEN2_SUPPORT
template<int Side, typename OtherDerived> template<int Side, typename Other>
typename internal::plain_matrix_type_column_major<OtherDerived>::type inline const internal::triangular_solve_retval<Side,TriangularView, Other>
solve(const MatrixBase<OtherDerived>& other) const; solve(const MatrixBase<Other>& other) const;
template<int Side, typename OtherDerived> template<int Side, typename OtherDerived>
void solveInPlace(const MatrixBase<OtherDerived>& other) const; void solveInPlace(const MatrixBase<OtherDerived>& other) const;
template<typename OtherDerived> template<typename Other>
typename internal::plain_matrix_type_column_major<OtherDerived>::type inline const internal::triangular_solve_retval<OnTheLeft,TriangularView, Other>
solve(const MatrixBase<OtherDerived>& other) const solve(const MatrixBase<Other>& other) const
{ return solve<OnTheLeft>(other); } { return solve<OnTheLeft>(other); }
template<typename OtherDerived> template<typename OtherDerived>

View File

@ -28,12 +28,18 @@
(XB).setRandom(); ref = (XB); \ (XB).setRandom(); ref = (XB); \
(TRI).solveInPlace(XB); \ (TRI).solveInPlace(XB); \
VERIFY_IS_APPROX((TRI).toDenseMatrix() * (XB), ref); \ VERIFY_IS_APPROX((TRI).toDenseMatrix() * (XB), ref); \
(XB).setRandom(); ref = (XB); \
(XB) = (TRI).solve(XB); \
VERIFY_IS_APPROX((TRI).toDenseMatrix() * (XB), ref); \
} }
#define VERIFY_TRSM_ONTHERIGHT(TRI,XB) { \ #define VERIFY_TRSM_ONTHERIGHT(TRI,XB) { \
(XB).setRandom(); ref = (XB); \ (XB).setRandom(); ref = (XB); \
(TRI).transpose().template solveInPlace<OnTheRight>(XB.transpose()); \ (TRI).transpose().template solveInPlace<OnTheRight>(XB.transpose()); \
VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \ VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \
(XB).setRandom(); ref = (XB); \
(XB).transpose() = (TRI).transpose().template solve<OnTheRight>(XB.transpose()); \
VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \
} }
template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols=Cols) template<typename Scalar,int Size, int Cols> void trsolve(int size=Size,int cols=Cols)