mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-08 17:59:00 +08:00
Fix two corner cases in the new implementation of logistic sigmoid.
This commit is contained in:
parent
5d7ffe2ca9
commit
0b58738938
@ -1026,8 +1026,10 @@ struct scalar_logistic_op {
|
|||||||
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
Packet packetOp(const Packet& x) const {
|
Packet packetOp(const Packet& x) const {
|
||||||
const Packet one = pset1<Packet>(T(1));
|
const Packet one = pset1<Packet>(T(1));
|
||||||
|
const Packet inf = pset1<Packet>(NumTraits<T>::infinity());
|
||||||
const Packet e = pexp(x);
|
const Packet e = pexp(x);
|
||||||
return pdiv(e, padd(one, e));
|
const Packet inf_mask = pcmp_eq(e, inf);
|
||||||
|
return pselect(inf_mask, one, pdiv(e, padd(one, e)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1052,7 +1054,9 @@ template <>
|
|||||||
struct scalar_logistic_op<float> {
|
struct scalar_logistic_op<float> {
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
|
||||||
const float e = numext::exp(x);
|
// Truncate at the first point where the interpolant is exactly one.
|
||||||
|
const float cst_exp_hi = 16.6355324f;
|
||||||
|
const float e = numext::exp(numext::mini(x, cst_exp_hi));
|
||||||
return e / (1.0f + e);
|
return e / (1.0f + e);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1062,7 +1066,8 @@ struct scalar_logistic_op<float> {
|
|||||||
const Packet cst_zero = pset1<Packet>(0.0f);
|
const Packet cst_zero = pset1<Packet>(0.0f);
|
||||||
const Packet cst_one = pset1<Packet>(1.0f);
|
const Packet cst_one = pset1<Packet>(1.0f);
|
||||||
const Packet cst_half = pset1<Packet>(0.5f);
|
const Packet cst_half = pset1<Packet>(0.5f);
|
||||||
const Packet cst_exp_hi = pset1<Packet>(16.f);
|
// Truncate at the first point where the interpolant is exactly one.
|
||||||
|
const Packet cst_exp_hi = pset1<Packet>(16.6355324f);
|
||||||
const Packet cst_exp_lo = pset1<Packet>(-104.f);
|
const Packet cst_exp_lo = pset1<Packet>(-104.f);
|
||||||
|
|
||||||
// Clamp x to the non-trivial range where S(x). Outside this
|
// Clamp x to the non-trivial range where S(x). Outside this
|
||||||
@ -1117,6 +1122,7 @@ struct scalar_logistic_op<float> {
|
|||||||
};
|
};
|
||||||
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE
|
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct functor_traits<scalar_logistic_op<T> > {
|
struct functor_traits<scalar_logistic_op<T> > {
|
||||||
enum {
|
enum {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user