mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 09:23:12 +08:00
Fix accuracy of logistic sigmoid
This commit is contained in:
parent
8b8125c574
commit
80ccacc717
@ -1026,83 +1026,93 @@ struct scalar_logistic_op {
|
|||||||
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
Packet packetOp(const Packet& x) const {
|
Packet packetOp(const Packet& x) const {
|
||||||
const Packet one = pset1<Packet>(T(1));
|
const Packet one = pset1<Packet>(T(1));
|
||||||
return pdiv(one, padd(one, pexp(pnegate(x))));
|
const Packet e = pexp(x);
|
||||||
|
return pdiv(e, padd(one, e));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifndef EIGEN_GPU_COMPILE_PHASE
|
#ifndef EIGEN_GPU_COMPILE_PHASE
|
||||||
|
|
||||||
/** \internal
|
/** \internal
|
||||||
* \brief Template specialization of the logistic function for float.
|
* \brief Template specialization of the logistic function for float.
|
||||||
*
|
* Computes S(x) = exp(x) / (1 + exp(x)), where exp(x) is implemented
|
||||||
* Uses just a 9/10-degree rational interpolant which
|
* using an algorithm partly adopted from the implementation of
|
||||||
* interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range
|
* pexp_float. See the individual steps described in the code below.
|
||||||
* [-9, 18]. Below -9 we use the more accurate approximation
|
* Note that compared to pexp, we use an additional outer multiplicative
|
||||||
* 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 within
|
* range reduction step using the identity exp(x) = exp(x/2)^2.
|
||||||
* one ulp. The shifted logistic is interpolated because it was easier to
|
* This prevert us from having to call ldexp on values that could produce
|
||||||
* make the fit converge.
|
* a denormal result, which allows us to call the faster implementation in
|
||||||
*
|
* pldexp_fast_impl<Packet>::run(p, m).
|
||||||
|
* The final squaring, however, doubles the error bound on the final
|
||||||
|
* approximation. Exhaustive testing shows that we have a worst case error
|
||||||
|
* of 4.5 ulps (compared to computing S(x) in double precision), which is
|
||||||
|
* acceptable.
|
||||||
*/
|
*/
|
||||||
template <>
|
template <>
|
||||||
struct scalar_logistic_op<float> {
|
struct scalar_logistic_op<float> {
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
|
||||||
return packetOp(x);
|
const float e = numext::exp(x);
|
||||||
|
return e / (1.0f + e);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename Packet>
|
||||||
Packet packetOp(const Packet& _x) const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
|
||||||
const Packet cutoff_lower = pset1<Packet>(-9.f);
|
packetOp(const Packet& _x) const {
|
||||||
const Packet lt_mask = pcmp_lt<Packet>(_x, cutoff_lower);
|
const Packet cst_zero = pset1<Packet>(0.0f);
|
||||||
const bool any_small = predux_any(lt_mask);
|
const Packet cst_one = pset1<Packet>(1.0f);
|
||||||
|
const Packet cst_half = pset1<Packet>(0.5f);
|
||||||
|
const Packet cst_exp_hi = pset1<Packet>(16.f);
|
||||||
|
const Packet cst_exp_lo = pset1<Packet>(-104.f);
|
||||||
|
|
||||||
// The upper cut-off is the smallest x for which the rational approximation evaluates to 1.
|
// Clamp x to the non-trivial range where S(x). Outside this
|
||||||
// Choosing this value saves us a few instructions clamping the results at the end.
|
// interval the correctly rounded value of S(x) is either zero
|
||||||
#ifdef EIGEN_VECTORIZE_FMA
|
// or one.
|
||||||
const Packet cutoff_upper = pset1<Packet>(15.7243833541870117f);
|
Packet zero_mask = pcmp_lt(_x, cst_exp_lo);
|
||||||
#else
|
Packet x = pmin(_x, cst_exp_hi);
|
||||||
const Packet cutoff_upper = pset1<Packet>(15.6437711715698242f);
|
|
||||||
#endif
|
|
||||||
const Packet x = pmin(_x, cutoff_upper);
|
|
||||||
|
|
||||||
// The monomial coefficients of the numerator polynomial (odd).
|
// 1. Multiplicative range reduction:
|
||||||
const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01f);
|
// Reduce the range of x by a factor of 2. This avoids having
|
||||||
const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03f);
|
// to compute exp(x) accurately where the result is a denormalized
|
||||||
const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05f);
|
// value.
|
||||||
const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07f);
|
x = pmul(x, cst_half);
|
||||||
const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11f);
|
|
||||||
|
|
||||||
// The monomial coefficients of the denominator polynomial (even).
|
// 2. Subtractive range reduction:
|
||||||
const Packet beta_0 = pset1<Packet>(9.93151921023180e-01f);
|
// Express exp(x) as exp(m*ln(2) + r) = 2^m*exp(r), start by extracting
|
||||||
const Packet beta_2 = pset1<Packet>(1.16817656904453e-01f);
|
// m = floor(x/ln(2) + 0.5), such that x = m*ln(2) + r.
|
||||||
const Packet beta_4 = pset1<Packet>(1.70198817374094e-03f);
|
const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
|
||||||
const Packet beta_6 = pset1<Packet>(6.29106785017040e-06f);
|
Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
|
||||||
const Packet beta_8 = pset1<Packet>(5.76102136993427e-09f);
|
// Get r = x - m*ln(2). We use a trick from Cephes where the term
|
||||||
const Packet beta_10 = pset1<Packet>(6.10247389755681e-13f);
|
// m*ln(2) is subtracted out in two parts, m*C1+m*C2 = m*ln(2),
|
||||||
|
// to avoid accumulating truncation errors.
|
||||||
|
const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f);
|
||||||
|
const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f);
|
||||||
|
Packet r = pmadd(m, cst_cephes_exp_C1, x);
|
||||||
|
r = pmadd(m, cst_cephes_exp_C2, r);
|
||||||
|
|
||||||
// Since the polynomials are odd/even, we need x^2.
|
// 3. Compute an approximation to exp(r) using a degree 5 minimax polynomial.
|
||||||
const Packet x2 = pmul(x, x);
|
// We compute even and odd terms separately to increase instruction level
|
||||||
|
// parallelism.
|
||||||
|
Packet r2 = pmul(r, r);
|
||||||
|
const Packet cst_p2 = pset1<Packet>(0.49999141693115234375f);
|
||||||
|
const Packet cst_p3 = pset1<Packet>(0.16666877269744873046875f);
|
||||||
|
const Packet cst_p4 = pset1<Packet>(4.1898667812347412109375e-2f);
|
||||||
|
const Packet cst_p5 = pset1<Packet>(8.33471305668354034423828125e-3f);
|
||||||
|
|
||||||
// Evaluate the numerator polynomial p.
|
const Packet p_even = pmadd(r2, cst_p4, cst_p2);
|
||||||
Packet p = pmadd(x2, alpha_9, alpha_7);
|
const Packet p_odd = pmadd(r2, cst_p5, cst_p3);
|
||||||
p = pmadd(x2, p, alpha_5);
|
const Packet p_low = padd(r, cst_one);
|
||||||
p = pmadd(x2, p, alpha_3);
|
Packet p = pmadd(r, p_odd, p_even);
|
||||||
p = pmadd(x2, p, alpha_1);
|
p = pmadd(r2, p, p_low);
|
||||||
p = pmul(x, p);
|
|
||||||
|
|
||||||
// Evaluate the denominator polynomial q.
|
// 4. Undo subtractive range reduction exp(m*ln(2) + r) = 2^m * exp(r).
|
||||||
Packet q = pmadd(x2, beta_10, beta_8);
|
Packet e = pldexp_fast_impl<Packet>::run(p, m);
|
||||||
q = pmadd(x2, q, beta_6);
|
|
||||||
q = pmadd(x2, q, beta_4);
|
// 5. Undo multiplicative range reduction by using exp(r) = exp(r/2)^2.
|
||||||
q = pmadd(x2, q, beta_2);
|
e = pmul(e, e);
|
||||||
q = pmadd(x2, q, beta_0);
|
|
||||||
// Divide the numerator by the denominator and shift it up.
|
// Return exp(x) / (1 + exp(x))
|
||||||
const Packet logistic = padd(pdiv(p, q), pset1<Packet>(0.5f));
|
return pselect(zero_mask, cst_zero, pdiv(e, padd(cst_one, e)));
|
||||||
if (EIGEN_PREDICT_FALSE(any_small)) {
|
|
||||||
const Packet exponential = pexp(_x);
|
|
||||||
return pselect(lt_mask, exponential, logistic);
|
|
||||||
} else {
|
|
||||||
return logistic;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE
|
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE
|
||||||
|
Loading…
x
Reference in New Issue
Block a user