From 80ccacc71779d899c4d14a2e8a6b57c4df58d58d Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Sat, 8 Jan 2022 00:15:14 +0000 Subject: [PATCH] Fix accuracy of logistic sigmoid --- Eigen/src/Core/functors/UnaryFunctors.h | 126 +++++++++++++----------- 1 file changed, 68 insertions(+), 58 deletions(-) diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index f792c8ed8..72d485814 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -1026,83 +1026,93 @@ struct scalar_logistic_op { template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { const Packet one = pset1(T(1)); - return pdiv(one, padd(one, pexp(pnegate(x)))); + const Packet e = pexp(x); + return pdiv(e, padd(one, e)); } }; #ifndef EIGEN_GPU_COMPILE_PHASE + /** \internal * \brief Template specialization of the logistic function for float. - * - * Uses just a 9/10-degree rational interpolant which - * interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range - * [-9, 18]. Below -9 we use the more accurate approximation - * 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 within - * one ulp. The shifted logistic is interpolated because it was easier to - * make the fit converge. - * + * Computes S(x) = exp(x) / (1 + exp(x)), where exp(x) is implemented + * using an algorithm partly adopted from the implementation of + * pexp_float. See the individual steps described in the code below. + * Note that compared to pexp, we use an additional outer multiplicative + * range reduction step using the identity exp(x) = exp(x/2)^2. + * This prevert us from having to call ldexp on values that could produce + * a denormal result, which allows us to call the faster implementation in + * pldexp_fast_impl::run(p, m). + * The final squaring, however, doubles the error bound on the final + * approximation. Exhaustive testing shows that we have a worst case error + * of 4.5 ulps (compared to computing S(x) in double precision), which is + * acceptable. */ template <> struct scalar_logistic_op { EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const { - return packetOp(x); + const float e = numext::exp(x); + return e / (1.0f + e); } - template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - Packet packetOp(const Packet& _x) const { - const Packet cutoff_lower = pset1(-9.f); - const Packet lt_mask = pcmp_lt(_x, cutoff_lower); - const bool any_small = predux_any(lt_mask); + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet + packetOp(const Packet& _x) const { + const Packet cst_zero = pset1(0.0f); + const Packet cst_one = pset1(1.0f); + const Packet cst_half = pset1(0.5f); + const Packet cst_exp_hi = pset1(16.f); + const Packet cst_exp_lo = pset1(-104.f); - // The upper cut-off is the smallest x for which the rational approximation evaluates to 1. - // Choosing this value saves us a few instructions clamping the results at the end. -#ifdef EIGEN_VECTORIZE_FMA - const Packet cutoff_upper = pset1(15.7243833541870117f); -#else - const Packet cutoff_upper = pset1(15.6437711715698242f); -#endif - const Packet x = pmin(_x, cutoff_upper); + // Clamp x to the non-trivial range where S(x). Outside this + // interval the correctly rounded value of S(x) is either zero + // or one. + Packet zero_mask = pcmp_lt(_x, cst_exp_lo); + Packet x = pmin(_x, cst_exp_hi); - // The monomial coefficients of the numerator polynomial (odd). - const Packet alpha_1 = pset1(2.48287947061529e-01f); - const Packet alpha_3 = pset1(8.51377133304701e-03f); - const Packet alpha_5 = pset1(6.08574864600143e-05f); - const Packet alpha_7 = pset1(1.15627324459942e-07f); - const Packet alpha_9 = pset1(4.37031012579801e-11f); + // 1. Multiplicative range reduction: + // Reduce the range of x by a factor of 2. This avoids having + // to compute exp(x) accurately where the result is a denormalized + // value. + x = pmul(x, cst_half); - // The monomial coefficients of the denominator polynomial (even). - const Packet beta_0 = pset1(9.93151921023180e-01f); - const Packet beta_2 = pset1(1.16817656904453e-01f); - const Packet beta_4 = pset1(1.70198817374094e-03f); - const Packet beta_6 = pset1(6.29106785017040e-06f); - const Packet beta_8 = pset1(5.76102136993427e-09f); - const Packet beta_10 = pset1(6.10247389755681e-13f); + // 2. Subtractive range reduction: + // Express exp(x) as exp(m*ln(2) + r) = 2^m*exp(r), start by extracting + // m = floor(x/ln(2) + 0.5), such that x = m*ln(2) + r. + const Packet cst_cephes_LOG2EF = pset1(1.44269504088896341f); + Packet m = pfloor(pmadd(x, cst_cephes_LOG2EF, cst_half)); + // Get r = x - m*ln(2). We use a trick from Cephes where the term + // m*ln(2) is subtracted out in two parts, m*C1+m*C2 = m*ln(2), + // to avoid accumulating truncation errors. + const Packet cst_cephes_exp_C1 = pset1(-0.693359375f); + const Packet cst_cephes_exp_C2 = pset1(2.12194440e-4f); + Packet r = pmadd(m, cst_cephes_exp_C1, x); + r = pmadd(m, cst_cephes_exp_C2, r); - // Since the polynomials are odd/even, we need x^2. - const Packet x2 = pmul(x, x); + // 3. Compute an approximation to exp(r) using a degree 5 minimax polynomial. + // We compute even and odd terms separately to increase instruction level + // parallelism. + Packet r2 = pmul(r, r); + const Packet cst_p2 = pset1(0.49999141693115234375f); + const Packet cst_p3 = pset1(0.16666877269744873046875f); + const Packet cst_p4 = pset1(4.1898667812347412109375e-2f); + const Packet cst_p5 = pset1(8.33471305668354034423828125e-3f); - // Evaluate the numerator polynomial p. - Packet p = pmadd(x2, alpha_9, alpha_7); - p = pmadd(x2, p, alpha_5); - p = pmadd(x2, p, alpha_3); - p = pmadd(x2, p, alpha_1); - p = pmul(x, p); + const Packet p_even = pmadd(r2, cst_p4, cst_p2); + const Packet p_odd = pmadd(r2, cst_p5, cst_p3); + const Packet p_low = padd(r, cst_one); + Packet p = pmadd(r, p_odd, p_even); + p = pmadd(r2, p, p_low); - // Evaluate the denominator polynomial q. - Packet q = pmadd(x2, beta_10, beta_8); - q = pmadd(x2, q, beta_6); - q = pmadd(x2, q, beta_4); - q = pmadd(x2, q, beta_2); - q = pmadd(x2, q, beta_0); - // Divide the numerator by the denominator and shift it up. - const Packet logistic = padd(pdiv(p, q), pset1(0.5f)); - if (EIGEN_PREDICT_FALSE(any_small)) { - const Packet exponential = pexp(_x); - return pselect(lt_mask, exponential, logistic); - } else { - return logistic; - } + // 4. Undo subtractive range reduction exp(m*ln(2) + r) = 2^m * exp(r). + Packet e = pldexp_fast_impl::run(p, m); + + // 5. Undo multiplicative range reduction by using exp(r) = exp(r/2)^2. + e = pmul(e, e); + + // Return exp(x) / (1 + exp(x)) + return pselect(zero_mask, cst_zero, pdiv(e, padd(cst_one, e))); } }; #endif // #ifndef EIGEN_GPU_COMPILE_PHASE