diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index d186b5a4b..b59cdbf85 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -173,8 +173,6 @@ struct triangular_solver_selector { ***************************************************************************/ /** "in-place" version of TriangularView::solve() where the result is written in \a other - * - * * * \warning The parameter is only marked 'const' to make the C++ compiler accept a temporary expression here. * This function will const_cast it, so constness isn't honored here. @@ -205,43 +203,68 @@ void TriangularView::solveInPlace(const MatrixBase T.transpose().solveInPlace(M.transpose()); - * \endcode - * * \sa TriangularView::solveInPlace() */ template -template -typename internal::plain_matrix_type_column_major::type -TriangularView::solve(const MatrixBase& rhs) const +template +const internal::triangular_solve_retval,Other> +TriangularView::solve(const MatrixBase& other) const { - typename internal::plain_matrix_type_column_major::type res(rhs); - solveInPlace(res); - return res; + return internal::triangular_solve_retval(*this, other.derived()); } +namespace internal { + + +template +struct traits > +{ + typedef typename internal::plain_matrix_type_column_major::type ReturnType; +}; + +template struct triangular_solve_retval + : public ReturnByValue > +{ + typedef typename remove_all::type RhsNestedCleaned; + typedef ReturnByValue Base; + typedef typename Base::Index Index; + + triangular_solve_retval(const TriangularType& tri, const Rhs& rhs) + : m_triangularMatrix(tri), m_rhs(rhs) + {} + + inline Index rows() const { return m_rhs.rows(); } + inline Index cols() const { return m_rhs.cols(); } + + template inline void evalTo(Dest& dst) const + { + if(!(is_same::value && extract_data(dst) == extract_data(m_rhs))) + dst = m_rhs; + m_triangularMatrix.template solveInPlace(dst); + } + + protected: + const TriangularType& m_triangularMatrix; + const typename Rhs::Nested m_rhs; +}; + +} // namespace internal + #endif // EIGEN_SOLVETRIANGULAR_H diff --git a/Eigen/src/Core/TriangularMatrix.h b/Eigen/src/Core/TriangularMatrix.h index eb82538dd..f9fedcb0f 100644 --- a/Eigen/src/Core/TriangularMatrix.h +++ b/Eigen/src/Core/TriangularMatrix.h @@ -26,6 +26,12 @@ #ifndef EIGEN_TRIANGULARMATRIX_H #define EIGEN_TRIANGULARMATRIX_H +namespace internal { + +template struct triangular_solve_retval; + +} + /** \internal * * \class TriangularBase @@ -332,16 +338,16 @@ template class TriangularView } #endif // EIGEN2_SUPPORT - template - typename internal::plain_matrix_type_column_major::type - solve(const MatrixBase& other) const; + template + inline const internal::triangular_solve_retval + solve(const MatrixBase& other) const; template void solveInPlace(const MatrixBase& other) const; - template - typename internal::plain_matrix_type_column_major::type - solve(const MatrixBase& other) const + template + inline const internal::triangular_solve_retval + solve(const MatrixBase& other) const { return solve(other); } template diff --git a/test/product_trsolve.cpp b/test/product_trsolve.cpp index f9ad049a4..c207cc500 100644 --- a/test/product_trsolve.cpp +++ b/test/product_trsolve.cpp @@ -28,12 +28,18 @@ (XB).setRandom(); ref = (XB); \ (TRI).solveInPlace(XB); \ VERIFY_IS_APPROX((TRI).toDenseMatrix() * (XB), ref); \ + (XB).setRandom(); ref = (XB); \ + (XB) = (TRI).solve(XB); \ + VERIFY_IS_APPROX((TRI).toDenseMatrix() * (XB), ref); \ } #define VERIFY_TRSM_ONTHERIGHT(TRI,XB) { \ (XB).setRandom(); ref = (XB); \ (TRI).transpose().template solveInPlace(XB.transpose()); \ VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \ + (XB).setRandom(); ref = (XB); \ + (XB).transpose() = (TRI).transpose().template solve(XB.transpose()); \ + VERIFY_IS_APPROX((XB).transpose() * (TRI).transpose().toDenseMatrix(), ref.transpose()); \ } template void trsolve(int size=Size,int cols=Cols)