Fix two corner cases in the new implementation of logistic sigmoid.

This commit is contained in:
Rasmus Munk Larsen 2022-01-12 00:41:29 +00:00
parent 5d7ffe2ca9
commit 0b58738938

View File

@ -1026,8 +1026,10 @@ struct scalar_logistic_op {
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet packetOp(const Packet& x) const { Packet packetOp(const Packet& x) const {
const Packet one = pset1<Packet>(T(1)); const Packet one = pset1<Packet>(T(1));
const Packet inf = pset1<Packet>(NumTraits<T>::infinity());
const Packet e = pexp(x); const Packet e = pexp(x);
return pdiv(e, padd(one, e)); const Packet inf_mask = pcmp_eq(e, inf);
return pselect(inf_mask, one, pdiv(e, padd(one, e)));
} }
}; };
@ -1052,7 +1054,9 @@ template <>
struct scalar_logistic_op<float> { struct scalar_logistic_op<float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
const float e = numext::exp(x); // Truncate at the first point where the interpolant is exactly one.
const float cst_exp_hi = 16.6355324f;
const float e = numext::exp(numext::mini(x, cst_exp_hi));
return e / (1.0f + e); return e / (1.0f + e);
} }
@ -1062,7 +1066,8 @@ struct scalar_logistic_op<float> {
const Packet cst_zero = pset1<Packet>(0.0f); const Packet cst_zero = pset1<Packet>(0.0f);
const Packet cst_one = pset1<Packet>(1.0f); const Packet cst_one = pset1<Packet>(1.0f);
const Packet cst_half = pset1<Packet>(0.5f); const Packet cst_half = pset1<Packet>(0.5f);
const Packet cst_exp_hi = pset1<Packet>(16.f); // Truncate at the first point where the interpolant is exactly one.
const Packet cst_exp_hi = pset1<Packet>(16.6355324f);
const Packet cst_exp_lo = pset1<Packet>(-104.f); const Packet cst_exp_lo = pset1<Packet>(-104.f);
// Clamp x to the non-trivial range where S(x). Outside this // Clamp x to the non-trivial range where S(x). Outside this
@ -1117,6 +1122,7 @@ struct scalar_logistic_op<float> {
}; };
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE #endif // #ifndef EIGEN_GPU_COMPILE_PHASE
template <typename T> template <typename T>
struct functor_traits<scalar_logistic_op<T> > { struct functor_traits<scalar_logistic_op<T> > {
enum { enum {