mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 19:59:05 +08:00
central sheme for numerical diff
This commit is contained in:
parent
206b5e3972
commit
87be19de4a
@ -28,7 +28,13 @@
|
||||
namespace Eigen
|
||||
{
|
||||
|
||||
template<typename Functor> class NumericalDiff : public Functor
|
||||
enum NumericalDiffMode {
|
||||
Forward,
|
||||
Central
|
||||
};
|
||||
|
||||
|
||||
template<typename Functor, NumericalDiffMode mode=Forward> class NumericalDiff : public Functor
|
||||
{
|
||||
public:
|
||||
typedef typename Functor::Scalar Scalar;
|
||||
@ -62,14 +68,23 @@ public:
|
||||
int nfev=0;
|
||||
const int n = _x.size();
|
||||
const Scalar eps = ei_sqrt((std::max(epsfcn,epsilon<Scalar>() )));
|
||||
ValueType val, fx;
|
||||
ValueType val1, val2;
|
||||
InputType x = _x;
|
||||
// TODO : we should do this only if the size is not already known
|
||||
val.resize(Functor::values());
|
||||
fx.resize(Functor::values());
|
||||
val1.resize(Functor::values());
|
||||
val2.resize(Functor::values());
|
||||
|
||||
// compute f(x)
|
||||
Functor::operator()(x, fx);
|
||||
switch(mode) {
|
||||
case Forward:
|
||||
// compute f(x)
|
||||
Functor::operator()(x, val1); nfev++;
|
||||
break;
|
||||
case Central:
|
||||
// do nothing
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
};
|
||||
|
||||
/* Function Body */
|
||||
|
||||
@ -78,11 +93,25 @@ public:
|
||||
if (h == 0.) {
|
||||
h = eps;
|
||||
}
|
||||
x[j] += h;
|
||||
Functor::operator()(x, val);
|
||||
nfev++;
|
||||
x[j] = _x[j];
|
||||
jac.col(j) = (val-fx)/h;
|
||||
switch(mode) {
|
||||
case Forward:
|
||||
x[j] += h;
|
||||
Functor::operator()(x, val2);
|
||||
nfev++;
|
||||
x[j] = _x[j];
|
||||
jac.col(j) = (val2-val1)/h;
|
||||
break;
|
||||
case Central:
|
||||
x[j] += h;
|
||||
Functor::operator()(x, val2); nfev++;
|
||||
x[j] -= 2*h;
|
||||
Functor::operator()(x, val1); nfev++;
|
||||
x[j] = _x[j];
|
||||
jac.col(j) = (val2-val1)/(2*h);
|
||||
break;
|
||||
default:
|
||||
assert(false);
|
||||
};
|
||||
}
|
||||
return nfev;
|
||||
}
|
||||
|
@ -88,8 +88,27 @@ void test_forward()
|
||||
VERIFY_IS_APPROX(jac, actual_jac);
|
||||
}
|
||||
|
||||
void test_central()
|
||||
{
|
||||
VectorXd x(3);
|
||||
MatrixXd jac(15,3);
|
||||
MatrixXd actual_jac(15,3);
|
||||
my_functor functor;
|
||||
|
||||
x << 0.082, 1.13, 2.35;
|
||||
|
||||
// real one
|
||||
functor.df(x, actual_jac);
|
||||
|
||||
// using NumericalDiff
|
||||
NumericalDiff<my_functor,Central> numDiff(functor);
|
||||
numDiff.df(x, jac);
|
||||
|
||||
VERIFY_IS_APPROX(jac, actual_jac);
|
||||
}
|
||||
|
||||
void test_NumericalDiff()
|
||||
{
|
||||
CALL_SUBTEST(test_forward());
|
||||
CALL_SUBTEST(test_central());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user