mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 20:56:00 +08:00
central sheme for numerical diff
This commit is contained in:
parent
206b5e3972
commit
87be19de4a
@ -28,7 +28,13 @@
|
|||||||
namespace Eigen
|
namespace Eigen
|
||||||
{
|
{
|
||||||
|
|
||||||
template<typename Functor> class NumericalDiff : public Functor
|
enum NumericalDiffMode {
|
||||||
|
Forward,
|
||||||
|
Central
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template<typename Functor, NumericalDiffMode mode=Forward> class NumericalDiff : public Functor
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef typename Functor::Scalar Scalar;
|
typedef typename Functor::Scalar Scalar;
|
||||||
@ -62,14 +68,23 @@ public:
|
|||||||
int nfev=0;
|
int nfev=0;
|
||||||
const int n = _x.size();
|
const int n = _x.size();
|
||||||
const Scalar eps = ei_sqrt((std::max(epsfcn,epsilon<Scalar>() )));
|
const Scalar eps = ei_sqrt((std::max(epsfcn,epsilon<Scalar>() )));
|
||||||
ValueType val, fx;
|
ValueType val1, val2;
|
||||||
InputType x = _x;
|
InputType x = _x;
|
||||||
// TODO : we should do this only if the size is not already known
|
// TODO : we should do this only if the size is not already known
|
||||||
val.resize(Functor::values());
|
val1.resize(Functor::values());
|
||||||
fx.resize(Functor::values());
|
val2.resize(Functor::values());
|
||||||
|
|
||||||
// compute f(x)
|
switch(mode) {
|
||||||
Functor::operator()(x, fx);
|
case Forward:
|
||||||
|
// compute f(x)
|
||||||
|
Functor::operator()(x, val1); nfev++;
|
||||||
|
break;
|
||||||
|
case Central:
|
||||||
|
// do nothing
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
};
|
||||||
|
|
||||||
/* Function Body */
|
/* Function Body */
|
||||||
|
|
||||||
@ -78,11 +93,25 @@ public:
|
|||||||
if (h == 0.) {
|
if (h == 0.) {
|
||||||
h = eps;
|
h = eps;
|
||||||
}
|
}
|
||||||
x[j] += h;
|
switch(mode) {
|
||||||
Functor::operator()(x, val);
|
case Forward:
|
||||||
nfev++;
|
x[j] += h;
|
||||||
x[j] = _x[j];
|
Functor::operator()(x, val2);
|
||||||
jac.col(j) = (val-fx)/h;
|
nfev++;
|
||||||
|
x[j] = _x[j];
|
||||||
|
jac.col(j) = (val2-val1)/h;
|
||||||
|
break;
|
||||||
|
case Central:
|
||||||
|
x[j] += h;
|
||||||
|
Functor::operator()(x, val2); nfev++;
|
||||||
|
x[j] -= 2*h;
|
||||||
|
Functor::operator()(x, val1); nfev++;
|
||||||
|
x[j] = _x[j];
|
||||||
|
jac.col(j) = (val2-val1)/(2*h);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(false);
|
||||||
|
};
|
||||||
}
|
}
|
||||||
return nfev;
|
return nfev;
|
||||||
}
|
}
|
||||||
|
@ -88,8 +88,27 @@ void test_forward()
|
|||||||
VERIFY_IS_APPROX(jac, actual_jac);
|
VERIFY_IS_APPROX(jac, actual_jac);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_central()
|
||||||
|
{
|
||||||
|
VectorXd x(3);
|
||||||
|
MatrixXd jac(15,3);
|
||||||
|
MatrixXd actual_jac(15,3);
|
||||||
|
my_functor functor;
|
||||||
|
|
||||||
|
x << 0.082, 1.13, 2.35;
|
||||||
|
|
||||||
|
// real one
|
||||||
|
functor.df(x, actual_jac);
|
||||||
|
|
||||||
|
// using NumericalDiff
|
||||||
|
NumericalDiff<my_functor,Central> numDiff(functor);
|
||||||
|
numDiff.df(x, jac);
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX(jac, actual_jac);
|
||||||
|
}
|
||||||
|
|
||||||
void test_NumericalDiff()
|
void test_NumericalDiff()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_forward());
|
CALL_SUBTEST(test_forward());
|
||||||
|
CALL_SUBTEST(test_central());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user