Iterative solvers: unify and fix handling of multiple rhs.

m_info was not properly computed and the logic was repeated in several places.
This commit is contained in:
Gael Guennebaud 2018-10-15 23:47:46 +02:00
parent 2747b98cfc
commit f0fb95135d
7 changed files with 92 additions and 130 deletions

View File

@ -191,32 +191,16 @@ public:
/** \internal */
template<typename Rhs,typename Dest>
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
bool failed = false;
for(Index j=0; j<b.cols(); ++j)
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
if(!internal::bicgstab(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error))
failed = true;
}
m_info = failed ? NumericalIssue
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
bool ret = internal::bicgstab(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error);
m_info = (!ret) ? NumericalIssue
: m_error <= Base::m_tolerance ? Success
: NoConvergence;
m_isInitialized = true;
}
/** \internal */
using Base::_solve_impl;
template<typename Rhs,typename Dest>
void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
{
x.resize(this->rows(),b.cols());
x.setZero();
_solve_with_guess_impl(b,x);
}
protected:

View File

@ -195,7 +195,7 @@ public:
/** \internal */
template<typename Rhs,typename Dest>
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
typedef typename Base::MatrixWrapper MatrixWrapper;
typedef typename Base::ActualMatrixType ActualMatrixType;
@ -211,31 +211,14 @@ public:
RowMajorWrapper,
typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type
>::type SelfAdjointWrapper;
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
for(Index j=0; j<b.cols(); ++j)
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
RowMajorWrapper row_mat(matrix());
internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
}
m_isInitialized = true;
RowMajorWrapper row_mat(matrix());
internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b, x, Base::m_preconditioner, m_iterations, m_error);
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
/** \internal */
using Base::_solve_impl;
template<typename Rhs,typename Dest>
void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
{
x.setZero();
_solve_with_guess_impl(b.derived(),x);
}
protected:

View File

@ -331,7 +331,7 @@ public:
/** \internal */
template<typename Rhs, typename DestDerived>
void _solve_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
{
eigen_assert(rows()==b.rows());
@ -344,15 +344,66 @@ public:
// We do not directly fill dest because sparse expressions have to be free of aliasing issue.
// For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
typename DestDerived::PlainObject tmp(cols(),rhsCols);
ComputationInfo global_info = Success;
for(Index k=0; k<rhsCols; ++k)
{
tb = b.col(k);
tx = derived().solve(tb);
tx = dest.col(k);
derived()._solve_vector_with_guess_impl(tb,tx);
tmp.col(k) = tx.sparseView(0);
// The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
// we need to restore it to the worst value.
if(m_info==NumericalIssue)
global_info = NumericalIssue;
else if(m_info==NoConvergence)
global_info = NoConvergence;
}
m_info = global_info;
dest.swap(tmp);
}
template<typename Rhs, typename DestDerived>
typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type
_solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const
{
eigen_assert(rows()==b.rows());
Index rhsCols = b.cols();
DestDerived& dest(aDest.derived());
ComputationInfo global_info = Success;
for(Index k=0; k<rhsCols; ++k)
{
typename DestDerived::ColXpr xk(dest,k);
typename Rhs::ConstColXpr bk(b,k);
derived()._solve_vector_with_guess_impl(bk,xk);
// The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
// we need to restore it to the worst value.
if(m_info==NumericalIssue)
global_info = NumericalIssue;
else if(m_info==NoConvergence)
global_info = NoConvergence;
}
m_info = global_info;
}
template<typename Rhs, typename DestDerived>
typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type
_solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const
{
derived()._solve_vector_with_guess_impl(b,dest.derived());
}
/** \internal default initial guess = 0 */
template<typename Rhs,typename Dest>
void _solve_impl(const Rhs& b, Dest& x) const
{
x.resize(this->rows(),b.cols());
x.setZero();
derived()._solve_with_guess_impl(b,x);
}
protected:
void init()
{

View File

@ -182,32 +182,14 @@ public:
/** \internal */
template<typename Rhs,typename Dest>
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
for(Index j=0; j<b.cols(); ++j)
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
internal::least_square_conjugate_gradient(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
}
m_isInitialized = true;
internal::least_square_conjugate_gradient(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error);
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
/** \internal */
using Base::_solve_impl;
template<typename Rhs,typename Dest>
void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
{
x.setZero();
_solve_with_guess_impl(b.derived(),x);
}
};

