diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 72d485814..1092fee50 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1026,8 +1026,10 @@ struct scalar_logistic_op { template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { const Packet one = pset1(T(1)); + const Packet inf = pset1(NumTraits::infinity()); const Packet e = pexp(x); - return pdiv(e, padd(one, e)); + const Packet inf_mask = pcmp_eq(e, inf); + return pselect(inf_mask, one, pdiv(e, padd(one, e))); } }; @@ -1052,7 +1054,9 @@ template <> struct scalar_logistic_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const { - const float e = numext::exp(x); + // Truncate at the first point where the interpolant is exactly one. + const float cst_exp_hi = 16.6355324f; + const float e = numext::exp(numext::mini(x, cst_exp_hi)); return e / (1.0f + e); } @@ -1062,7 +1066,8 @@ struct scalar_logistic_op { const Packet cst_zero = pset1(0.0f); const Packet cst_one = pset1(1.0f); const Packet cst_half = pset1(0.5f); - const Packet cst_exp_hi = pset1(16.f); + // Truncate at the first point where the interpolant is exactly one. + const Packet cst_exp_hi = pset1(16.6355324f); const Packet cst_exp_lo = pset1(-104.f); // Clamp x to the non-trivial range where S(x). Outside this @@ -1117,6 +1122,7 @@ struct scalar_logistic_op { }; #endif // #ifndef EIGEN_GPU_COMPILE_PHASE + template struct functor_traits > { enum {