Add autodiff coverage for standard library hyperbolic functions, and tests.

* * *
Corrected tanh derivatived, moved test definitions.
* * *
Added more test cases, removed lingering lines
This commit is contained in:
Geoffrey Lalonde 2016-06-15 23:33:19 -07:00
parent d7e3e4bb04
commit 72c95383e0
2 changed files with 50 additions and 0 deletions

View File

@ -646,6 +646,21 @@ EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(acos,
using std::acos; using std::acos;
return ReturnType(acos(x.value()),x.derivatives() * (Scalar(-1)/sqrt(1-numext::abs2(x.value()))));) return ReturnType(acos(x.value()),x.derivatives() * (Scalar(-1)/sqrt(1-numext::abs2(x.value()))));)
EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(tanh,
using std::cosh;
using std::tanh;
return ReturnType(tanh(x.value()),x.derivatives() * (Scalar(1)/numext::abs2(cosh(x.value()))));)
EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(sinh,
using std::sinh;
using std::cosh;
return ReturnType(sinh(x.value()),x.derivatives() * cosh(x.value()));)
EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY(cosh,
using std::sinh;
using std::cosh;
return ReturnType(cosh(x.value()),x.derivatives() * sinh(x.value()));)
#undef EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY #undef EIGEN_AUTODIFF_DECLARE_GLOBAL_UNARY
template<typename DerType> struct NumTraits<AutoDiffScalar<DerType> > template<typename DerType> struct NumTraits<AutoDiffScalar<DerType> >

View File

@ -36,13 +36,48 @@ template<typename Scalar> void check_atan2()
VERIFY_IS_APPROX(res.derivatives(), x.derivatives()); VERIFY_IS_APPROX(res.derivatives(), x.derivatives());
} }
template<typename Scalar> void check_hyperbolic_functions()
{
using std::sinh;
using std::cosh;
using std::tanh;
typedef Matrix<Scalar, 1, 1> Deriv1;
typedef AutoDiffScalar<Deriv1> AD;
Deriv1 p = Deriv1::Random();
AD val(p.x(),Deriv1::UnitX());
Scalar cosh_px = std::cosh(p.x());
AD res1 = tanh(val);
VERIFY_IS_APPROX(res1.value(), std::tanh(p.x()));
VERIFY_IS_APPROX(res1.derivatives().x(), Scalar(1.0) / (cosh_px * cosh_px));
AD res2 = sinh(val);
VERIFY_IS_APPROX(res2.value(), std::sinh(p.x()));
VERIFY_IS_APPROX(res2.derivatives().x(), cosh_px);
AD res3 = cosh(val);
VERIFY_IS_APPROX(res3.value(), cosh_px);
VERIFY_IS_APPROX(res3.derivatives().x(), std::sinh(p.x()));
// Check constant values.
const Scalar sample_point = Scalar(1) / Scalar(3);
val = AD(sample_point,Deriv1::UnitX());
res1 = tanh(val);
VERIFY_IS_APPROX(res1.derivatives().x(), Scalar(0.896629559604914));
res2 = sinh(val);
VERIFY_IS_APPROX(res2.derivatives().x(), Scalar(1.056071867829939));
res3 = cosh(val);
VERIFY_IS_APPROX(res3.derivatives().x(), Scalar(0.339540557256150));
}
void test_autodiff_scalar() void test_autodiff_scalar()
{ {
for(int i = 0; i < g_repeat; i++) { for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( check_atan2<float>() ); CALL_SUBTEST_1( check_atan2<float>() );
CALL_SUBTEST_2( check_atan2<double>() ); CALL_SUBTEST_2( check_atan2<double>() );
CALL_SUBTEST_3( check_hyperbolic_functions<float>() );
CALL_SUBTEST_4( check_hyperbolic_functions<double>() );
} }
} }