mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 03:09:01 +08:00
Iterative solvers: unify and fix handling of multiple rhs.
m_info was not properly computed and the logic was repeated in several places.
This commit is contained in:
parent
2747b98cfc
commit
f0fb95135d
@ -191,32 +191,16 @@ public:
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
bool failed = false;
|
||||
for(Index j=0; j<b.cols(); ++j)
|
||||
{
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
typename Dest::ColXpr xj(x,j);
|
||||
if(!internal::bicgstab(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error))
|
||||
failed = true;
|
||||
}
|
||||
m_info = failed ? NumericalIssue
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
bool ret = internal::bicgstab(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error);
|
||||
|
||||
m_info = (!ret) ? NumericalIssue
|
||||
: m_error <= Base::m_tolerance ? Success
|
||||
: NoConvergence;
|
||||
m_isInitialized = true;
|
||||
}
|
||||
|
||||
/** \internal */
|
||||
using Base::_solve_impl;
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
|
||||
{
|
||||
x.resize(this->rows(),b.cols());
|
||||
x.setZero();
|
||||
_solve_with_guess_impl(b,x);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -195,7 +195,7 @@ public:
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
typedef typename Base::MatrixWrapper MatrixWrapper;
|
||||
typedef typename Base::ActualMatrixType ActualMatrixType;
|
||||
@ -211,31 +211,14 @@ public:
|
||||
RowMajorWrapper,
|
||||
typename MatrixWrapper::template ConstSelfAdjointViewReturnType<UpLo>::Type
|
||||
>::type SelfAdjointWrapper;
|
||||
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
for(Index j=0; j<b.cols(); ++j)
|
||||
{
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
typename Dest::ColXpr xj(x,j);
|
||||
RowMajorWrapper row_mat(matrix());
|
||||
internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
|
||||
}
|
||||
|
||||
m_isInitialized = true;
|
||||
RowMajorWrapper row_mat(matrix());
|
||||
internal::conjugate_gradient(SelfAdjointWrapper(row_mat), b, x, Base::m_preconditioner, m_iterations, m_error);
|
||||
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
|
||||
}
|
||||
|
||||
/** \internal */
|
||||
using Base::_solve_impl;
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
|
||||
{
|
||||
x.setZero();
|
||||
_solve_with_guess_impl(b.derived(),x);
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
|
@ -331,7 +331,7 @@ public:
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs, typename DestDerived>
|
||||
void _solve_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
|
||||
void _solve_with_guess_impl(const Rhs& b, SparseMatrixBase<DestDerived> &aDest) const
|
||||
{
|
||||
eigen_assert(rows()==b.rows());
|
||||
|
||||
@ -344,15 +344,66 @@ public:
|
||||
// We do not directly fill dest because sparse expressions have to be free of aliasing issue.
|
||||
// For non square least-square problems, b and dest might not have the same size whereas they might alias each-other.
|
||||
typename DestDerived::PlainObject tmp(cols(),rhsCols);
|
||||
ComputationInfo global_info = Success;
|
||||
for(Index k=0; k<rhsCols; ++k)
|
||||
{
|
||||
tb = b.col(k);
|
||||
tx = derived().solve(tb);
|
||||
tx = dest.col(k);
|
||||
derived()._solve_vector_with_guess_impl(tb,tx);
|
||||
tmp.col(k) = tx.sparseView(0);
|
||||
|
||||
// The call to _solve_vector_with_guess_impl updates m_info, so if it failed for a previous column
|
||||
// we need to restore it to the worst value.
|
||||
if(m_info==NumericalIssue)
|
||||
global_info = NumericalIssue;
|
||||
else if(m_info==NoConvergence)
|
||||
global_info = NoConvergence;
|
||||
}
|
||||
m_info = global_info;
|
||||
dest.swap(tmp);
|
||||
}
|
||||
|
||||
template<typename Rhs, typename DestDerived>
|
||||
typename internal::enable_if<Rhs::ColsAtCompileTime!=1 && DestDerived::ColsAtCompileTime!=1>::type
|
||||
_solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &aDest) const
|
||||
{
|
||||
eigen_assert(rows()==b.rows());
|
||||
|
||||
Index rhsCols = b.cols();
|
||||
DestDerived& dest(aDest.derived());
|
||||
ComputationInfo global_info = Success;
|
||||
for(Index k=0; k<rhsCols; ++k)
|
||||
{
|
||||
typename DestDerived::ColXpr xk(dest,k);
|
||||
typename Rhs::ConstColXpr bk(b,k);
|
||||
derived()._solve_vector_with_guess_impl(bk,xk);
|
||||
|
||||
// The call to _solve_vector_with_guess updates m_info, so if it failed for a previous column
|
||||
// we need to restore it to the worst value.
|
||||
if(m_info==NumericalIssue)
|
||||
global_info = NumericalIssue;
|
||||
else if(m_info==NoConvergence)
|
||||
global_info = NoConvergence;
|
||||
}
|
||||
m_info = global_info;
|
||||
}
|
||||
|
||||
template<typename Rhs, typename DestDerived>
|
||||
typename internal::enable_if<Rhs::ColsAtCompileTime==1 || DestDerived::ColsAtCompileTime==1>::type
|
||||
_solve_with_guess_impl(const Rhs& b, MatrixBase<DestDerived> &dest) const
|
||||
{
|
||||
derived()._solve_vector_with_guess_impl(b,dest.derived());
|
||||
}
|
||||
|
||||
/** \internal default initial guess = 0 */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
x.resize(this->rows(),b.cols());
|
||||
x.setZero();
|
||||
derived()._solve_with_guess_impl(b,x);
|
||||
}
|
||||
|
||||
protected:
|
||||
void init()
|
||||
{
|
||||
|
@ -182,32 +182,14 @@ public:
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
for(Index j=0; j<b.cols(); ++j)
|
||||
{
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
typename Dest::ColXpr xj(x,j);
|
||||
internal::least_square_conjugate_gradient(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_error);
|
||||
}
|
||||
|
||||
m_isInitialized = true;
|
||||
internal::least_square_conjugate_gradient(matrix(), b, x, Base::m_preconditioner, m_iterations, m_error);
|
||||
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
|
||||
}
|
||||
|
||||
/** \internal */
|
||||
using Base::_solve_impl;
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_impl(const MatrixBase<Rhs>& b, Dest& x) const
|
||||
{
|
||||
x.setZero();
|
||||
_solve_with_guess_impl(b.derived(),x);
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
@ -109,6 +109,7 @@ class DGMRES : public IterativeSolverBase<DGMRES<_MatrixType,_Preconditioner> >
|
||||
using Base::m_tolerance;
|
||||
public:
|
||||
using Base::_solve_impl;
|
||||
using Base::_solve_with_guess_impl;
|
||||
typedef _MatrixType MatrixType;
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::StorageIndex StorageIndex;
|
||||
@ -141,30 +142,16 @@ class DGMRES : public IterativeSolverBase<DGMRES<_MatrixType,_Preconditioner> >
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
bool failed = false;
|
||||
for(Index j=0; j<b.cols(); ++j)
|
||||
{
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
typename Dest::ColXpr xj(x,j);
|
||||
dgmres(matrix(), b.col(j), xj, Base::m_preconditioner);
|
||||
}
|
||||
m_info = failed ? NumericalIssue
|
||||
: m_error <= Base::m_tolerance ? Success
|
||||
: NoConvergence;
|
||||
m_isInitialized = true;
|
||||
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
EIGEN_STATIC_ASSERT(Rhs::ColsAtCompileTime==1 || Dest::ColsAtCompileTime==1, YOU_TRIED_CALLING_A_VECTOR_METHOD_ON_A_MATRIX);
|
||||
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
dgmres(matrix(), b, x, Base::m_preconditioner);
|
||||
}
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_impl(const Rhs& b, MatrixBase<Dest>& x) const
|
||||
{
|
||||
x = b;
|
||||
_solve_with_guess_impl(b,x.derived());
|
||||
}
|
||||
/**
|
||||
* Get the restart value
|
||||
*/
|
||||
|
@ -64,6 +64,15 @@ bool gmres(const MatrixType & mat, const Rhs & rhs, Dest & x, const Precondition
|
||||
typedef Matrix < Scalar, Dynamic, 1 > VectorType;
|
||||
typedef Matrix < Scalar, Dynamic, Dynamic, ColMajor> FMatrixType;
|
||||
|
||||
const RealScalar considerAsZero = (std::numeric_limits<RealScalar>::min)();
|
||||
|
||||
if(rhs.norm() <= considerAsZero)
|
||||
{
|
||||
x.setZero();
|
||||
tol_error = 0;
|
||||
return true;
|
||||
}
|
||||
|
||||
RealScalar tol = tol_error;
|
||||
const Index maxIters = iters;
|
||||
iters = 0;
|
||||
@ -307,31 +316,14 @@ public:
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
bool failed = false;
|
||||
for(Index j=0; j<b.cols(); ++j)
|
||||
{
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
typename Dest::ColXpr xj(x,j);
|
||||
if(!internal::gmres(matrix(), b.col(j), xj, Base::m_preconditioner, m_iterations, m_restart, m_error))
|
||||
failed = true;
|
||||
}
|
||||
m_info = failed ? NumericalIssue
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
bool ret = internal::gmres(matrix(), b, x, Base::m_preconditioner, m_iterations, m_restart, m_error);
|
||||
m_info = (!ret) ? NumericalIssue
|
||||
: m_error <= Base::m_tolerance ? Success
|
||||
: NoConvergence;
|
||||
m_isInitialized = true;
|
||||
}
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_impl(const Rhs& b, MatrixBase<Dest> &x) const
|
||||
{
|
||||
x = b;
|
||||
if(x.squaredNorm() == 0) return; // Check Zero right hand side
|
||||
_solve_with_guess_impl(b,x.derived());
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -233,7 +233,7 @@ namespace Eigen {
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
void _solve_vector_with_guess_impl(const Rhs& b, Dest& x) const
|
||||
{
|
||||
typedef typename Base::MatrixWrapper MatrixWrapper;
|
||||
typedef typename Base::ActualMatrixType ActualMatrixType;
|
||||
@ -253,28 +253,11 @@ namespace Eigen {
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
RowMajorWrapper row_mat(matrix());
|
||||
for(int j=0; j<b.cols(); ++j)
|
||||
{
|
||||
m_iterations = Base::maxIterations();
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
typename Dest::ColXpr xj(x,j);
|
||||
internal::minres(SelfAdjointWrapper(row_mat), b.col(j), xj,
|
||||
Base::m_preconditioner, m_iterations, m_error);
|
||||
}
|
||||
|
||||
m_isInitialized = true;
|
||||
internal::minres(SelfAdjointWrapper(row_mat), b, x,
|
||||
Base::m_preconditioner, m_iterations, m_error);
|
||||
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
|
||||
}
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve_impl(const Rhs& b, MatrixBase<Dest> &x) const
|
||||
{
|
||||
x.setZero();
|
||||
_solve_with_guess_impl(b,x.derived());
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user