move partial-pivoting lu to ei_solve_impl

This commit is contained in:
Benoit Jacob 2009-11-03 03:06:34 -05:00
parent da363d997f
commit a77872dd6c

View File

@ -26,8 +26,6 @@
#ifndef EIGEN_PARTIALLU_H #ifndef EIGEN_PARTIALLU_H
#define EIGEN_PARTIALLU_H #define EIGEN_PARTIALLU_H
template<typename MatrixType, typename Rhs> struct ei_partialpivlu_solve_impl;
/** \ingroup LU_Module /** \ingroup LU_Module
* *
* \class PartialPivLU * \class PartialPivLU
@ -59,10 +57,11 @@ template<typename MatrixType, typename Rhs> struct ei_partialpivlu_solve_impl;
* *
* \sa MatrixBase::partialPivLu(), MatrixBase::determinant(), MatrixBase::inverse(), MatrixBase::computeInverse(), class FullPivLU * \sa MatrixBase::partialPivLu(), MatrixBase::determinant(), MatrixBase::inverse(), MatrixBase::computeInverse(), class FullPivLU
*/ */
template<typename MatrixType> class PartialPivLU template<typename _MatrixType> class PartialPivLU
{ {
public: public:
typedef _MatrixType MatrixType;
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar; typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar;
typedef Matrix<int, 1, MatrixType::ColsAtCompileTime> IntRowVectorType; typedef Matrix<int, 1, MatrixType::ColsAtCompileTime> IntRowVectorType;
@ -134,11 +133,11 @@ template<typename MatrixType> class PartialPivLU
* \sa TriangularView::solve(), inverse(), computeInverse() * \sa TriangularView::solve(), inverse(), computeInverse()
*/ */
template<typename Rhs> template<typename Rhs>
inline const ei_partialpivlu_solve_impl<MatrixType, Rhs> inline const ei_solve_return_value<PartialPivLU, Rhs>
solve(const MatrixBase<Rhs>& b) const solve(const MatrixBase<Rhs>& b) const
{ {
ei_assert(m_isInitialized && "PartialPivLU is not initialized."); ei_assert(m_isInitialized && "PartialPivLU is not initialized.");
return ei_partialpivlu_solve_impl<MatrixType, Rhs>(*this, b.derived()); return ei_solve_return_value<PartialPivLU, Rhs>(*this, b.derived());
} }
/** \returns the inverse of the matrix of which *this is the LU decomposition. /** \returns the inverse of the matrix of which *this is the LU decomposition.
@ -148,10 +147,10 @@ template<typename MatrixType> class PartialPivLU
* *
* \sa MatrixBase::inverse(), LU::inverse() * \sa MatrixBase::inverse(), LU::inverse()
*/ */
inline const ei_partialpivlu_solve_impl<MatrixType,NestByValue<typename MatrixType::IdentityReturnType> > inverse() const inline const ei_solve_return_value<PartialPivLU,NestByValue<typename MatrixType::IdentityReturnType> > inverse() const
{ {
ei_assert(m_isInitialized && "PartialPivLU is not initialized."); ei_assert(m_isInitialized && "PartialPivLU is not initialized.");
return ei_partialpivlu_solve_impl<MatrixType,NestByValue<typename MatrixType::IdentityReturnType> > return ei_solve_return_value<PartialPivLU,NestByValue<typename MatrixType::IdentityReturnType> >
(*this, MatrixType::Identity(m_lu.rows(), m_lu.cols()).nestByValue()); (*this, MatrixType::Identity(m_lu.rows(), m_lu.cols()).nestByValue());
} }
@ -170,6 +169,9 @@ template<typename MatrixType> class PartialPivLU
*/ */
typename ei_traits<MatrixType>::Scalar determinant() const; typename ei_traits<MatrixType>::Scalar determinant() const;
inline int rows() const { return m_lu.rows(); }
inline int cols() const { return m_lu.cols(); }
protected: protected:
MatrixType m_lu; MatrixType m_lu;
IntColVectorType m_p; IntColVectorType m_p;
@ -407,33 +409,11 @@ typename ei_traits<MatrixType>::Scalar PartialPivLU<MatrixType>::determinant() c
/***** Implementation of solve() *****************************************************/ /***** Implementation of solve() *****************************************************/
template<typename MatrixType,typename Rhs> template<typename MatrixType, typename Rhs, typename Dest>
struct ei_traits<ei_partialpivlu_solve_impl<MatrixType,Rhs> > struct ei_solve_impl<PartialPivLU<MatrixType>, Rhs, Dest>
: ei_solve_return_value<PartialPivLU<MatrixType>, Rhs>
{ {
typedef Matrix<typename Rhs::Scalar, void evalTo(Dest& dst) const
MatrixType::ColsAtCompileTime,
Rhs::ColsAtCompileTime,
Rhs::PlainMatrixType::Options,
MatrixType::MaxColsAtCompileTime,
Rhs::MaxColsAtCompileTime> ReturnMatrixType;
};
template<typename MatrixType, typename Rhs>
struct ei_partialpivlu_solve_impl : public ReturnByValue<ei_partialpivlu_solve_impl<MatrixType, Rhs> >
{
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested;
typedef PartialPivLU<MatrixType> LUType;
const LUType& m_lu;
const typename Rhs::Nested m_rhs;
ei_partialpivlu_solve_impl(const LUType& lu, const Rhs& rhs)
: m_lu(lu), m_rhs(rhs)
{}
inline int rows() const { return m_lu.matrixLU().cols(); }
inline int cols() const { return m_rhs.cols(); }
template<typename Dest> void evalTo(Dest& dst) const
{ {
/* The decomposition PA = LU can be rewritten as A = P^{-1} L U. /* The decomposition PA = LU can be rewritten as A = P^{-1} L U.
* So we proceed as follows: * So we proceed as follows:
@ -442,19 +422,22 @@ struct ei_partialpivlu_solve_impl : public ReturnByValue<ei_partialpivlu_solve_i
* Step 3: replace c by the solution x to Ux = c. * Step 3: replace c by the solution x to Ux = c.
*/ */
const int size = m_lu.matrixLU().rows(); const PartialPivLU<MatrixType>& dec = this->m_dec;
ei_assert(m_rhs.rows() == size); const Rhs& rhs = this->m_rhs;
const int size = dec.matrixLU().rows();
ei_assert(rhs.rows() == size);
dst.resize(size, m_rhs.cols()); dst.resize(size, rhs.cols());
// Step 1 // Step 1
for(int i = 0; i < size; ++i) dst.row(m_lu.permutationP().coeff(i)) = m_rhs.row(i); for(int i = 0; i < size; ++i) dst.row(dec.permutationP().coeff(i)) = rhs.row(i);
// Step 2 // Step 2
m_lu.matrixLU().template triangularView<UnitLowerTriangular>().solveInPlace(dst); dec.matrixLU().template triangularView<UnitLowerTriangular>().solveInPlace(dst);
// Step 3 // Step 3
m_lu.matrixLU().template triangularView<UpperTriangular>().solveInPlace(dst); dec.matrixLU().template triangularView<UpperTriangular>().solveInPlace(dst);
} }
}; };