Generalize matrix ctor and compute() method of dense decomposition to 1) limit temporaries, 2) forward expressions to nested decompositions, 3) fix ambiguous ctor instanciation for square decomposition

This commit is contained in:
Gael Guennebaud 2015-09-07 10:42:04 +02:00
parent 1702fcb72e
commit 7031a851d4
14 changed files with 162 additions and 90 deletions

View File

@ -99,14 +99,15 @@ template<typename _MatrixType, int _UpLo> class LDLT
* This calculates the decomposition for the input \a matrix. * This calculates the decomposition for the input \a matrix.
* \sa LDLT(Index size) * \sa LDLT(Index size)
*/ */
explicit LDLT(const MatrixType& matrix) template<typename InputType>
explicit LDLT(const EigenBase<InputType>& matrix)
: m_matrix(matrix.rows(), matrix.cols()), : m_matrix(matrix.rows(), matrix.cols()),
m_transpositions(matrix.rows()), m_transpositions(matrix.rows()),
m_temporary(matrix.rows()), m_temporary(matrix.rows()),
m_sign(internal::ZeroSign), m_sign(internal::ZeroSign),
m_isInitialized(false) m_isInitialized(false)
{ {
compute(matrix); compute(matrix.derived());
} }
/** Clear any existing decomposition /** Clear any existing decomposition
@ -188,7 +189,8 @@ template<typename _MatrixType, int _UpLo> class LDLT
template<typename Derived> template<typename Derived>
bool solveInPlace(MatrixBase<Derived> &bAndX) const; bool solveInPlace(MatrixBase<Derived> &bAndX) const;
LDLT& compute(const MatrixType& matrix); template<typename InputType>
LDLT& compute(const EigenBase<InputType>& matrix);
template <typename Derived> template <typename Derived>
LDLT& rankUpdate(const MatrixBase<Derived>& w, const RealScalar& alpha=1); LDLT& rankUpdate(const MatrixBase<Derived>& w, const RealScalar& alpha=1);
@ -427,14 +429,15 @@ template<typename MatrixType> struct LDLT_Traits<MatrixType,Upper>
/** Compute / recompute the LDLT decomposition A = L D L^* = U^* D U of \a matrix /** Compute / recompute the LDLT decomposition A = L D L^* = U^* D U of \a matrix
*/ */
template<typename MatrixType, int _UpLo> template<typename MatrixType, int _UpLo>
LDLT<MatrixType,_UpLo>& LDLT<MatrixType,_UpLo>::compute(const MatrixType& a) template<typename InputType>
LDLT<MatrixType,_UpLo>& LDLT<MatrixType,_UpLo>::compute(const EigenBase<InputType>& a)
{ {
check_template_parameters(); check_template_parameters();
eigen_assert(a.rows()==a.cols()); eigen_assert(a.rows()==a.cols());
const Index size = a.rows(); const Index size = a.rows();
m_matrix = a; m_matrix = a.derived();
m_transpositions.resize(size); m_transpositions.resize(size);
m_isInitialized = false; m_isInitialized = false;

View File

@ -87,11 +87,12 @@ template<typename _MatrixType, int _UpLo> class LLT
explicit LLT(Index size) : m_matrix(size, size), explicit LLT(Index size) : m_matrix(size, size),
m_isInitialized(false) {} m_isInitialized(false) {}
explicit LLT(const MatrixType& matrix) template<typename InputType>
explicit LLT(const EigenBase<InputType>& matrix)
: m_matrix(matrix.rows(), matrix.cols()), : m_matrix(matrix.rows(), matrix.cols()),
m_isInitialized(false) m_isInitialized(false)
{ {
compute(matrix); compute(matrix.derived());
} }
/** \returns a view of the upper triangular matrix U */ /** \returns a view of the upper triangular matrix U */
@ -131,7 +132,8 @@ template<typename _MatrixType, int _UpLo> class LLT
template<typename Derived> template<typename Derived>
void solveInPlace(MatrixBase<Derived> &bAndX) const; void solveInPlace(MatrixBase<Derived> &bAndX) const;
LLT& compute(const MatrixType& matrix); template<typename InputType>
LLT& compute(const EigenBase<InputType>& matrix);
/** \returns the LLT decomposition matrix /** \returns the LLT decomposition matrix
* *
@ -381,14 +383,15 @@ template<typename MatrixType> struct LLT_Traits<MatrixType,Upper>
* Output: \verbinclude TutorialLinAlgComputeTwice.out * Output: \verbinclude TutorialLinAlgComputeTwice.out
*/ */
template<typename MatrixType, int _UpLo> template<typename MatrixType, int _UpLo>
LLT<MatrixType,_UpLo>& LLT<MatrixType,_UpLo>::compute(const MatrixType& a) template<typename InputType>
LLT<MatrixType,_UpLo>& LLT<MatrixType,_UpLo>::compute(const EigenBase<InputType>& a)
{ {
check_template_parameters(); check_template_parameters();
eigen_assert(a.rows()==a.cols()); eigen_assert(a.rows()==a.cols());
const Index size = a.rows(); const Index size = a.rows();
m_matrix.resize(size, size); m_matrix.resize(size, size);
m_matrix = a; m_matrix = a.derived();
m_isInitialized = true; m_isInitialized = true;
bool ok = Traits::inplace_decomposition(m_matrix); bool ok = Traits::inplace_decomposition(m_matrix);

View File

@ -122,7 +122,8 @@ template<typename _MatrixType> class ComplexEigenSolver
* *
* This constructor calls compute() to compute the eigendecomposition. * This constructor calls compute() to compute the eigendecomposition.
*/ */
explicit ComplexEigenSolver(const MatrixType& matrix, bool computeEigenvectors = true) template<typename InputType>
explicit ComplexEigenSolver(const EigenBase<InputType>& matrix, bool computeEigenvectors = true)
: m_eivec(matrix.rows(),matrix.cols()), : m_eivec(matrix.rows(),matrix.cols()),
m_eivalues(matrix.cols()), m_eivalues(matrix.cols()),
m_schur(matrix.rows()), m_schur(matrix.rows()),
@ -130,7 +131,7 @@ template<typename _MatrixType> class ComplexEigenSolver
m_eigenvectorsOk(false), m_eigenvectorsOk(false),
m_matX(matrix.rows(),matrix.cols()) m_matX(matrix.rows(),matrix.cols())
{ {
compute(matrix, computeEigenvectors); compute(matrix.derived(), computeEigenvectors);
} }
/** \brief Returns the eigenvectors of given matrix. /** \brief Returns the eigenvectors of given matrix.
@ -208,7 +209,8 @@ template<typename _MatrixType> class ComplexEigenSolver
* Example: \include ComplexEigenSolver_compute.cpp * Example: \include ComplexEigenSolver_compute.cpp
* Output: \verbinclude ComplexEigenSolver_compute.out * Output: \verbinclude ComplexEigenSolver_compute.out
*/ */
ComplexEigenSolver& compute(const MatrixType& matrix, bool computeEigenvectors = true); template<typename InputType>
ComplexEigenSolver& compute(const EigenBase<InputType>& matrix, bool computeEigenvectors = true);
/** \brief Reports whether previous computation was successful. /** \brief Reports whether previous computation was successful.
* *
@ -254,8 +256,9 @@ template<typename _MatrixType> class ComplexEigenSolver
template<typename MatrixType> template<typename MatrixType>
template<typename InputType>
ComplexEigenSolver<MatrixType>& ComplexEigenSolver<MatrixType>&
ComplexEigenSolver<MatrixType>::compute(const MatrixType& matrix, bool computeEigenvectors) ComplexEigenSolver<MatrixType>::compute(const EigenBase<InputType>& matrix, bool computeEigenvectors)
{ {
check_template_parameters(); check_template_parameters();
@ -264,13 +267,13 @@ ComplexEigenSolver<MatrixType>::compute(const MatrixType& matrix, bool computeEi
// Do a complex Schur decomposition, A = U T U^* // Do a complex Schur decomposition, A = U T U^*
// The eigenvalues are on the diagonal of T. // The eigenvalues are on the diagonal of T.
m_schur.compute(matrix, computeEigenvectors); m_schur.compute(matrix.derived(), computeEigenvectors);
if(m_schur.info() == Success) if(m_schur.info() == Success)
{ {
m_eivalues = m_schur.matrixT().diagonal(); m_eivalues = m_schur.matrixT().diagonal();
if(computeEigenvectors) if(computeEigenvectors)
doComputeEigenvectors(matrix.norm()); doComputeEigenvectors(m_schur.matrixT().norm());
sortEigenvalues(computeEigenvectors); sortEigenvalues(computeEigenvectors);
} }

View File

@ -109,7 +109,8 @@ template<typename _MatrixType> class ComplexSchur
* *
* \sa matrixT() and matrixU() for examples. * \sa matrixT() and matrixU() for examples.
*/ */
explicit ComplexSchur(const MatrixType& matrix, bool computeU = true) template<typename InputType>
explicit ComplexSchur(const EigenBase<InputType>& matrix, bool computeU = true)
: m_matT(matrix.rows(),matrix.cols()), : m_matT(matrix.rows(),matrix.cols()),
m_matU(matrix.rows(),matrix.cols()), m_matU(matrix.rows(),matrix.cols()),
m_hess(matrix.rows()), m_hess(matrix.rows()),
@ -117,7 +118,7 @@ template<typename _MatrixType> class ComplexSchur
m_matUisUptodate(false), m_matUisUptodate(false),
m_maxIters(-1) m_maxIters(-1)
{ {
compute(matrix, computeU); compute(matrix.derived(), computeU);
} }
/** \brief Returns the unitary matrix in the Schur decomposition. /** \brief Returns the unitary matrix in the Schur decomposition.
@ -186,7 +187,8 @@ template<typename _MatrixType> class ComplexSchur
* *
* \sa compute(const MatrixType&, bool, Index) * \sa compute(const MatrixType&, bool, Index)
*/ */
ComplexSchur& compute(const MatrixType& matrix, bool computeU = true); template<typename InputType>
ComplexSchur& compute(const EigenBase<InputType>& matrix, bool computeU = true);
/** \brief Compute Schur decomposition from a given Hessenberg matrix /** \brief Compute Schur decomposition from a given Hessenberg matrix
* \param[in] matrixH Matrix in Hessenberg form H * \param[in] matrixH Matrix in Hessenberg form H
@ -313,14 +315,15 @@ typename ComplexSchur<MatrixType>::ComplexScalar ComplexSchur<MatrixType>::compu
template<typename MatrixType> template<typename MatrixType>
ComplexSchur<MatrixType>& ComplexSchur<MatrixType>::compute(const MatrixType& matrix, bool computeU) template<typename InputType>
ComplexSchur<MatrixType>& ComplexSchur<MatrixType>::compute(const EigenBase<InputType>& matrix, bool computeU)
{ {
m_matUisUptodate = false; m_matUisUptodate = false;
eigen_assert(matrix.cols() == matrix.rows()); eigen_assert(matrix.cols() == matrix.rows());
if(matrix.cols() == 1) if(matrix.cols() == 1)
{ {
m_matT = matrix.template cast<ComplexScalar>(); m_matT = matrix.derived().template cast<ComplexScalar>();
if(computeU) m_matU = ComplexMatrixType::Identity(1,1); if(computeU) m_matU = ComplexMatrixType::Identity(1,1);
m_info = Success; m_info = Success;
m_isInitialized = true; m_isInitialized = true;
@ -328,7 +331,7 @@ ComplexSchur<MatrixType>& ComplexSchur<MatrixType>::compute(const MatrixType& ma
return *this; return *this;
} }
internal::complex_schur_reduce_to_hessenberg<MatrixType, NumTraits<Scalar>::IsComplex>::run(*this, matrix, computeU); internal::complex_schur_reduce_to_hessenberg<MatrixType, NumTraits<Scalar>::IsComplex>::run(*this, matrix.derived(), computeU);
computeFromHessenberg(m_matT, m_matU, computeU); computeFromHessenberg(m_matT, m_matU, computeU);
return *this; return *this;
} }

View File

@ -143,7 +143,8 @@ template<typename _MatrixType> class EigenSolver
* *
* \sa compute() * \sa compute()
*/ */
explicit EigenSolver(const MatrixType& matrix, bool computeEigenvectors = true) template<typename InputType>
explicit EigenSolver(const EigenBase<InputType>& matrix, bool computeEigenvectors = true)
: m_eivec(matrix.rows(), matrix.cols()), : m_eivec(matrix.rows(), matrix.cols()),
m_eivalues(matrix.cols()), m_eivalues(matrix.cols()),
m_isInitialized(false), m_isInitialized(false),
@ -152,7 +153,7 @@ template<typename _MatrixType> class EigenSolver
m_matT(matrix.rows(), matrix.cols()), m_matT(matrix.rows(), matrix.cols()),
m_tmp(matrix.cols()) m_tmp(matrix.cols())
{ {
compute(matrix, computeEigenvectors); compute(matrix.derived(), computeEigenvectors);
} }
/** \brief Returns the eigenvectors of given matrix. /** \brief Returns the eigenvectors of given matrix.
@ -273,7 +274,8 @@ template<typename _MatrixType> class EigenSolver
* Example: \include EigenSolver_compute.cpp * Example: \include EigenSolver_compute.cpp
* Output: \verbinclude EigenSolver_compute.out * Output: \verbinclude EigenSolver_compute.out
*/ */
EigenSolver& compute(const MatrixType& matrix, bool computeEigenvectors = true); template<typename InputType>
EigenSolver& compute(const EigenBase<InputType>& matrix, bool computeEigenvectors = true);
/** \returns NumericalIssue if the input contains INF or NaN values or overflow occured. Returns Success otherwise. */ /** \returns NumericalIssue if the input contains INF or NaN values or overflow occured. Returns Success otherwise. */
ComputationInfo info() const ComputationInfo info() const
@ -370,8 +372,9 @@ typename EigenSolver<MatrixType>::EigenvectorsType EigenSolver<MatrixType>::eige
} }
template<typename MatrixType> template<typename MatrixType>
template<typename InputType>
EigenSolver<MatrixType>& EigenSolver<MatrixType>&
EigenSolver<MatrixType>::compute(const MatrixType& matrix, bool computeEigenvectors) EigenSolver<MatrixType>::compute(const EigenBase<InputType>& matrix, bool computeEigenvectors)
{ {
check_template_parameters(); check_template_parameters();
@ -381,7 +384,7 @@ EigenSolver<MatrixType>::compute(const MatrixType& matrix, bool computeEigenvect
eigen_assert(matrix.cols() == matrix.rows()); eigen_assert(matrix.cols() == matrix.rows());
// Reduce to real Schur form. // Reduce to real Schur form.
m_realSchur.compute(matrix, computeEigenvectors); m_realSchur.compute(matrix.derived(), computeEigenvectors);
m_info = m_realSchur.info(); m_info = m_realSchur.info();

View File

@ -115,8 +115,9 @@ template<typename _MatrixType> class HessenbergDecomposition
* *
* \sa matrixH() for an example. * \sa matrixH() for an example.
*/ */
explicit HessenbergDecomposition(const MatrixType& matrix) template<typename InputType>
: m_matrix(matrix), explicit HessenbergDecomposition(const EigenBase<InputType>& matrix)
: m_matrix(matrix.derived()),
m_temp(matrix.rows()), m_temp(matrix.rows()),
m_isInitialized(false) m_isInitialized(false)
{ {
@ -147,9 +148,10 @@ template<typename _MatrixType> class HessenbergDecomposition
* Example: \include HessenbergDecomposition_compute.cpp * Example: \include HessenbergDecomposition_compute.cpp
* Output: \verbinclude HessenbergDecomposition_compute.out * Output: \verbinclude HessenbergDecomposition_compute.out
*/ */
HessenbergDecomposition& compute(const MatrixType& matrix) template<typename InputType>
HessenbergDecomposition& compute(const EigenBase<InputType>& matrix)
{ {
m_matrix = matrix; m_matrix = matrix.derived();
if(matrix.rows()<2) if(matrix.rows()<2)
{ {
m_isInitialized = true; m_isInitialized = true;

View File

@ -100,7 +100,8 @@ template<typename _MatrixType> class RealSchur
* Example: \include RealSchur_RealSchur_MatrixType.cpp * Example: \include RealSchur_RealSchur_MatrixType.cpp
* Output: \verbinclude RealSchur_RealSchur_MatrixType.out * Output: \verbinclude RealSchur_RealSchur_MatrixType.out
*/ */
explicit RealSchur(const MatrixType& matrix, bool computeU = true) template<typename InputType>
explicit RealSchur(const EigenBase<InputType>& matrix, bool computeU = true)
: m_matT(matrix.rows(),matrix.cols()), : m_matT(matrix.rows(),matrix.cols()),
m_matU(matrix.rows(),matrix.cols()), m_matU(matrix.rows(),matrix.cols()),
m_workspaceVector(matrix.rows()), m_workspaceVector(matrix.rows()),
@ -109,7 +110,7 @@ template<typename _MatrixType> class RealSchur
m_matUisUptodate(false), m_matUisUptodate(false),
m_maxIters(-1) m_maxIters(-1)
{ {
compute(matrix, computeU); compute(matrix.derived(), computeU);
} }
/** \brief Returns the orthogonal matrix in the Schur decomposition. /** \brief Returns the orthogonal matrix in the Schur decomposition.
@ -165,7 +166,8 @@ template<typename _MatrixType> class RealSchur
* *
* \sa compute(const MatrixType&, bool, Index) * \sa compute(const MatrixType&, bool, Index)
*/ */
RealSchur& compute(const MatrixType& matrix, bool computeU = true); template<typename InputType>
RealSchur& compute(const EigenBase<InputType>& matrix, bool computeU = true);
/** \brief Computes Schur decomposition of a Hessenberg matrix H = Z T Z^T /** \brief Computes Schur decomposition of a Hessenberg matrix H = Z T Z^T
* \param[in] matrixH Matrix in Hessenberg form H * \param[in] matrixH Matrix in Hessenberg form H
@ -243,7 +245,8 @@ template<typename _MatrixType> class RealSchur
template<typename MatrixType> template<typename MatrixType>
RealSchur<MatrixType>& RealSchur<MatrixType>::compute(const MatrixType& matrix, bool computeU) template<typename InputType>
RealSchur<MatrixType>& RealSchur<MatrixType>::compute(const EigenBase<InputType>& matrix, bool computeU)
{ {
eigen_assert(matrix.cols() == matrix.rows()); eigen_assert(matrix.cols() == matrix.rows());
Index maxIters = m_maxIters; Index maxIters = m_maxIters;
@ -251,7 +254,7 @@ RealSchur<MatrixType>& RealSchur<MatrixType>::compute(const MatrixType& matrix,
maxIters = m_maxIterationsPerRow * matrix.rows(); maxIters = m_maxIterationsPerRow * matrix.rows();
// Step 1. Reduce to Hessenberg form // Step 1. Reduce to Hessenberg form
m_hess.compute(matrix); m_hess.compute(matrix.derived());
// Step 2. Reduce to real Schur form // Step 2. Reduce to real Schur form
computeFromHessenberg(m_hess.matrixH(), m_hess.matrixQ(), computeU); computeFromHessenberg(m_hess.matrixH(), m_hess.matrixQ(), computeU);

View File

@ -158,13 +158,14 @@ template<typename _MatrixType> class SelfAdjointEigenSolver
* \sa compute(const MatrixType&, int) * \sa compute(const MatrixType&, int)
*/ */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
explicit SelfAdjointEigenSolver(const MatrixType& matrix, int options = ComputeEigenvectors) template<typename InputType>
explicit SelfAdjointEigenSolver(const EigenBase<InputType>& matrix, int options = ComputeEigenvectors)
: m_eivec(matrix.rows(), matrix.cols()), : m_eivec(matrix.rows(), matrix.cols()),
m_eivalues(matrix.cols()), m_eivalues(matrix.cols()),
m_subdiag(matrix.rows() > 1 ? matrix.rows() - 1 : 1), m_subdiag(matrix.rows() > 1 ? matrix.rows() - 1 : 1),
m_isInitialized(false) m_isInitialized(false)
{ {
compute(matrix, options); compute(matrix.derived(), options);
} }
/** \brief Computes eigendecomposition of given matrix. /** \brief Computes eigendecomposition of given matrix.
@ -198,7 +199,8 @@ template<typename _MatrixType> class SelfAdjointEigenSolver
* \sa SelfAdjointEigenSolver(const MatrixType&, int) * \sa SelfAdjointEigenSolver(const MatrixType&, int)
*/ */
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
SelfAdjointEigenSolver& compute(const MatrixType& matrix, int options = ComputeEigenvectors); template<typename InputType>
SelfAdjointEigenSolver& compute(const EigenBase<InputType>& matrix, int options = ComputeEigenvectors);
/** \brief Computes eigendecomposition of given matrix using a closed-form algorithm /** \brief Computes eigendecomposition of given matrix using a closed-form algorithm
* *
@ -389,12 +391,15 @@ static void tridiagonal_qr_step(RealScalar* diag, RealScalar* subdiag, Index sta
} }
template<typename MatrixType> template<typename MatrixType>
template<typename InputType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
SelfAdjointEigenSolver<MatrixType>& SelfAdjointEigenSolver<MatrixType> SelfAdjointEigenSolver<MatrixType>& SelfAdjointEigenSolver<MatrixType>
::compute(const MatrixType& matrix, int options) ::compute(const EigenBase<InputType>& a_matrix, int options)
{ {
check_template_parameters(); check_template_parameters();
const InputType &matrix(a_matrix);
using std::abs; using std::abs;
eigen_assert(matrix.cols() == matrix.rows()); eigen_assert(matrix.cols() == matrix.rows());
eigen_assert((options&~(EigVecMask|GenEigMask))==0 eigen_assert((options&~(EigVecMask|GenEigMask))==0

View File

@ -126,8 +126,9 @@ template<typename _MatrixType> class Tridiagonalization
* Example: \include Tridiagonalization_Tridiagonalization_MatrixType.cpp * Example: \include Tridiagonalization_Tridiagonalization_MatrixType.cpp
* Output: \verbinclude Tridiagonalization_Tridiagonalization_MatrixType.out * Output: \verbinclude Tridiagonalization_Tridiagonalization_MatrixType.out
*/ */
explicit Tridiagonalization(const MatrixType& matrix) template<typename InputType>
: m_matrix(matrix), explicit Tridiagonalization(const EigenBase<InputType>& matrix)
: m_matrix(matrix.derived()),
m_hCoeffs(matrix.cols() > 1 ? matrix.cols()-1 : 1), m_hCoeffs(matrix.cols() > 1 ? matrix.cols()-1 : 1),
m_isInitialized(false) m_isInitialized(false)
{ {
@ -152,9 +153,10 @@ template<typename _MatrixType> class Tridiagonalization
* Example: \include Tridiagonalization_compute.cpp * Example: \include Tridiagonalization_compute.cpp
* Output: \verbinclude Tridiagonalization_compute.out * Output: \verbinclude Tridiagonalization_compute.out
*/ */
Tridiagonalization& compute(const MatrixType& matrix) template<typename InputType>
Tridiagonalization& compute(const EigenBase<InputType>& matrix)
{ {
m_matrix = matrix; m_matrix = matrix.derived();
m_hCoeffs.resize(matrix.rows()-1, 1); m_hCoeffs.resize(matrix.rows()-1, 1);
internal::tridiagonalization_inplace(m_matrix, m_hCoeffs); internal::tridiagonalization_inplace(m_matrix, m_hCoeffs);
m_isInitialized = true; m_isInitialized = true;

View File

@ -95,7 +95,8 @@ template<typename _MatrixType> class FullPivLU
* \param matrix the matrix of which to compute the LU decomposition. * \param matrix the matrix of which to compute the LU decomposition.
* It is required to be nonzero. * It is required to be nonzero.
*/ */
explicit FullPivLU(const MatrixType& matrix); template<typename InputType>
explicit FullPivLU(const EigenBase<InputType>& matrix);
/** Computes the LU decomposition of the given matrix. /** Computes the LU decomposition of the given matrix.
* *
@ -104,7 +105,8 @@ template<typename _MatrixType> class FullPivLU
* *
* \returns a reference to *this * \returns a reference to *this
*/ */
FullPivLU& compute(const MatrixType& matrix); template<typename InputType>
FullPivLU& compute(const EigenBase<InputType>& matrix);
/** \returns the LU decomposition matrix: the upper-triangular part is U, the /** \returns the LU decomposition matrix: the upper-triangular part is U, the
* unit-lower-triangular part is L (at least for square matrices; in the non-square * unit-lower-triangular part is L (at least for square matrices; in the non-square
@ -396,6 +398,8 @@ template<typename _MatrixType> class FullPivLU
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
} }
void computeInPlace();
MatrixType m_lu; MatrixType m_lu;
PermutationPType m_p; PermutationPType m_p;
PermutationQType m_q; PermutationQType m_q;
@ -425,7 +429,8 @@ FullPivLU<MatrixType>::FullPivLU(Index rows, Index cols)
} }
template<typename MatrixType> template<typename MatrixType>
FullPivLU<MatrixType>::FullPivLU(const MatrixType& matrix) template<typename InputType>
FullPivLU<MatrixType>::FullPivLU(const EigenBase<InputType>& matrix)
: m_lu(matrix.rows(), matrix.cols()), : m_lu(matrix.rows(), matrix.cols()),
m_p(matrix.rows()), m_p(matrix.rows()),
m_q(matrix.cols()), m_q(matrix.cols()),
@ -434,11 +439,12 @@ FullPivLU<MatrixType>::FullPivLU(const MatrixType& matrix)
m_isInitialized(false), m_isInitialized(false),
m_usePrescribedThreshold(false) m_usePrescribedThreshold(false)
{ {
compute(matrix); compute(matrix.derived());
} }
template<typename MatrixType> template<typename MatrixType>
FullPivLU<MatrixType>& FullPivLU<MatrixType>::compute(const MatrixType& matrix) template<typename InputType>
FullPivLU<MatrixType>& FullPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix)
{ {
check_template_parameters(); check_template_parameters();
@ -446,16 +452,24 @@ FullPivLU<MatrixType>& FullPivLU<MatrixType>::compute(const MatrixType& matrix)
eigen_assert(matrix.rows()<=NumTraits<int>::highest() && matrix.cols()<=NumTraits<int>::highest()); eigen_assert(matrix.rows()<=NumTraits<int>::highest() && matrix.cols()<=NumTraits<int>::highest());
m_isInitialized = true; m_isInitialized = true;
m_lu = matrix; m_lu = matrix.derived();
const Index size = matrix.diagonalSize(); computeInPlace();
const Index rows = matrix.rows();
const Index cols = matrix.cols(); return *this;
}
template<typename MatrixType>
void FullPivLU<MatrixType>::computeInPlace()
{
const Index size = m_lu.diagonalSize();
const Index rows = m_lu.rows();
const Index cols = m_lu.cols();
// will store the transpositions, before we accumulate them at the end. // will store the transpositions, before we accumulate them at the end.
// can't accumulate on-the-fly because that will be done in reverse order for the rows. // can't accumulate on-the-fly because that will be done in reverse order for the rows.
m_rowsTranspositions.resize(matrix.rows()); m_rowsTranspositions.resize(m_lu.rows());
m_colsTranspositions.resize(matrix.cols()); m_colsTranspositions.resize(m_lu.cols());
Index number_of_transpositions = 0; // number of NONTRIVIAL transpositions, i.e. m_rowsTranspositions[i]!=i Index number_of_transpositions = 0; // number of NONTRIVIAL transpositions, i.e. m_rowsTranspositions[i]!=i
m_nonzero_pivots = size; // the generic case is that in which all pivots are nonzero (invertible case) m_nonzero_pivots = size; // the generic case is that in which all pivots are nonzero (invertible case)
@ -527,7 +541,6 @@ FullPivLU<MatrixType>& FullPivLU<MatrixType>::compute(const MatrixType& matrix)
m_q.applyTranspositionOnTheRight(k, m_colsTranspositions.coeff(k)); m_q.applyTranspositionOnTheRight(k, m_colsTranspositions.coeff(k));
m_det_pq = (number_of_transpositions%2) ? -1 : 1; m_det_pq = (number_of_transpositions%2) ? -1 : 1;
return *this;
} }
template<typename MatrixType> template<typename MatrixType>

View File

@ -102,9 +102,11 @@ template<typename _MatrixType> class PartialPivLU
* \warning The matrix should have full rank (e.g. if it's square, it should be invertible). * \warning The matrix should have full rank (e.g. if it's square, it should be invertible).
* If you need to deal with non-full rank, use class FullPivLU instead. * If you need to deal with non-full rank, use class FullPivLU instead.
*/ */
explicit PartialPivLU(const MatrixType& matrix); template<typename InputType>
explicit PartialPivLU(const EigenBase<InputType>& matrix);
PartialPivLU& compute(const MatrixType& matrix); template<typename InputType>
PartialPivLU& compute(const EigenBase<InputType>& matrix);
/** \returns the LU decomposition matrix: the upper-triangular part is U, the /** \returns the LU decomposition matrix: the upper-triangular part is U, the
* unit-lower-triangular part is L (at least for square matrices; in the non-square * unit-lower-triangular part is L (at least for square matrices; in the non-square
@ -243,14 +245,15 @@ PartialPivLU<MatrixType>::PartialPivLU(Index size)
} }
template<typename MatrixType> template<typename MatrixType>
PartialPivLU<MatrixType>::PartialPivLU(const MatrixType& matrix) template<typename InputType>
PartialPivLU<MatrixType>::PartialPivLU(const EigenBase<InputType>& matrix)
: m_lu(matrix.rows(), matrix.rows()), : m_lu(matrix.rows(), matrix.rows()),
m_p(matrix.rows()), m_p(matrix.rows()),
m_rowsTranspositions(matrix.rows()), m_rowsTranspositions(matrix.rows()),
m_det_p(0), m_det_p(0),
m_isInitialized(false) m_isInitialized(false)
{ {
compute(matrix); compute(matrix.derived());
} }
namespace internal { namespace internal {
@ -429,14 +432,15 @@ void partial_lu_inplace(MatrixType& lu, TranspositionType& row_transpositions, t
} // end namespace internal } // end namespace internal
template<typename MatrixType> template<typename MatrixType>
PartialPivLU<MatrixType>& PartialPivLU<MatrixType>::compute(const MatrixType& matrix) template<typename InputType>
PartialPivLU<MatrixType>& PartialPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix)
{ {
check_template_parameters(); check_template_parameters();
// the row permutation is stored as int indices, so just to be sure: // the row permutation is stored as int indices, so just to be sure:
eigen_assert(matrix.rows()<NumTraits<int>::highest()); eigen_assert(matrix.rows()<NumTraits<int>::highest());
m_lu = matrix; m_lu = matrix.derived();
eigen_assert(matrix.rows() == matrix.cols() && "PartialPivLU is only for square (and moreover invertible) matrices"); eigen_assert(matrix.rows() == matrix.cols() && "PartialPivLU is only for square (and moreover invertible) matrices");
const Index size = matrix.rows(); const Index size = matrix.rows();

View File

@ -118,7 +118,8 @@ template<typename _MatrixType> class ColPivHouseholderQR
* *
* \sa compute() * \sa compute()
*/ */
explicit ColPivHouseholderQR(const MatrixType& matrix) template<typename InputType>
explicit ColPivHouseholderQR(const EigenBase<InputType>& matrix)
: m_qr(matrix.rows(), matrix.cols()), : m_qr(matrix.rows(), matrix.cols()),
m_hCoeffs((std::min)(matrix.rows(),matrix.cols())), m_hCoeffs((std::min)(matrix.rows(),matrix.cols())),
m_colsPermutation(PermIndexType(matrix.cols())), m_colsPermutation(PermIndexType(matrix.cols())),
@ -128,7 +129,7 @@ template<typename _MatrixType> class ColPivHouseholderQR
m_isInitialized(false), m_isInitialized(false),
m_usePrescribedThreshold(false) m_usePrescribedThreshold(false)
{ {
compute(matrix); compute(matrix.derived());
} }
/** This method finds a solution x to the equation Ax=b, where A is the matrix of which /** This method finds a solution x to the equation Ax=b, where A is the matrix of which
@ -185,7 +186,8 @@ template<typename _MatrixType> class ColPivHouseholderQR
return m_qr; return m_qr;
} }
ColPivHouseholderQR& compute(const MatrixType& matrix); template<typename InputType>
ColPivHouseholderQR& compute(const EigenBase<InputType>& matrix);
/** \returns a const reference to the column permutation matrix */ /** \returns a const reference to the column permutation matrix */
const PermutationType& colsPermutation() const const PermutationType& colsPermutation() const
@ -404,6 +406,8 @@ template<typename _MatrixType> class ColPivHouseholderQR
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
} }
void computeInPlace();
MatrixType m_qr; MatrixType m_qr;
HCoeffsType m_hCoeffs; HCoeffsType m_hCoeffs;
PermutationType m_colsPermutation; PermutationType m_colsPermutation;
@ -440,24 +444,34 @@ typename MatrixType::RealScalar ColPivHouseholderQR<MatrixType>::logAbsDetermina
* \sa class ColPivHouseholderQR, ColPivHouseholderQR(const MatrixType&) * \sa class ColPivHouseholderQR, ColPivHouseholderQR(const MatrixType&)
*/ */
template<typename MatrixType> template<typename MatrixType>
ColPivHouseholderQR<MatrixType>& ColPivHouseholderQR<MatrixType>::compute(const MatrixType& matrix) template<typename InputType>
ColPivHouseholderQR<MatrixType>& ColPivHouseholderQR<MatrixType>::compute(const EigenBase<InputType>& matrix)
{ {
check_template_parameters(); check_template_parameters();
using std::abs;
Index rows = matrix.rows();
Index cols = matrix.cols();
Index size = matrix.diagonalSize();
// the column permutation is stored as int indices, so just to be sure: // the column permutation is stored as int indices, so just to be sure:
eigen_assert(cols<=NumTraits<int>::highest()); eigen_assert(matrix.cols()<=NumTraits<int>::highest());
m_qr = matrix; m_qr = matrix;
computeInPlace();
return *this;
}
template<typename MatrixType>
void ColPivHouseholderQR<MatrixType>::computeInPlace()
{
using std::abs;
Index rows = m_qr.rows();
Index cols = m_qr.cols();
Index size = m_qr.diagonalSize();
m_hCoeffs.resize(size); m_hCoeffs.resize(size);
m_temp.resize(cols); m_temp.resize(cols);
m_colsTranspositions.resize(matrix.cols()); m_colsTranspositions.resize(m_qr.cols());
Index number_of_transpositions = 0; Index number_of_transpositions = 0;
m_colSqNorms.resize(cols); m_colSqNorms.resize(cols);
@ -522,8 +536,6 @@ ColPivHouseholderQR<MatrixType>& ColPivHouseholderQR<MatrixType>::compute(const
m_det_pq = (number_of_transpositions%2) ? -1 : 1; m_det_pq = (number_of_transpositions%2) ? -1 : 1;
m_isInitialized = true; m_isInitialized = true;
return *this;
} }
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN

View File

@ -121,7 +121,8 @@ template<typename _MatrixType> class FullPivHouseholderQR
* *
* \sa compute() * \sa compute()
*/ */
explicit FullPivHouseholderQR(const MatrixType& matrix) template<typename InputType>
explicit FullPivHouseholderQR(const EigenBase<InputType>& matrix)
: m_qr(matrix.rows(), matrix.cols()), : m_qr(matrix.rows(), matrix.cols()),
m_hCoeffs((std::min)(matrix.rows(), matrix.cols())), m_hCoeffs((std::min)(matrix.rows(), matrix.cols())),
m_rows_transpositions((std::min)(matrix.rows(), matrix.cols())), m_rows_transpositions((std::min)(matrix.rows(), matrix.cols())),
@ -131,7 +132,7 @@ template<typename _MatrixType> class FullPivHouseholderQR
m_isInitialized(false), m_isInitialized(false),
m_usePrescribedThreshold(false) m_usePrescribedThreshold(false)
{ {
compute(matrix); compute(matrix.derived());
} }
/** This method finds a solution x to the equation Ax=b, where A is the matrix of which /** This method finds a solution x to the equation Ax=b, where A is the matrix of which
@ -172,7 +173,8 @@ template<typename _MatrixType> class FullPivHouseholderQR
return m_qr; return m_qr;
} }
FullPivHouseholderQR& compute(const MatrixType& matrix); template<typename InputType>
FullPivHouseholderQR& compute(const EigenBase<InputType>& matrix);
/** \returns a const reference to the column permutation matrix */ /** \returns a const reference to the column permutation matrix */
const PermutationType& colsPermutation() const const PermutationType& colsPermutation() const
@ -386,6 +388,8 @@ template<typename _MatrixType> class FullPivHouseholderQR
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
} }
void computeInPlace();
MatrixType m_qr; MatrixType m_qr;
HCoeffsType m_hCoeffs; HCoeffsType m_hCoeffs;
IntDiagSizeVectorType m_rows_transpositions; IntDiagSizeVectorType m_rows_transpositions;
@ -423,16 +427,27 @@ typename MatrixType::RealScalar FullPivHouseholderQR<MatrixType>::logAbsDetermin
* \sa class FullPivHouseholderQR, FullPivHouseholderQR(const MatrixType&) * \sa class FullPivHouseholderQR, FullPivHouseholderQR(const MatrixType&)
*/ */
template<typename MatrixType> template<typename MatrixType>
FullPivHouseholderQR<MatrixType>& FullPivHouseholderQR<MatrixType>::compute(const MatrixType& matrix) template<typename InputType>
FullPivHouseholderQR<MatrixType>& FullPivHouseholderQR<MatrixType>::compute(const EigenBase<InputType>& matrix)
{ {
check_template_parameters(); check_template_parameters();
m_qr = matrix.derived();
computeInPlace();
return *this;
}
template<typename MatrixType>
void FullPivHouseholderQR<MatrixType>::computeInPlace()
{
using std::abs; using std::abs;
Index rows = matrix.rows(); Index rows = m_qr.rows();
Index cols = matrix.cols(); Index cols = m_qr.cols();
Index size = (std::min)(rows,cols); Index size = (std::min)(rows,cols);
m_qr = matrix;
m_hCoeffs.resize(size); m_hCoeffs.resize(size);
m_temp.resize(cols); m_temp.resize(cols);
@ -503,8 +518,6 @@ FullPivHouseholderQR<MatrixType>& FullPivHouseholderQR<MatrixType>::compute(cons
m_det_pq = (number_of_transpositions%2) ? -1 : 1; m_det_pq = (number_of_transpositions%2) ? -1 : 1;
m_isInitialized = true; m_isInitialized = true;
return *this;
} }
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN

View File

@ -92,13 +92,14 @@ template<typename _MatrixType> class HouseholderQR
* *
* \sa compute() * \sa compute()
*/ */
explicit HouseholderQR(const MatrixType& matrix) template<typename InputType>
explicit HouseholderQR(const EigenBase<InputType>& matrix)
: m_qr(matrix.rows(), matrix.cols()), : m_qr(matrix.rows(), matrix.cols()),
m_hCoeffs((std::min)(matrix.rows(),matrix.cols())), m_hCoeffs((std::min)(matrix.rows(),matrix.cols())),
m_temp(matrix.cols()), m_temp(matrix.cols()),
m_isInitialized(false) m_isInitialized(false)
{ {
compute(matrix); compute(matrix.derived());
} }
/** This method finds a solution x to the equation Ax=b, where A is the matrix of which /** This method finds a solution x to the equation Ax=b, where A is the matrix of which
@ -149,7 +150,8 @@ template<typename _MatrixType> class HouseholderQR
return m_qr; return m_qr;
} }
HouseholderQR& compute(const MatrixType& matrix); template<typename InputType>
HouseholderQR& compute(const EigenBase<InputType>& matrix);
/** \returns the absolute value of the determinant of the matrix of which /** \returns the absolute value of the determinant of the matrix of which
* *this is the QR decomposition. It has only linear complexity * *this is the QR decomposition. It has only linear complexity
@ -352,7 +354,8 @@ void HouseholderQR<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) c
* \sa class HouseholderQR, HouseholderQR(const MatrixType&) * \sa class HouseholderQR, HouseholderQR(const MatrixType&)
*/ */
template<typename MatrixType> template<typename MatrixType>
HouseholderQR<MatrixType>& HouseholderQR<MatrixType>::compute(const MatrixType& matrix) template<typename InputType>
HouseholderQR<MatrixType>& HouseholderQR<MatrixType>::compute(const EigenBase<InputType>& matrix)
{ {
check_template_parameters(); check_template_parameters();
@ -360,7 +363,7 @@ HouseholderQR<MatrixType>& HouseholderQR<MatrixType>::compute(const MatrixType&
Index cols = matrix.cols(); Index cols = matrix.cols();
Index size = (std::min)(rows,cols); Index size = (std::min)(rows,cols);
m_qr = matrix; m_qr = matrix.derived();
m_hCoeffs.resize(size); m_hCoeffs.resize(size);
m_temp.resize(cols); m_temp.resize(cols);