From 7eea0a9213e801ad9479a6499fd0330ec1db8693 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 9 Oct 2024 18:38:05 +0000 Subject: [PATCH] Vectorize erfc() for float --- .../SpecialFunctions/SpecialFunctionsImpl.h | 65 +++++++++++++++---- 1 file changed, 54 insertions(+), 11 deletions(-) diff --git a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h index 86a49b679..5169f1cb6 100644 --- a/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h +++ b/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h @@ -345,6 +345,49 @@ struct erf_impl { /*************************************************************************** * Implementation of erfc, requires C++11/C99 * ****************************************************************************/ +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc_float(const T& x) { + const T x_abs = pmin(pabs(x), pset1(10.0f)); + const T one = pset1(1.0f); + const T x_abs_gt_one_mask = pcmp_lt(one, x_abs); + + // erfc(x) = 1 + x * S(x^2), |x| <= 1. + // + // Coefficients for S and T generated with Rminimax command: + // ./ratapprox --function="erfc(x)-1" --dom='[-1,1]' --type=[11,0] --num="odd" + // --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec" + constexpr float alpha[] = {5.61802298761904239654541015625e-04, -4.91381669417023658752441406250e-03, + 2.67075151205062866210937500000e-02, -1.12800106406211853027343750000e-01, + 3.76122951507568359375000000000e-01, -1.12837910652160644531250000000e+00}; + const T x2 = pmul(x, x); + const T erfc_small = pmadd(x, ppolevl::run(x2, alpha), one); + + // Return early if we don't need the more expensive approximation for any + // entry in a. + if (!predux_any(x_abs_gt_one_mask)) return erfc_small; + + // erfc(x) = exp(-x^2) * 1/x * P(1/x^2) / Q(1/x^2), 1 < x < 9. + // + // Coefficients for P and Q generated with Rminimax command: + // ./ratapprox --function="erfc(1/sqrt(x))*exp(1/x)/sqrt(x)" + // --dom='[0.01,1]' --type=[3,4] --numF="[SG]" --denF="[SG]" --log + // --dispCoeff="dec" + constexpr float gamma[] = {1.0208116471767425537109375e-01f, 4.2920666933059692382812500e-01f, + 3.2379078865051269531250000e-01f, 5.3971976041793823242187500e-02f}; + constexpr float delta[] = {1.7251677811145782470703125e-02f, 3.9137163758277893066406250e-01f, + 1.0000000000000000000000000e+00f, 6.2173241376876831054687500e-01f, + 9.5662862062454223632812500e-02f}; + const T z = pexp(pnegate(x2)); + const T q2 = preciprocal(x2); + const T num = ppolevl::run(q2, gamma); + const T denom = pmul(x_abs, ppolevl::run(q2, delta)); + const T r = pdiv(num, denom); + // If x < -1 then use erfc(x) = 2 - erfc(|x|). + const T x_negative = pcmp_lt(x, pset1(0.0f)); + const T erfc_large = pselect(x_negative, pnmadd(z, r, pset1(2.0f)), pmul(z, r)); + + return pselect(x_abs_gt_one_mask, erfc_large, erfc_small); +} template struct erfc_impl { @@ -365,7 +408,7 @@ struct erfc_impl { #if defined(SYCL_DEVICE_ONLY) return cl::sycl::erfc(x); #else - return ::erfcf(x); + return generic_fast_erfc_float(x); #endif } }; @@ -462,17 +505,17 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float flipsign(const float& should_ template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(const T& b) { const ScalarType p0[] = {ScalarType(-5.99633501014107895267e1), ScalarType(9.80010754185999661536e1), - ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1), - ScalarType(-1.23916583867381258016e0)}; + ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1), + ScalarType(-1.23916583867381258016e0)}; const ScalarType q0[] = {ScalarType(1.0), - ScalarType(1.95448858338141759834e0), - ScalarType(4.67627912898881538453e0), - ScalarType(8.63602421390890590575e1), - ScalarType(-2.25462687854119370527e2), - ScalarType(2.00260212380060660359e2), - ScalarType(-8.20372256168333339912e1), - ScalarType(1.59056225126211695515e1), - ScalarType(-1.18331621121330003142e0)}; + ScalarType(1.95448858338141759834e0), + ScalarType(4.67627912898881538453e0), + ScalarType(8.63602421390890590575e1), + ScalarType(-2.25462687854119370527e2), + ScalarType(2.00260212380060660359e2), + ScalarType(-8.20372256168333339912e1), + ScalarType(1.59056225126211695515e1), + ScalarType(-1.18331621121330003142e0)}; const T sqrt2pi = pset1(ScalarType(2.50662827463100050242e0)); const T half = pset1(ScalarType(0.5)); T c, c2, ndtri_gt_exp_neg_two;