central sheme for numerical diff

This commit is contained in:
Thomas Capricelli 2009-09-28 02:55:30 +02:00
parent 206b5e3972
commit 87be19de4a
2 changed files with 59 additions and 11 deletions

View File

@ -28,7 +28,13 @@
namespace Eigen namespace Eigen
{ {
template<typename Functor> class NumericalDiff : public Functor enum NumericalDiffMode {
Forward,
Central
};
template<typename Functor, NumericalDiffMode mode=Forward> class NumericalDiff : public Functor
{ {
public: public:
typedef typename Functor::Scalar Scalar; typedef typename Functor::Scalar Scalar;
@ -62,14 +68,23 @@ public:
int nfev=0; int nfev=0;
const int n = _x.size(); const int n = _x.size();
const Scalar eps = ei_sqrt((std::max(epsfcn,epsilon<Scalar>() ))); const Scalar eps = ei_sqrt((std::max(epsfcn,epsilon<Scalar>() )));
ValueType val, fx; ValueType val1, val2;
InputType x = _x; InputType x = _x;
// TODO : we should do this only if the size is not already known // TODO : we should do this only if the size is not already known
val.resize(Functor::values()); val1.resize(Functor::values());
fx.resize(Functor::values()); val2.resize(Functor::values());
// compute f(x) switch(mode) {
Functor::operator()(x, fx); case Forward:
// compute f(x)
Functor::operator()(x, val1); nfev++;
break;
case Central:
// do nothing
break;
default:
assert(false);
};
/* Function Body */ /* Function Body */
@ -78,11 +93,25 @@ public:
if (h == 0.) { if (h == 0.) {
h = eps; h = eps;
} }
x[j] += h; switch(mode) {
Functor::operator()(x, val); case Forward:
nfev++; x[j] += h;
x[j] = _x[j]; Functor::operator()(x, val2);
jac.col(j) = (val-fx)/h; nfev++;
x[j] = _x[j];
jac.col(j) = (val2-val1)/h;
break;
case Central:
x[j] += h;
Functor::operator()(x, val2); nfev++;
x[j] -= 2*h;
Functor::operator()(x, val1); nfev++;
x[j] = _x[j];
jac.col(j) = (val2-val1)/(2*h);
break;
default:
assert(false);
};
} }
return nfev; return nfev;
} }

View File

@ -88,8 +88,27 @@ void test_forward()
VERIFY_IS_APPROX(jac, actual_jac); VERIFY_IS_APPROX(jac, actual_jac);
} }
void test_central()
{
VectorXd x(3);
MatrixXd jac(15,3);
MatrixXd actual_jac(15,3);
my_functor functor;
x << 0.082, 1.13, 2.35;
// real one
functor.df(x, actual_jac);
// using NumericalDiff
NumericalDiff<my_functor,Central> numDiff(functor);
numDiff.df(x, jac);
VERIFY_IS_APPROX(jac, actual_jac);
}
void test_NumericalDiff() void test_NumericalDiff()
{ {
CALL_SUBTEST(test_forward()); CALL_SUBTEST(test_forward());
CALL_SUBTEST(test_central());
} }