From 3026fc0d3cf7becebd0761d8213536071eb1aa50 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Fri, 14 Apr 2023 16:57:56 +0000 Subject: [PATCH] Improve accuracy of erf(). --- .../SpecialFunctions/SpecialFunctionsImpl.h | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index d85729d88..5d84701e3 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h @@ -296,49 +296,50 @@ struct digamma_impl { ****************************************************************************/ /** \internal \returns the error function of \a a (coeff-wise) - Doesn't do anything fancy, just a 13/8-degree rational interpolant which - is accurate up to a couple of ulp in the range [-4, 4], outside of which - fl(erf(x)) = +/-1. + Doesn't do anything fancy, just a 9/12-degree rational interpolant which + is accurate to 3 ulp for normalized floats in the range [-c;c], where + c = erfinv(1-2^-23), outside of which x should be +/-1 in single precision. + Strictly speaking c should be erfinv(1-2^-24), but we clamp slightly earlier + to avoid returning values greater than 1. This implementation works on both scalars and Ts. */ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) { - // Clamp the inputs to the range [-4, 4] since anything outside - // this range is +/-1.0f in single-precision. - const T plus_4 = pset1(4.f); - const T minus_4 = pset1(-4.f); - const T x = pmax(pmin(a_x, plus_4), minus_4); + constexpr float kErfInvOneMinusULP = 3.7439211627767994f; + const T plus_clamp = pset1(kErfInvOneMinusULP); + const T minus_clamp = pset1(-kErfInvOneMinusULP); + const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); // The monomial coefficients of the numerator polynomial (odd). - const T alpha_1 = pset1(-1.60960333262415e-02f); - const T alpha_3 = pset1(-2.95459980854025e-03f); - const T alpha_5 = pset1(-7.34990630326855e-04f); - const T alpha_7 = pset1(-5.69250639462346e-05f); - const T alpha_9 = pset1(-2.10102402082508e-06f); - const T alpha_11 = pset1(2.77068142495902e-08f); - const T alpha_13 = pset1(-2.72614225801306e-10f); + const T alpha_1 = pset1(1.128379143519084f); + const T alpha_3 = pset1(0.18520832239976145f); + const T alpha_5 = pset1(0.050955695062380861f); + const T alpha_7 = pset1(0.0034082910107109506f); + const T alpha_9 = pset1(0.00022905065861350646f); // The monomial coefficients of the denominator polynomial (even). - const T beta_0 = pset1(-1.42647390514189e-02f); - const T beta_2 = pset1(-7.37332916720468e-03f); - const T beta_4 = pset1(-1.68282697438203e-03f); - const T beta_6 = pset1(-2.13374055278905e-04f); - const T beta_8 = pset1(-1.45660718464996e-05f); + const T beta_0 = pset1(1.0f); + const T beta_2 = pset1(0.49746925110067538f); + const T beta_4 = pset1(0.11098505178285362f); + const T beta_6 = pset1(0.014070470171167667f); + const T beta_8 = pset1(0.0010179625278914885f); + const T beta_10 = pset1(0.000023547966471313185f); + const T beta_12 = pset1(-1.1791602954361697e-7); // Since the polynomials are odd/even, we need x^2. const T x2 = pmul(x, x); // Evaluate the numerator polynomial p. - T p = pmadd(x2, alpha_13, alpha_11); - p = pmadd(x2, p, alpha_9); - p = pmadd(x2, p, alpha_7); + T p = pmadd(x2, alpha_9, alpha_7); p = pmadd(x2, p, alpha_5); p = pmadd(x2, p, alpha_3); p = pmadd(x2, p, alpha_1); p = pmul(x, p); // Evaluate the denominator polynomial p. - T q = pmadd(x2, beta_8, beta_6); + T q = pmadd(x2, beta_12, beta_10); + q = pmadd(x2, q, beta_8); + q = pmadd(x2, q, beta_6); q = pmadd(x2, q, beta_4); q = pmadd(x2, q, beta_2); q = pmadd(x2, q, beta_0);