Fix accuracy of logistic sigmoid

This commit is contained in:
Rasmus Munk Larsen 2022-01-08 00:15:14 +00:00
parent 8b8125c574
commit 80ccacc717

View File

@ -1026,83 +1026,93 @@ struct scalar_logistic_op {
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet packetOp(const Packet& x) const {
const Packet one = pset1<Packet>(T(1));
return pdiv(one, padd(one, pexp(pnegate(x))));
const Packet e = pexp(x);
return pdiv(e, padd(one, e));
}
};
#ifndef EIGEN_GPU_COMPILE_PHASE
/** \internal
* \brief Template specialization of the logistic function for float.
*
* Uses just a 9/10-degree rational interpolant which
* interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range
* [-9, 18]. Below -9 we use the more accurate approximation
* 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 within
* one ulp. The shifted logistic is interpolated because it was easier to
* make the fit converge.
*
* Computes S(x) = exp(x) / (1 + exp(x)), where exp(x) is implemented
* using an algorithm partly adopted from the implementation of
* pexp_float. See the individual steps described in the code below.
* Note that compared to pexp, we use an additional outer multiplicative
* range reduction step using the identity exp(x) = exp(x/2)^2.
* This prevert us from having to call ldexp on values that could produce
* a denormal result, which allows us to call the faster implementation in
* pldexp_fast_impl<Packet>::run(p, m).
* The final squaring, however, doubles the error bound on the final
* approximation. Exhaustive testing shows that we have a worst case error
* of 4.5 ulps (compared to computing S(x) in double precision), which is
* acceptable.
*/
template <>
struct scalar_logistic_op<float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
return packetOp(x);
const float e = numext::exp(x);
return e / (1.0f + e);
}
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet packetOp(const Packet& _x) const {
const Packet cutoff_lower = pset1<Packet>(-9.f);
const Packet lt_mask = pcmp_lt<Packet>(_x, cutoff_lower);
const bool any_small = predux_any(lt_mask);
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
packetOp(const Packet& _x) const {
const Packet cst_zero = pset1<Packet>(0.0f);
const Packet cst_one = pset1<Packet>(1.0f);
const Packet cst_half = pset1<Packet>(0.5f);
const Packet cst_exp_hi = pset1<Packet>(16.f);
const Packet cst_exp_lo = pset1<Packet>(-104.f);
// The upper cut-off is the smallest x for which the rational approximation evaluates to 1.
// Choosing this value saves us a few instructions clamping the results at the end.
#ifdef EIGEN_VECTORIZE_FMA
const Packet cutoff_upper = pset1<Packet>(15.7243833541870117f);
#else
const Packet cutoff_upper = pset1<Packet>(15.6437711715698242f);
#endif
const Packet x = pmin(_x, cutoff_upper);
// Clamp x to the non-trivial range where S(x). Outside this
// interval the correctly rounded value of S(x) is either zero
// or one.
Packet zero_mask = pcmp_lt(_x, cst_exp_lo);
Packet x = pmin(_x, cst_exp_hi);
// The monomial coefficients of the numerator polynomial (odd).
const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01f);
const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03f);
const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05f);
const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07f);
const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11f);
// 1. Multiplicative range reduction:
// Reduce the range of x by a factor of 2. This avoids having
// to compute exp(x) accurately where the result is a denormalized
// value.
x = pmul(x, cst_half);
// The monomial coefficients of the denominator polynomial (even).
const Packet beta_0 = pset1<Packet>(9.93151921023180e-01f);
const Packet beta_2 = pset1<Packet>(1.16817656904453e-01f);
const Packet beta_4 = pset1<Packet>(1.70198817374094e-03f);
const Packet beta_6 = pset1<Packet>(6.29106785017040e-06f);
const Packet beta_8 = pset1<Packet>(5.76102136993427e-09f);
const Packet beta_10 = pset1<Packet>(6.10247389755681e-13f);
// 2. Subtractive range reduction:
// Express exp(x) as exp(m*ln(2) + r) = 2^m*exp(r), start by extracting
// m = floor(x/ln(2) + 0.5), such that x = m*ln(2) + r.
const Packet cst_cephes_LOG2EF = pset1<Packet>(1.44269504088896341f);
Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half));
// Get r = x - m*ln(2). We use a trick from Cephes where the term
// m*ln(2) is subtracted out in two parts, m*C1+m*C2 = m*ln(2),
// to avoid accumulating truncation errors.
const Packet cst_cephes_exp_C1 = pset1<Packet>(-0.693359375f);
const Packet cst_cephes_exp_C2 = pset1<Packet>(2.12194440e-4f);
Packet r = pmadd(m, cst_cephes_exp_C1, x);
r = pmadd(m, cst_cephes_exp_C2, r);
// Since the polynomials are odd/even, we need x^2.
const Packet x2 = pmul(x, x);
// 3. Compute an approximation to exp(r) using a degree 5 minimax polynomial.
// We compute even and odd terms separately to increase instruction level
// parallelism.
Packet r2 = pmul(r, r);
const Packet cst_p2 = pset1<Packet>(0.49999141693115234375f);
const Packet cst_p3 = pset1<Packet>(0.16666877269744873046875f);
const Packet cst_p4 = pset1<Packet>(4.1898667812347412109375e-2f);
const Packet cst_p5 = pset1<Packet>(8.33471305668354034423828125e-3f);
// Evaluate the numerator polynomial p.
Packet p = pmadd(x2, alpha_9, alpha_7);
p = pmadd(x2, p, alpha_5);
p = pmadd(x2, p, alpha_3);
p = pmadd(x2, p, alpha_1);
p = pmul(x, p);
const Packet p_even = pmadd(r2, cst_p4, cst_p2);
const Packet p_odd = pmadd(r2, cst_p5, cst_p3);
const Packet p_low = padd(r, cst_one);
Packet p = pmadd(r, p_odd, p_even);
p = pmadd(r2, p, p_low);
// Evaluate the denominator polynomial q.
Packet q = pmadd(x2, beta_10, beta_8);
q = pmadd(x2, q, beta_6);
q = pmadd(x2, q, beta_4);
q = pmadd(x2, q, beta_2);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator and shift it up.
const Packet logistic = padd(pdiv(p, q), pset1<Packet>(0.5f));
if (EIGEN_PREDICT_FALSE(any_small)) {
const Packet exponential = pexp(_x);
return pselect(lt_mask, exponential, logistic);
} else {
return logistic;
}
// 4. Undo subtractive range reduction exp(m*ln(2) + r) = 2^m * exp(r).
Packet e = pldexp_fast_impl<Packet>::run(p, m);
// 5. Undo multiplicative range reduction by using exp(r) = exp(r/2)^2.
e = pmul(e, e);
// Return exp(x) / (1 + exp(x))
return pselect(zero_mask, cst_zero, pdiv(e, padd(cst_one, e)));
}
};
#endif // #ifndef EIGEN_GPU_COMPILE_PHASE