Fix stopping criteria of CG

This commit is contained in:
Gael Guennebaud 2012-06-06 17:11:16 +02:00
parent b9f0eabd93
commit fd32697074

View File

@ -54,6 +54,7 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
int maxIters = iters; int maxIters = iters;
int n = mat.cols(); int n = mat.cols();
VectorType residual = rhs - mat * x; //initial residual VectorType residual = rhs - mat * x; //initial residual
VectorType p(n); VectorType p(n);
@ -61,26 +62,31 @@ void conjugate_gradient(const MatrixType& mat, const Rhs& rhs, Dest& x,
VectorType z(n), tmp(n); VectorType z(n), tmp(n);
RealScalar absNew = internal::real(residual.dot(p)); // the square of the absolute value of r scaled by invM RealScalar absNew = internal::real(residual.dot(p)); // the square of the absolute value of r scaled by invM
RealScalar absInit = absNew; // the initial absolute value RealScalar rhsNorm2 = rhs.squaredNorm();
RealScalar residualNorm2 = 0;
RealScalar threshold = tol*tol*rhsNorm2;
int i = 0; int i = 0;
while ((i < maxIters) && (absNew > tol*tol*absInit)) while(i < maxIters)
{ {
tmp.noalias() = mat * p; // the bottleneck of the algorithm tmp.noalias() = mat * p; // the bottleneck of the algorithm
Scalar alpha = absNew / p.dot(tmp); // the amount we travel on dir Scalar alpha = absNew / p.dot(tmp); // the amount we travel on dir
x += alpha * p; // update solution x += alpha * p; // update solution
residual -= alpha * tmp; // update residue residual -= alpha * tmp; // update residue
residualNorm2 = residual.squaredNorm();
if(residualNorm2 < threshold)
break;
z = precond.solve(residual); // approximately solve for "A z = residual" z = precond.solve(residual); // approximately solve for "A z = residual"
RealScalar absOld = absNew; RealScalar absOld = absNew;
absNew = internal::real(residual.dot(z)); // update the absolute value of r absNew = internal::real(residual.dot(z)); // update the absolute value of r
RealScalar beta = absNew / absOld; // calculate the Gram-Schmidit value used to create the new search direction RealScalar beta = absNew / absOld; // calculate the Gram-Schmidt value used to create the new search direction
p = z + beta * p; // update search direction p = z + beta * p; // update search direction
i++; i++;
} }
tol_error = sqrt(residualNorm2 / rhsNorm2);
tol_error = sqrt(abs(absNew / absInit));
iters = i; iters = i;
} }