From fd326970747c828b2053d5a378e6f16f2803f540 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 6 Jun 2012 17:11:16 +0200 Subject: [PATCH] Fix stopping criteria of CG --- .../IterativeLinearSolvers/ConjugateGradient.h | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h b/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h index 8f74d1b91..edab2299e 100644 --- a/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h +++ b/Eigen/src/IterativeLinearSolvers/ConjugateGradient.h @@ -54,6 +54,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, int maxIters = iters; int n = mat.cols(); + VectorType residual = rhs - mat * x; //initial residual VectorType p(n); @@ -61,26 +62,31 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x, VectorType z(n), tmp(n); RealScalar absNew = internal::real(residual.dot(p)); // the square of the absolute value of r scaled by invM - RealScalar absInit = absNew; // the initial absolute value - + RealScalar rhsNorm2 = rhs.squaredNorm(); + RealScalar residualNorm2 = 0; + RealScalar threshold = tol*tol*rhsNorm2; int i = 0; - while ((i < maxIters) && (absNew > tol*tol*absInit)) + while(i < maxIters) { tmp.noalias() = mat * p; // the bottleneck of the algorithm Scalar alpha = absNew / p.dot(tmp); // the amount we travel on dir x += alpha * p; // update solution residual -= alpha * tmp; // update residue + + residualNorm2 = residual.squaredNorm(); + if(residualNorm2 < threshold) + break; + z = precond.solve(residual); // approximately solve for "A z = residual" RealScalar absOld = absNew; absNew = internal::real(residual.dot(z)); // update the absolute value of r - RealScalar beta = absNew / absOld; // calculate the Gram-Schmidit value used to create the new search direction + RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction p = z + beta * p; // update search direction i++; } - - tol_error = sqrt(abs(absNew / absInit)); + tol_error = sqrt(residualNorm2 / rhsNorm2); iters = i; }