Fix accuracy of logistic sigmoid

This commit is contained in:
Rasmus Munk Larsen 2022-01-08 00:15:14 +00:00
parent 8b8125c574
commit 80ccacc717

View File

@ -1026,83 +1026,93 @@ struct scalar_logistic_op {
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet packetOp(const Packet& x) const { Packet packetOp(const Packet& x) const {
const Packet one = pset1<Packet>(T(1)); const Packet one = pset1<Packet>(T(1));
return pdiv(one, padd(one, pexp(pnegate(x)))); const Packet e = pexp(x);
return pdiv(e, padd(one, e));
} }
}; };
#ifndef EIGEN_GPU_COMPILE_PHASE #ifndef EIGEN_GPU_COMPILE_PHASE
/** \internal /** \internal
* \brief Template specialization of the logistic function for float. * \brief Template specialization of the logistic function for float.
* * Computes S(x) = exp(x) / (1 + exp(x)), where exp(x) is implemented
* Uses just a 9/10-degree rational interpolant which * using an algorithm partly adopted from the implementation of
* interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range * pexp_float. See the individual steps described in the code below.
* [-9, 18]. Below -9 we use the more accurate approximation * Note that compared to pexp, we use an additional outer multiplicative
* 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 within * range reduction step using the identity exp(x) = exp(x/2)^2.
* one ulp. The shifted logistic is interpolated because it was easier to * This prevert us from having to call ldexp on values that could produce
* make the fit converge. * a denormal result, which allows us to call the faster implementation in
* * pldexp_fast_impl<Packet>::run(p, m).
* The final squaring, however, doubles the error bound on the final
* approximation. Exhaustive testing shows that we have a worst case error
* of 4.5 ulps (compared to computing S(x) in double precision), which is
* acceptable.
*/ */
template <> template <>
struct scalar_logistic_op<float> { struct scalar_logistic_op<float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
return packetOp(x); const float e = numext::exp(x);
return e / (1.0f + e);
} }
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename Packet>
Packet packetOp(const Packet& _x) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
const Packet cutoff_lower = pset1<Packet>(-9.f); packetOp(const Packet& _x) const {
const Packet lt_mask = pcmp_lt<Packet>(_x, cutoff_lower); const Packet cst_zero = pset1<Packet>(0.0f);
const bool any_small = predux_any(lt_mask); const Packet cst_one = pset1<Packet>(1.0f);
const Packet cst_half = pset1<Packet>(0.5f);
const Packet cst_exp_hi = pset1<Packet>(16.f);
const Packet cst_exp_lo = pset1<Packet>(-104.f);
// The upper cut-off is the smallest x for which the rational approximation evaluates to 1. // Clamp x to the non-trivial range where S(x). Outside this
// Choosing this value saves us a few instructions clamping the results at the end. // interval the correctly rounded value of S(x) is either zero
#ifdef EIGEN_VECTORIZE_FMA // or one.
const Packet cutoff_upper = pset1<Packet>(15.7243833541870117f); Packet zero_mask = pcmp_lt(_x, cst_exp_lo);
#else Packet x = pmin(_x, cst_exp_hi);
const Packet cutoff_upper = pset1<Packet>(15.6437711715698242f);
#endif
const Packet x = pmin(_x, cutoff_upper);
// The monomial coefficients of the numerator polynomial (odd). // 1. Multiplicative range reduction:
const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01f); // Reduce the range of x by a factor of 2. This avoids having
const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03f); // to compute exp(x) accurately where the result is a denormalized
const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05f); // value.
const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07f); x = pmul(x, cst_half);
const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11f);
// The monomial coefficients of the denominator polynomial (even). // 2. Subtractive range reduction:
const Packet beta_0 = pset1<Packet>(9.93151921023180e-01f); // Express exp(x) as exp(m*ln(2) + r) = 2^m*exp(r), start by extracting
const Packet beta_2 = pset1<Packet>(1.16817656904453e-01f); // m = floor(x/ln(2) + 0.5), such that x = m*ln(2) + r.
const Packet beta_4 = pset1<Packet>(1.70198817374094e-03f); const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
const Packet beta_6 = pset1<Packet>(6.29106785017040e-06f); Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
const Packet beta_8 = pset1<Packet>(5.76102136993427e-09f); // Get r = x - m*ln(2). We use a trick from Cephes where the term
const Packet beta_10 = pset1<Packet>(6.10247389755681e-13f); // m*ln(2) is subtracted out in two parts, m*C1+m*C2 = m*ln(2),
// to avoid accumulating truncation errors.
const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f);
const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f);
Packet r = pmadd(m, cst_cephes_exp_C1, x);
r = pmadd(m, cst_cephes_exp_C2, r);
// Since the polynomials are odd/even, we need x^2. // 3. Compute an approximation to exp(r) using a degree 5 minimax polynomial.
const Packet x2 = pmul(x, x); // We compute even and odd terms separately to increase instruction level
// parallelism.
Packet r2 = pmul(r, r);
const Packet cst_p2 = pset1<Packet>(0.49999141693115234375f);
const Packet cst_p3 = pset1<Packet>(0.16666877269744873046875f);
const Packet cst_p4 = pset1<Packet>(4.1898667812347412109375e-2f);
const Packet cst_p5 = pset1<Packet>(8.33471305668354034423828125e-3f);
// Evaluate the numerator polynomial p. const Packet p_even = pmadd(r2, cst_p4, cst_p2);
Packet p = pmadd(x2, alpha_9, alpha_7); const Packet p_odd = pmadd(r2, cst_p5, cst_p3);
p = pmadd(x2, p, alpha_5); const Packet p_low = padd(r, cst_one);
p = pmadd(x2, p, alpha_3); Packet p = pmadd(r, p_odd, p_even);
p = pmadd(x2, p, alpha_1); p = pmadd(r2, p, p_low);
p = pmul(x, p);
// Evaluate the denominator polynomial q. // 4. Undo subtractive range reduction exp(m*ln(2) + r) = 2^m * exp(r).
Packet q = pmadd(x2, beta_10, beta_8); Packet e = pldexp_fast_impl<Packet>::run(p, m);
q = pmadd(x2, q, beta_6);
q = pmadd(x2, q, beta_4); // 5. Undo multiplicative range reduction by using exp(r) = exp(r/2)^2.
q = pmadd(x2, q, beta_2); e = pmul(e, e);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator and shift it up. // Return exp(x) / (1 + exp(x))
const Packet logistic = padd(pdiv(p, q), pset1<Packet>(0.5f)); return pselect(zero_mask, cst_zero, pdiv(e, padd(cst_one, e)));
if (EIGEN_PREDICT_FALSE(any_small)) {
const Packet exponential = pexp(_x);
return pselect(lt_mask, exponential, logistic);
} else {
return logistic;
}
} }
}; };
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE #endif // #ifndef EIGEN_GPU_COMPILE_PHASE