factorize solving with guess

This commit is contained in:
Gael Guennebaud 2011-10-24 09:33:24 +02:00
parent 70df09b76d
commit 5d43b4049d
3 changed files with 80 additions and 72 deletions

View File

@ -106,9 +106,6 @@ class BiCGSTAB;
namespace internal {
template<typename CG, typename Rhs, typename Guess>
class bicgstab_solve_retval_with_guess;
template< typename _MatrixType, typename _Preconditioner>
struct traits<BiCGSTAB<_MatrixType,_Preconditioner> >
{
@ -204,19 +201,19 @@ public:
* \sa compute()
*/
template<typename Rhs,typename Guess>
inline const internal::bicgstab_solve_retval_with_guess<BiCGSTAB, Rhs, Guess>
inline const internal::solve_retval_with_guess<BiCGSTAB, Rhs, Guess>
solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
{
eigen_assert(m_isInitialized && "BiCGSTAB is not initialized.");
eigen_assert(Base::rows()==b.rows()
&& "BiCGSTAB::solve(): invalid number of rows of the right hand side matrix b");
return internal::bicgstab_solve_retval_with_guess
return internal::solve_retval_with_guess
<BiCGSTAB, Rhs, Guess>(*this, b.derived(), x0);
}
/** \internal */
template<typename Rhs,typename Dest>
void _solve(const Rhs& b, Dest& x) const
void _solveWithGuess(const Rhs& b, Dest& x) const
{
for(int j=0; j<b.cols(); ++j)
{
@ -231,6 +228,14 @@ public:
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
/** \internal */
template<typename Rhs,typename Dest>
void _solve(const Rhs& b, Dest& x) const
{
x.setOnes();
_solveWithGuess(b,x);
}
protected:
};
@ -247,33 +252,10 @@ struct solve_retval<BiCGSTAB<_MatrixType, _Preconditioner>, Rhs>
template<typename Dest> void evalTo(Dest& dst) const
{
dst.setOnes();
dec()._solve(rhs(),dst);
}
};
template<typename CG, typename Rhs, typename Guess>
class bicgstab_solve_retval_with_guess
: public solve_retval_base<CG, Rhs>
{
typedef Eigen::internal::solve_retval_base<CG,Rhs> Base;
using Base::dec;
using Base::rhs;
public:
bicgstab_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess)
: Base(cg, rhs), m_guess(guess)
{}
template<typename Dest> void evalTo(Dest& dst) const
{
dst = m_guess;
dec()._solve(rhs(), dst);
}
protected:
const Guess& m_guess;
};
}
#endif // EIGEN_BICGSTAB_H

View File

