central sheme for numerical diff

This commit is contained in:
Thomas Capricelli 2009-09-28 02:55:30 +02:00
parent 206b5e3972
commit 87be19de4a
2 changed files with 59 additions and 11 deletions

View File

@ -28,7 +28,13 @@
namespace Eigen
{
template<typename Functor> class NumericalDiff : public Functor
enum NumericalDiffMode {
Forward,
Central
};
template<typename Functor, NumericalDiffMode mode=Forward> class NumericalDiff : public Functor
{
public:
typedef typename Functor::Scalar Scalar;
@ -62,14 +68,23 @@ public:
int nfev=0;
const int n = _x.size();
const Scalar eps = ei_sqrt((std::max(epsfcn,epsilon<Scalar>() )));
ValueType val, fx;
ValueType val1, val2;
InputType x = _x;
// TODO : we should do this only if the size is not already known
val.resize(Functor::values());
fx.resize(Functor::values());
val1.resize(Functor::values());
val2.resize(Functor::values());
switch(mode) {
case Forward:
// compute f(x)
Functor::operator()(x, fx);
Functor::operator()(x, val1); nfev++;
break;
case Central:
// do nothing
break;
default:
assert(false);
};
/* Function Body */
@ -78,11 +93,25 @@ public:
if (h == 0.) {
h = eps;
}
switch(mode) {
case Forward:
x[j] += h;
Functor::operator()(x, val);
Functor::operator()(x, val2);
nfev++;
x[j] = _x[j];
jac.col(j) = (val-fx)/h;
jac.col(j) = (val2-val1)/h;
break;
case Central:
x[j] += h;
Functor::operator()(x, val2); nfev++;
x[j] -= 2*h;
Functor::operator()(x, val1); nfev++;
x[j] = _x[j];
jac.col(j) = (val2-val1)/(2*h);
break;
default:
assert(false);
};
}
return nfev;
}

View File

@ -88,8 +88,27 @@ void test_forward()
VERIFY_IS_APPROX(jac, actual_jac);
}
void test_central()
{
VectorXd x(3);
MatrixXd jac(15,3);
MatrixXd actual_jac(15,3);
my_functor functor;
x << 0.082, 1.13, 2.35;
// real one
functor.df(x, actual_jac);
// using NumericalDiff
NumericalDiff<my_functor,Central> numDiff(functor);
numDiff.df(x, jac);
VERIFY_IS_APPROX(jac, actual_jac);
}
void test_NumericalDiff()
{
CALL_SUBTEST(test_forward());
CALL_SUBTEST(test_central());
}