View File

@ -109,6 +109,7 @@ class DGMRES : public IterativeSolverBase<DGMRES<_MatrixType,_Preconditioner> >
using Base::m_tolerance;
public:
using Base::_solve_impl;
using Base::_solve_with_guess_impl;
typedef _MatrixType MatrixType;
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::StorageIndex StorageIndex;
@ -141,30 +142,16 @@ class DGMRES : public IterativeSolverBase<DGMRES<_MatrixType,_Preconditioner> >
/** \internal */
template<typename Rhs,typename Dest>
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
{
bool failed = false;
for(Index j=0; j<b.cols(); ++j)
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
dgmres(matrix(), b.col(j), xj, Base::m_preconditioner);
}
m_info = failed ? NumericalIssue
: m_error <= Base::m_tolerance ? Success
: NoConvergence;
m_isInitialized = true;
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
EIGEN_STATIC_ASSERT(Rhs::ColsAtCompileTime==1 || Dest::ColsAtCompileTime==1, YOU_TRIED_CALLING_A_VECTOR_METHOD_ON_A_MATRIX);
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
dgmres(matrix(), b, x, Base::m_preconditioner);
}
/** \internal */
template<typename Rhs,typename Dest>
void _solve_impl(const Rhs& b, MatrixBase<Dest>& x) const
{
x = b;
_solve_with_guess_impl(b,x.derived());
}
/**
* Get the restart value
*/

View File

@ -64,6 +64,15 @@ bool gmres(const MatrixType & mat, const Rhs & rhs, Dest & x, const Precondition
typedef Matrix < Scalar, Dynamic, 1 > VectorType;
typedef Matrix < Scalar, Dynamic, Dynamic, ColMajor> FMatrixType;
const RealScalar considerAsZero = (std::numeric_limits<RealScalar>::min)();
if(rhs.norm() <= considerAsZero)
{
x.setZero();
tol_error = 0;
return true;
}
RealScalar tol = tol_error;
const Index maxIters = iters;
iters = 0;
@ -307,31 +316,14 @@ public:
/** \internal */
template<typename Rhs,typename Dest>
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
bool failed = false;
for(Index j=0; j<b.cols(); ++j)
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
if(!internal::gmres(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_restart, m_error))
failed = true;
}
m_info = failed ? NumericalIssue
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
bool ret = internal::gmres(matrix(), b, x, Base::m_preconditioner, m_iterations, m_restart, m_error);
m_info = (!ret) ? NumericalIssue
: m_error <= Base::m_tolerance ? Success
: NoConvergence;
m_isInitialized = true;
}
/** \internal */
template<typename Rhs,typename Dest>
void _solve_impl(const Rhs& b, MatrixBase<Dest> &x) const
{
x = b;
if(x.squaredNorm() == 0) return; // Check Zero right hand side
_solve_with_guess_impl(b,x.derived());
}
protected:

View File

@ -233,7 +233,7 @@ namespace Eigen {
/** \internal */
template<typename Rhs,typename Dest>
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
{
typedef typename Base::MatrixWrapper MatrixWrapper;
typedef typename Base::ActualMatrixType ActualMatrixType;
@ -253,28 +253,11 @@ namespace Eigen {
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
RowMajorWrapper row_mat(matrix());
for(int j=0; j<b.cols(); ++j)
{
m_iterations = Base::maxIterations();
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
internal::minres(SelfAdjointWrapper(row_mat), b.col(j), xj,
Base::m_preconditioner, m_iterations, m_error);
}
m_isInitialized = true;
internal::minres(SelfAdjointWrapper(row_mat), b, x,
Base::m_preconditioner, m_iterations, m_error);
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
}
/** \internal */
template<typename Rhs,typename Dest>
void _solve_impl(const Rhs& b, MatrixBase<Dest> &x) const
{
x.setZero();
_solve_with_guess_impl(b,x.derived());
}
protected:
};