Improve accuracy of erf().

This commit is contained in:
Rasmus Munk Larsen 2023-04-14 16:57:56 +00:00
parent 554fe02ae3
commit 3026fc0d3c

View File

@ -296,49 +296,50 @@ struct digamma_impl {
****************************************************************************/ ****************************************************************************/
/** \internal \returns the error function of \a a (coeff-wise) /** \internal \returns the error function of \a a (coeff-wise)
Doesn't do anything fancy, just a 13/8-degree rational interpolant which Doesn't do anything fancy, just a 9/12-degree rational interpolant which
is accurate up to a couple of ulp in the range [-4, 4], outside of which is accurate to 3 ulp for normalized floats in the range [-c;c], where
fl(erf(x)) = +/-1. c = erfinv(1-2^-23), outside of which x should be +/-1 in single precision.
Strictly speaking c should be erfinv(1-2^-24), but we clamp slightly earlier
to avoid returning values greater than 1.
This implementation works on both scalars and Ts. This implementation works on both scalars and Ts.
*/ */
template <typename T> template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
// Clamp the inputs to the range [-4, 4] since anything outside constexpr float kErfInvOneMinusULP = 3.7439211627767994f;
// this range is +/-1.0f in single-precision. const T plus_clamp = pset1<T>(kErfInvOneMinusULP);
const T plus_4 = pset1<T>(4.f); const T minus_clamp = pset1<T>(-kErfInvOneMinusULP);
const T minus_4 = pset1<T>(-4.f); const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
const T x = pmax(pmin(a_x, plus_4), minus_4);
// The monomial coefficients of the numerator polynomial (odd). // The monomial coefficients of the numerator polynomial (odd).
const T alpha_1 = pset1<T>(-1.60960333262415e-02f); const T alpha_1 = pset1<T>(1.128379143519084f);
const T alpha_3 = pset1<T>(-2.95459980854025e-03f); const T alpha_3 = pset1<T>(0.18520832239976145f);
const T alpha_5 = pset1<T>(-7.34990630326855e-04f); const T alpha_5 = pset1<T>(0.050955695062380861f);
const T alpha_7 = pset1<T>(-5.69250639462346e-05f); const T alpha_7 = pset1<T>(0.0034082910107109506f);
const T alpha_9 = pset1<T>(-2.10102402082508e-06f); const T alpha_9 = pset1<T>(0.00022905065861350646f);
const T alpha_11 = pset1<T>(2.77068142495902e-08f);
const T alpha_13 = pset1<T>(-2.72614225801306e-10f);
// The monomial coefficients of the denominator polynomial (even). // The monomial coefficients of the denominator polynomial (even).
const T beta_0 = pset1<T>(-1.42647390514189e-02f); const T beta_0 = pset1<T>(1.0f);
const T beta_2 = pset1<T>(-7.37332916720468e-03f); const T beta_2 = pset1<T>(0.49746925110067538f);
const T beta_4 = pset1<T>(-1.68282697438203e-03f); const T beta_4 = pset1<T>(0.11098505178285362f);
const T beta_6 = pset1<T>(-2.13374055278905e-04f); const T beta_6 = pset1<T>(0.014070470171167667f);
const T beta_8 = pset1<T>(-1.45660718464996e-05f); const T beta_8 = pset1<T>(0.0010179625278914885f);
const T beta_10 = pset1<T>(0.000023547966471313185f);
const T beta_12 = pset1<T>(-1.1791602954361697e-7);
// Since the polynomials are odd/even, we need x^2. // Since the polynomials are odd/even, we need x^2.
const T x2 = pmul(x, x); const T x2 = pmul(x, x);
// Evaluate the numerator polynomial p. // Evaluate the numerator polynomial p.
T p = pmadd(x2, alpha_13, alpha_11); T p = pmadd(x2, alpha_9, alpha_7);
p = pmadd(x2, p, alpha_9);
p = pmadd(x2, p, alpha_7);
p = pmadd(x2, p, alpha_5); p = pmadd(x2, p, alpha_5);
p = pmadd(x2, p, alpha_3); p = pmadd(x2, p, alpha_3);
p = pmadd(x2, p, alpha_1); p = pmadd(x2, p, alpha_1);
p = pmul(x, p); p = pmul(x, p);
// Evaluate the denominator polynomial p. // Evaluate the denominator polynomial p.
T q = pmadd(x2, beta_8, beta_6); T q = pmadd(x2, beta_12, beta_10);
q = pmadd(x2, q, beta_8);
q = pmadd(x2, q, beta_6);
q = pmadd(x2, q, beta_4); q = pmadd(x2, q, beta_4);
q = pmadd(x2, q, beta_2); q = pmadd(x2, q, beta_2);
q = pmadd(x2, q, beta_0); q = pmadd(x2, q, beta_0);