Vectorize erfc() for float

This commit is contained in:
Rasmus Munk Larsen 2024-10-09 18:38:05 +00:00
parent 78f3c654ee
commit 7eea0a9213

View File

@ -345,6 +345,49 @@ struct erf_impl<double> {
/***************************************************************************
* Implementation of erfc, requires C++11/C99 *
****************************************************************************/
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erfc_float(const T& x) {
const T x_abs = pmin(pabs(x), pset1<T>(10.0f));
const T one = pset1<T>(1.0f);
const T x_abs_gt_one_mask = pcmp_lt(one, x_abs);
// erfc(x) = 1 + x * S(x^2), |x| <= 1.
//
// Coefficients for S and T generated with Rminimax command:
// ./ratapprox --function="erfc(x)-1" --dom='[-1,1]' --type=[11,0] --num="odd"
// --numF="[SG]" --denF="[SG]" --log --dispCoeff="dec"
constexpr float alpha[] = {5.61802298761904239654541015625e-04, -4.91381669417023658752441406250e-03,
2.67075151205062866210937500000e-02, -1.12800106406211853027343750000e-01,
3.76122951507568359375000000000e-01, -1.12837910652160644531250000000e+00};
const T x2 = pmul(x, x);
const T erfc_small = pmadd(x, ppolevl<T, 5>::run(x2, alpha), one);
// Return early if we don't need the more expensive approximation for any
// entry in a.
if (!predux_any(x_abs_gt_one_mask)) return erfc_small;
// erfc(x) = exp(-x^2) * 1/x * P(1/x^2) / Q(1/x^2), 1 < x < 9.
//
// Coefficients for P and Q generated with Rminimax command:
// ./ratapprox --function="erfc(1/sqrt(x))*exp(1/x)/sqrt(x)"
// --dom='[0.01,1]' --type=[3,4] --numF="[SG]" --denF="[SG]" --log
// --dispCoeff="dec"
constexpr float gamma[] = {1.0208116471767425537109375e-01f, 4.2920666933059692382812500e-01f,
3.2379078865051269531250000e-01f, 5.3971976041793823242187500e-02f};
constexpr float delta[] = {1.7251677811145782470703125e-02f, 3.9137163758277893066406250e-01f,
1.0000000000000000000000000e+00f, 6.2173241376876831054687500e-01f,
9.5662862062454223632812500e-02f};
const T z = pexp(pnegate(x2));
const T q2 = preciprocal(x2);
const T num = ppolevl<T, 3>::run(q2, gamma);
const T denom = pmul(x_abs, ppolevl<T, 4>::run(q2, delta));
const T r = pdiv(num, denom);
// If x < -1 then use erfc(x) = 2 - erfc(|x|).
const T x_negative = pcmp_lt(x, pset1<T>(0.0f));
const T erfc_large = pselect(x_negative, pnmadd(z, r, pset1<T>(2.0f)), pmul(z, r));
return pselect(x_abs_gt_one_mask, erfc_large, erfc_small);
}
template <typename Scalar>
struct erfc_impl {
@ -365,7 +408,7 @@ struct erfc_impl<float> {
#if defined(SYCL_DEVICE_ONLY)
return cl::sycl::erfc(x);
#else
return ::erfcf(x);
return generic_fast_erfc_float(x);
#endif
}
};
@ -462,17 +505,17 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float flipsign<float>(const float& should_
template <typename T, typename ScalarType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_ndtri_gt_exp_neg_two(const T& b) {
const ScalarType p0[] = {ScalarType(-5.99633501014107895267e1), ScalarType(9.80010754185999661536e1),
ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1),
ScalarType(-1.23916583867381258016e0)};
ScalarType(-5.66762857469070293439e1), ScalarType(1.39312609387279679503e1),
ScalarType(-1.23916583867381258016e0)};
const ScalarType q0[] = {ScalarType(1.0),
ScalarType(1.95448858338141759834e0),
ScalarType(4.67627912898881538453e0),
ScalarType(8.63602421390890590575e1),
ScalarType(-2.25462687854119370527e2),
ScalarType(2.00260212380060660359e2),
ScalarType(-8.20372256168333339912e1),
ScalarType(1.59056225126211695515e1),
ScalarType(-1.18331621121330003142e0)};
ScalarType(1.95448858338141759834e0),
ScalarType(4.67627912898881538453e0),
ScalarType(8.63602421390890590575e1),
ScalarType(-2.25462687854119370527e2),
ScalarType(2.00260212380060660359e2),
ScalarType(-8.20372256168333339912e1),
ScalarType(1.59056225126211695515e1),
ScalarType(-1.18331621121330003142e0)};
const T sqrt2pi = pset1<T>(ScalarType(2.50662827463100050242e0));
const T half = pset1<T>(ScalarType(0.5));
T c, c2, ndtri_gt_exp_neg_two;