extend CG for multiple right hand sides

This commit is contained in:
Gael Guennebaud 2011-10-11 11:29:50 +02:00
parent b94c00226f
commit 5dc8458293
3 changed files with 61 additions and 8 deletions

View File

@ -25,7 +25,7 @@
#ifndef EIGEN_ITERATIVE_SOLVERS_MODULE_H
#define EIGEN_ITERATIVE_SOLVERS_MODULE_H
#include <Eigen/Core>
#include <Eigen/Sparse>
namespace Eigen {
@ -42,6 +42,7 @@ namespace Eigen {
//@{
#include "../../Eigen/src/misc/Solve.h"
#include "src/SparseExtra/Solve.h"
#include "src/IterativeSolvers/IterativeSolverBase.h"
#include "src/IterativeSolvers/IterationController.h"

View File

@ -45,7 +45,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
using std::abs;
typedef typename Dest::RealScalar RealScalar;
typedef typename Dest::Scalar Scalar;
typedef Dest VectorType;
typedef Matrix<Scalar,Dynamic,1> VectorType;
RealScalar tol = tol_error;
int maxIters = iters;
@ -207,11 +207,15 @@ public:
template<typename Rhs,typename Dest>
void _solve(const Rhs& b, Dest& x) const
{
m_iterations = Base::m_maxIterations;
m_error = Base::m_tolerance;
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b, x,
Base::m_preconditioner, m_iterations, m_error);
for(int j=0; j<b.cols(); ++j)
{
m_iterations = Base::m_maxIterations;
m_error = Base::m_tolerance;
typename Dest::ColXpr xj(x,j);
internal::conjugate_gradient(mp_matrix->template selfadjointView<UpLo>(), b.col(j), xj,
Base::m_preconditioner, m_iterations, m_error);
}
m_isInitialized = true;
m_info = m_error <= Base::m_tolerance ? Success : NoConvergence;
@ -233,7 +237,7 @@ struct solve_retval<ConjugateGradient<_MatrixType,_UpLo,_Preconditioner>, Rhs>
template<typename Dest> void evalTo(Dest& dst) const
{
dst.setZero();
dst.setOnes();
dec()._solve(rhs(),dst);
}
};

View File

@ -146,6 +146,20 @@ public:
&& "IterativeSolverBase::solve(): invalid number of rows of the right hand side matrix b");
return internal::solve_retval<Derived, Rhs>(derived(), b.derived());
}
/** \returns the solution x of \f$ A x = b \f$ using the current decomposition of A.
*
* \sa compute()
*/
template<typename Rhs>
inline const internal::sparse_solve_retval<IterativeSolverBase, Rhs>
solve(const SparseMatrixBase<Rhs>& b) const
{
eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
eigen_assert(rows()==b.rows()
&& "IterativeSolverBase::solve(): invalid number of rows of the right hand side matrix b");
return internal::sparse_solve_retval<IterativeSolverBase, Rhs>(*this, b.derived());
}
/** \returns Success if the iterations converged, and NoConvergence otherwise. */
ComputationInfo info() const
@ -153,6 +167,24 @@ public:
eigen_assert(m_isInitialized && "IterativeSolverBase is not initialized.");
return m_info;
}
/** \internal */
template<typename Rhs, typename DestScalar, int DestOptions, typename DestIndex>
void _solve_sparse(const Rhs& b, SparseMatrix<DestScalar,DestOptions,DestIndex> &dest) const
{
eigen_assert(rows()==b.rows());
int rhsCols = b.cols();
int size = b.rows();
Eigen::Matrix<DestScalar,Dynamic,1> tb(size);
Eigen::Matrix<DestScalar,Dynamic,1> tx(size);
for(int k=0; k<rhsCols; ++k)
{
tb = b.col(k);
tx = derived().solve(tb);
dest.col(k) = tx.sparseView(0);
}
}
protected:
void init()
@ -173,5 +205,21 @@ protected:
mutable bool m_isInitialized;
};
namespace internal {
template<typename Derived, typename Rhs>
struct sparse_solve_retval<IterativeSolverBase<Derived>, Rhs>
: sparse_solve_retval_base<IterativeSolverBase<Derived>, Rhs>
{
typedef IterativeSolverBase<Derived> Dec;
EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
template<typename Dest> void evalTo(Dest& dst) const
{
dec().derived()._solve_sparse(rhs(),dst);
}
};
}
#endif // EIGEN_ITERATIVE_SOLVER_BASE_H