mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 04:09:10 +08:00
extend CG for multiple right hand sides
This commit is contained in:
parent
b94c00226f
commit
5dc8458293
@ -25,7 +25,7 @@
|
|||||||
#ifndef EIGEN_ITERATIVE_SOLVERS_MODULE_H
|
#ifndef EIGEN_ITERATIVE_SOLVERS_MODULE_H
|
||||||
#define EIGEN_ITERATIVE_SOLVERS_MODULE_H
|
#define EIGEN_ITERATIVE_SOLVERS_MODULE_H
|
||||||
|
|
||||||
#include <Eigen/Core>
|
#include <Eigen/Sparse>
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
@ -42,6 +42,7 @@ namespace Eigen {
|
|||||||
//@{
|
//@{
|
||||||
|
|
||||||
#include "../../Eigen/src/misc/Solve.h"
|
#include "../../Eigen/src/misc/Solve.h"
|
||||||
|
#include "src/SparseExtra/Solve.h"
|
||||||
|
|
||||||
#include "src/IterativeSolvers/IterativeSolverBase.h"
|
#include "src/IterativeSolvers/IterativeSolverBase.h"
|
||||||
#include "src/IterativeSolvers/IterationController.h"
|
#include "src/IterativeSolvers/IterationController.h"
|
||||||
|
@ -45,7 +45,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
|
|||||||
using std::abs;
|
using std::abs;
|
||||||
typedef typename Dest::RealScalar RealScalar;
|
typedef typename Dest::RealScalar RealScalar;
|
||||||
typedef typename Dest::Scalar Scalar;
|
typedef typename Dest::Scalar Scalar;
|
||||||
typedef Dest VectorType;
|
typedef Matrix<Scalar,Dynamic,1> VectorType;
|
||||||
|
|
||||||
RealScalar tol = tol_error;
|
RealScalar tol = tol_error;
|
||||||
int maxIters = iters;
|
int maxIters = iters;
|
||||||
@ -206,12 +206,16 @@ public:
|
|||||||
/** \internal */
|
/** \internal */
|
||||||
template<typename Rhs,typename Dest>
|
template<typename Rhs,typename Dest>
|
||||||
void _solve(const Rhs& b, Dest& x) const
|
void _solve(const Rhs& b, Dest& x) const
|
||||||
|
{
|
||||||
|
for(int j=0; j<b.cols(); ++j)
|
||||||
{
|
{
|
||||||
m_iterations = Base::m_maxIterations;
|
m_iterations = Base::m_maxIterations;
|
||||||
m_error = Base::m_tolerance;
|
m_error = Base::m_tolerance;
|
||||||
|
|
||||||
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b, x,
|
typename Dest::ColXpr xj(x,j);
|
||||||
|
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
|
||||||
Base::m_preconditioner, m_iterations, m_error);
|
Base::m_preconditioner, m_iterations, m_error);
|
||||||
|
}
|
||||||
|
|
||||||
m_isInitialized = true;
|
m_isInitialized = true;
|
||||||
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
|
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
|
||||||
@ -233,7 +237,7 @@ struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
|
|||||||
|
|
||||||
template<typename Dest> void evalTo(Dest& dst) const
|
template<typename Dest> void evalTo(Dest& dst) const
|
||||||
{
|
{
|
||||||
dst.setZero();
|
dst.setOnes();
|
||||||
dec()._solve(rhs(),dst);
|
dec()._solve(rhs(),dst);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -147,6 +147,20 @@ public:
|
|||||||
return internal::solve_retval<Derived, Rhs>(derived(), b.derived());
|
return internal::solve_retval<Derived, Rhs>(derived(), b.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A.
|
||||||
|
*
|
||||||
|
* \sa compute()
|
||||||
|
*/
|
||||||
|
template<typename Rhs>
|
||||||
|
inline const internal::sparse_solve_retval<IterativeSolverBase, Rhs>
|
||||||
|
solve(const SparseMatrixBase<Rhs>& b) const
|
||||||
|
{
|
||||||
|
eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
|
||||||
|
eigen_assert(rows()==b.rows()
|
||||||
|
&& "IterativeSolverBase::solve(): invalid number of rows of the right hand side matrix b");
|
||||||
|
return internal::sparse_solve_retval<IterativeSolverBase, Rhs>(*this, b.derived());
|
||||||
|
}
|
||||||
|
|
||||||
/** \returns Success if the iterations converged, and NoConvergence otherwise. */
|
/** \returns Success if the iterations converged, and NoConvergence otherwise. */
|
||||||
ComputationInfo info() const
|
ComputationInfo info() const
|
||||||
{
|
{
|
||||||
@ -154,6 +168,24 @@ public:
|
|||||||
return m_info;
|
return m_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \internal */
|
||||||
|
template<typename Rhs, typename DestScalar, int DestOptions, typename DestIndex>
|
||||||
|
void _solve_sparse(const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const
|
||||||
|
{
|
||||||
|
eigen_assert(rows()==b.rows());
|
||||||
|
|
||||||
|
int rhsCols = b.cols();
|
||||||
|
int size = b.rows();
|
||||||
|
Eigen::Matrix<DestScalar,Dynamic,1> tb(size);
|
||||||
|
Eigen::Matrix<DestScalar,Dynamic,1> tx(size);
|
||||||
|
for(int k=0; k<rhsCols; ++k)
|
||||||
|
{
|
||||||
|
tb = b.col(k);
|
||||||
|
tx = derived().solve(tb);
|
||||||
|
dest.col(k) = tx.sparseView(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void init()
|
void init()
|
||||||
{
|
{
|
||||||
@ -173,5 +205,21 @@ protected:
|
|||||||
mutable bool m_isInitialized;
|
mutable bool m_isInitialized;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename Derived, typename Rhs>
|
||||||
|
struct sparse_solve_retval<IterativeSolverBase<Derived>, Rhs>
|
||||||
|
: sparse_solve_retval_base<IterativeSolverBase<Derived>, Rhs>
|
||||||
|
{
|
||||||
|
typedef IterativeSolverBase<Derived> Dec;
|
||||||
|
EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
|
||||||
|
|
||||||
|
template<typename Dest> void evalTo(Dest& dst) const
|
||||||
|
{
|
||||||
|
dec().derived()._solve_sparse(rhs(),dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
#endif // EIGEN_ITERATIVE_SOLVER_BASE_H
|
#endif // EIGEN_ITERATIVE_SOLVER_BASE_H
|
||||||
|
Loading…
x
Reference in New Issue
Block a user