@ -37,6 +37,7 @@ namespace internal {
* \param tol_error On input the tolerance error, on output an estimation of the relative error.
*/
template<typename MatrixType, typename Rhs, typename Dest, typename Preconditioner>
EIGEN_DONT_INLINE
void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
const Preconditioner& precond, int& iters,
typename Dest::RealScalar& tol_error)
@ -59,7 +60,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
VectorType z(n), tmp(n);
RealScalar absNew = internal::real(residual.dot(p)); // the square of the absolute value of r scaled by invM
RealScalar absInit = absNew; // the initial absolute value
int i = 0;
while ((i < maxIters) && (absNew > tol*tol*absInit))
{
@ -89,9 +90,6 @@ class ConjugateGradient;
namespace internal {
template<typename CG, typename Rhs, typename Guess>
class conjugate_gradient_solve_retval_with_guess;
template< typename _MatrixType, int _UpLo, typename _Preconditioner>
struct traits<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner> >
{
@ -193,32 +191,43 @@ public:
* \sa compute()
*/
template<typename Rhs,typename Guess>
inline const internal::conjugate_gradient_solve_retval_with_guess<ConjugateGradient, Rhs, Guess>
inline const internal::solve_retval_with_guess<ConjugateGradient, Rhs, Guess>
solveWithGuess(const MatrixBase<Rhs>& b, const Guess& x0) const
{
eigen_assert(m_isInitialized && "ConjugateGradient is not initialized.");
eigen_assert(Base::rows()==b.rows()
&& "ConjugateGradient::solve(): invalid number of rows of the right hand side matrix b");
return internal::conjugate_gradient_solve_retval_with_guess
return internal::solve_retval_with_guess
<ConjugateGradient, Rhs, Guess>(*this, b.derived(), x0);
}
/** \internal */
template<typename Rhs,typename Dest>
void _solveWithGuess(const Rhs& b, Dest& x) const
{
m_iterations = Base::m_maxIterations;
m_error = Base::m_tolerance;
for(int j=0; j<b.cols(); ++j)
{
m_iterations = Base::m_maxIterations;
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
Base::m_preconditioner, m_iterations, m_error);
}
m_isInitialized = true;
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
/** \internal */
template<typename Rhs,typename Dest>
void _solve(const Rhs& b, Dest& x) const
{
for(int j=0; j<b.cols(); ++j)
{
m_iterations = Base::m_maxIterations;
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
Base::m_preconditioner, m_iterations, m_error);
}
m_isInitialized = true;
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
x.setOnes();
_solveWithGuess(b,x);
}
protected:
@ -228,7 +237,7 @@ protected:
namespace internal {
template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs>
template<typename _MatrixType, int _UpLo, typename _Preconditioner, typename Rhs>
struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
: solve_retval_base<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
{
@ -237,33 +246,10 @@ struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
template<typename Dest> void evalTo(Dest& dst) const
{
dst.setOnes();
dec()._solve(rhs(),dst);
}
};
template<typename CG, typename Rhs, typename Guess>
class conjugate_gradient_solve_retval_with_guess
: public solve_retval_base<CG, Rhs>
{
typedef Eigen::internal::solve_retval_base<CG,Rhs> Base;
using Base::dec;
using Base::rhs;
public:
conjugate_gradient_solve_retval_with_guess(const CG& cg, const Rhs& rhs, const Guess& guess)
: Base(cg, rhs), m_guess(guess)
{}
template<typename Dest> void evalTo(Dest& dst) const
{
dst = m_guess;
dec()._solve(rhs(), dst);
}
protected:
const Guess& m_guess;
};
}
#endif // EIGEN_CONJUGATE_GRADIENT_H

View File

@ -76,7 +76,47 @@ template<typename _DecompositionType, typename Rhs> struct sparse_solve_retval_b
using Base::cols; \
sparse_solve_retval(const DecompositionType& dec, const Rhs& rhs) \
: Base(dec, rhs) {}
template<typename DecompositionType, typename Rhs, typename Guess> struct solve_retval_with_guess;
template<typename DecompositionType, typename Rhs, typename Guess>
struct traits<solve_retval_with_guess<DecompositionType, Rhs, Guess> >
{
typedef typename DecompositionType::MatrixType MatrixType;
typedef Matrix<typename Rhs::Scalar,
MatrixType::ColsAtCompileTime,
Rhs::ColsAtCompileTime,
Rhs::PlainObject::Options,
MatrixType::MaxColsAtCompileTime,
Rhs::MaxColsAtCompileTime> ReturnType;
};
template<typename DecompositionType, typename Rhs, typename Guess> struct solve_retval_with_guess
: public ReturnByValue<solve_retval_with_guess<DecompositionType, Rhs, Guess> >
{
typedef typename DecompositionType::Index Index;
solve_retval_with_guess(const DecompositionType& dec, const Rhs& rhs, const Guess& guess)
: m_dec(dec), m_rhs(rhs), m_guess(guess)
{}
inline Index rows() const { return m_dec.cols(); }
inline Index cols() const { return m_rhs.cols(); }
template<typename Dest> inline void evalTo(Dest& dst) const
{
dst = m_guess;
m_dec._solveWithGuess(m_rhs,dst);
}
protected:
const DecompositionType& m_dec;
const typename Rhs::Nested m_rhs;
const typename Guess::Nested m_guess;
};
} // namepsace internal
#endif // EIGEN_SPARSE_SOLVE_H