mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 11:19:02 +08:00
extend CG for multiple right hand sides
This commit is contained in:
parent
b94c00226f
commit
5dc8458293
@ -25,7 +25,7 @@
|
||||
#ifndef EIGEN_ITERATIVE_SOLVERS_MODULE_H
|
||||
#define EIGEN_ITERATIVE_SOLVERS_MODULE_H
|
||||
|
||||
#include <Eigen/Core>
|
||||
#include <Eigen/Sparse>
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
@ -42,6 +42,7 @@ namespace Eigen {
|
||||
//@{
|
||||
|
||||
#include "../../Eigen/src/misc/Solve.h"
|
||||
#include "src/SparseExtra/Solve.h"
|
||||
|
||||
#include "src/IterativeSolvers/IterativeSolverBase.h"
|
||||
#include "src/IterativeSolvers/IterationController.h"
|
||||
|
@ -45,7 +45,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
|
||||
using std::abs;
|
||||
typedef typename Dest::RealScalar RealScalar;
|
||||
typedef typename Dest::Scalar Scalar;
|
||||
typedef Dest VectorType;
|
||||
typedef Matrix<Scalar,Dynamic,1> VectorType;
|
||||
|
||||
RealScalar tol = tol_error;
|
||||
int maxIters = iters;
|
||||
@ -207,11 +207,15 @@ public:
|
||||
template<typename Rhs,typename Dest>
|
||||
void _solve(const Rhs& b, Dest& x) const
|
||||
{
|
||||
m_iterations = Base::m_maxIterations;
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b, x,
|
||||
Base::m_preconditioner, m_iterations, m_error);
|
||||
for(int j=0; j<b.cols(); ++j)
|
||||
{
|
||||
m_iterations = Base::m_maxIterations;
|
||||
m_error = Base::m_tolerance;
|
||||
|
||||
typename Dest::ColXpr xj(x,j);
|
||||
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
|
||||
Base::m_preconditioner, m_iterations, m_error);
|
||||
}
|
||||
|
||||
m_isInitialized = true;
|
||||
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
|
||||
@ -233,7 +237,7 @@ struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
|
||||
|
||||
template<typename Dest> void evalTo(Dest& dst) const
|
||||
{
|
||||
dst.setZero();
|
||||
dst.setOnes();
|
||||
dec()._solve(rhs(),dst);
|
||||
}
|
||||
};
|
||||
|
@ -146,6 +146,20 @@ public:
|
||||
&& "IterativeSolverBase::solve(): invalid number of rows of the right hand side matrix b");
|
||||
return internal::solve_retval<Derived, Rhs>(derived(), b.derived());
|
||||
}
|
||||
|
||||
/** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A.
|
||||
*
|
||||
* \sa compute()
|
||||
*/
|
||||
template<typename Rhs>
|
||||
inline const internal::sparse_solve_retval<IterativeSolverBase, Rhs>
|
||||
solve(const SparseMatrixBase<Rhs>& b) const
|
||||
{
|
||||
eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
|
||||
eigen_assert(rows()==b.rows()
|
||||
&& "IterativeSolverBase::solve(): invalid number of rows of the right hand side matrix b");
|
||||
return internal::sparse_solve_retval<IterativeSolverBase, Rhs>(*this, b.derived());
|
||||
}
|
||||
|
||||
/** \returns Success if the iterations converged, and NoConvergence otherwise. */
|
||||
ComputationInfo info() const
|
||||
@ -153,6 +167,24 @@ public:
|
||||
eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
|
||||
return m_info;
|
||||
}
|
||||
|
||||
/** \internal */
|
||||
template<typename Rhs, typename DestScalar, int DestOptions, typename DestIndex>
|
||||
void _solve_sparse(const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const
|
||||
{
|
||||
eigen_assert(rows()==b.rows());
|
||||
|
||||
int rhsCols = b.cols();
|
||||
int size = b.rows();
|
||||
Eigen::Matrix<DestScalar,Dynamic,1> tb(size);
|
||||
Eigen::Matrix<DestScalar,Dynamic,1> tx(size);
|
||||
for(int k=0; k<rhsCols; ++k)
|
||||
{
|
||||
tb = b.col(k);
|
||||
tx = derived().solve(tb);
|
||||
dest.col(k) = tx.sparseView(0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void init()
|
||||
@ -173,5 +205,21 @@ protected:
|
||||
mutable bool m_isInitialized;
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
||||
template<typename Derived, typename Rhs>
|
||||
struct sparse_solve_retval<IterativeSolverBase<Derived>, Rhs>
|
||||
: sparse_solve_retval_base<IterativeSolverBase<Derived>, Rhs>
|
||||
{
|
||||
typedef IterativeSolverBase<Derived> Dec;
|
||||
EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
|
||||
|
||||
template<typename Dest> void evalTo(Dest& dst) const
|
||||
{
|
||||
dec().derived()._solve_sparse(rhs(),dst);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif // EIGEN_ITERATIVE_SOLVER_BASE_H
|
||||
|
Loading…
x
Reference in New Issue
Block a user