Fix IterativeSolverBase for expressions as input

This commit is contained in:
Gael Guennebaud 2015-11-05 12:05:31 +01:00
parent 47592d31ea
commit a92681e0d2

View File

@ -49,10 +49,11 @@ public:
* this class becomes invalid. Call compute() to update it with the new
* matrix A, or modify a copy of A.
*/
IterativeSolverBase(const MatrixType& A)
template<typename InputDerived>
IterativeSolverBase(const EigenBase<InputDerived>& A)
{
init();
compute(A);
compute(A.derived());
}
~IterativeSolverBase() {}
@ -62,9 +63,11 @@ public:
* Currently, this function mostly call analyzePattern on the preconditioner. In the future
* we might, for instance, implement column reodering for faster matrix vector products.
*/
Derived& analyzePattern(const MatrixType& A)
template<typename InputDerived>
Derived& analyzePattern(const EigenBase<InputDerived>& A)
{
m_preconditioner.analyzePattern(A);
grabInput(A.derived());
m_preconditioner.analyzePattern(*mp_matrix);
m_isInitialized = true;
m_analysisIsOk = true;
m_info = Success;
@ -80,11 +83,12 @@ public:
* this class becomes invalid. Call compute() to update it with the new
* matrix A, or modify a copy of A.
*/
Derived& factorize(const MatrixType& A)
template<typename InputDerived>
Derived& factorize(const EigenBase<InputDerived>& A)
{
grabInput(A.derived());
eigen_assert(m_analysisIsOk && "You must first call analyzePattern()");
mp_matrix = &A;
m_preconditioner.factorize(A);
m_preconditioner.factorize(*mp_matrix);
m_factorizationIsOk = true;
m_info = Success;
return derived();
@ -100,10 +104,11 @@ public:
* this class becomes invalid. Call compute() to update it with the new
* matrix A, or modify a copy of A.
*/
Derived& compute(const MatrixType& A)
template<typename InputDerived>
Derived& compute(const EigenBase<InputDerived>& A)
{
mp_matrix = &A;
m_preconditioner.compute(A);
grabInput(A.derived());
m_preconditioner.compute(*mp_matrix);
m_isInitialized = true;
m_analysisIsOk = true;
m_factorizationIsOk = true;
@ -212,6 +217,28 @@ public:
}
protected:
template<typename InputDerived>
void grabInput(const EigenBase<InputDerived>& A)
{
// we const cast to prevent the creation of a MatrixType temporary by the compiler.
grabInput_impl(A.const_cast_derived());
}
template<typename InputDerived>
void grabInput_impl(const EigenBase<InputDerived>& A)
{
m_copyMatrix = A;
mp_matrix = &m_copyMatrix;
}
void grabInput_impl(MatrixType& A)
{
if(MatrixType::RowsAtCompileTime==Dynamic && MatrixType::ColsAtCompileTime==Dynamic)
m_copyMatrix.resize(0,0);
mp_matrix = &A;
}
void init()
{
m_isInitialized = false;
@ -220,6 +247,7 @@ protected:
m_maxIterations = -1;
m_tolerance = NumTraits<Scalar>::epsilon();
}
MatrixType m_copyMatrix;
const MatrixType* mp_matrix;
Preconditioner m_preconditioner;