PartialPivLU: port to PermutationMatrix

PermutationMatrix: add resize()
This commit is contained in:
Benoit Jacob 2009-11-16 15:36:07 -05:00
parent eb6df28c6c
commit 76c614f9bd
2 changed files with 11 additions and 10 deletions

View File

@ -159,6 +159,9 @@ class PermutationMatrix : public AnyMatrixBase<PermutationMatrix<SizeAtCompileTi
/** \returns a reference to the stored array representing the permutation. */ /** \returns a reference to the stored array representing the permutation. */
IndicesType& indices() { return m_indices; } IndicesType& indices() { return m_indices; }
/** Resizes to given size. */
inline void resize(int size) { m_indices.resize(size); }
/**** inversion and multiplication helpers to hopefully get RVO ****/ /**** inversion and multiplication helpers to hopefully get RVO ****/
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN

View File

@ -64,10 +64,8 @@ template<typename _MatrixType> class PartialPivLU
typedef _MatrixType MatrixType; typedef _MatrixType MatrixType;
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar; typedef typename NumTraits<typename MatrixType::Scalar>::Real RealScalar;
typedef Matrix<int, 1, MatrixType::ColsAtCompileTime> IntRowVectorType; typedef Matrix<int, MatrixType::RowsAtCompileTime, 1> PermutationVectorType;
typedef Matrix<int, MatrixType::RowsAtCompileTime, 1> IntColVectorType; typedef PermutationMatrix<MatrixType::RowsAtCompileTime> PermutationType;
typedef Matrix<Scalar, 1, MatrixType::ColsAtCompileTime> RowVectorType;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> ColVectorType;
enum { MaxSmallDimAtCompileTime = EIGEN_ENUM_MIN( enum { MaxSmallDimAtCompileTime = EIGEN_ENUM_MIN(
MatrixType::MaxColsAtCompileTime, MatrixType::MaxColsAtCompileTime,
@ -109,7 +107,7 @@ template<typename _MatrixType> class PartialPivLU
* representing the P permutation i.e. the permutation of the rows. For its precise meaning, * representing the P permutation i.e. the permutation of the rows. For its precise meaning,
* see the examples given in the documentation of class FullPivLU. * see the examples given in the documentation of class FullPivLU.
*/ */
inline const IntColVectorType& permutationP() const inline const PermutationType& permutationP() const
{ {
ei_assert(m_isInitialized && "PartialPivLU is not initialized."); ei_assert(m_isInitialized && "PartialPivLU is not initialized.");
return m_p; return m_p;
@ -174,7 +172,7 @@ template<typename _MatrixType> class PartialPivLU
protected: protected:
MatrixType m_lu; MatrixType m_lu;
IntColVectorType m_p; PermutationType m_p;
int m_det_p; int m_det_p;
bool m_isInitialized; bool m_isInitialized;
}; };
@ -384,15 +382,15 @@ PartialPivLU<MatrixType>& PartialPivLU<MatrixType>::compute(const MatrixType& ma
ei_assert(matrix.rows() == matrix.cols() && "PartialPivLU is only for square (and moreover invertible) matrices"); ei_assert(matrix.rows() == matrix.cols() && "PartialPivLU is only for square (and moreover invertible) matrices");
const int size = matrix.rows(); const int size = matrix.rows();
IntColVectorType rows_transpositions(size); PermutationVectorType rows_transpositions(size);
int nb_transpositions; int nb_transpositions;
ei_partial_lu_inplace(m_lu, rows_transpositions, nb_transpositions); ei_partial_lu_inplace(m_lu, rows_transpositions, nb_transpositions);
m_det_p = (nb_transpositions%2) ? -1 : 1; m_det_p = (nb_transpositions%2) ? -1 : 1;
for(int k = 0; k < size; ++k) m_p.coeffRef(k) = k; for(int k = 0; k < size; ++k) m_p.indices().coeffRef(k) = k;
for(int k = size-1; k >= 0; --k) for(int k = size-1; k >= 0; --k)
std::swap(m_p.coeffRef(k), m_p.coeffRef(rows_transpositions.coeff(k))); std::swap(m_p.indices().coeffRef(k), m_p.indices().coeffRef(rows_transpositions.coeff(k)));
m_isInitialized = true; m_isInitialized = true;
return *this; return *this;
@ -428,7 +426,7 @@ struct ei_solve_retval<PartialPivLU<_MatrixType>, Rhs>
dst.resize(size, rhs().cols()); dst.resize(size, rhs().cols());
// Step 1 // Step 1
for(int i = 0; i < size; ++i) dst.row(dec().permutationP().coeff(i)) = rhs().row(i); dst = dec().permutationP() * rhs();
// Step 2 // Step 2
dec().matrixLU().template triangularView<UnitLowerTriangular>().solveInPlace(dst); dec().matrixLU().template triangularView<UnitLowerTriangular>().solveInPlace(dst);