Make sure we return +/-1 above the clamping point for Erf().

(cherry picked from commit b378014fef017a829fb42c7fad15f3764bfb8ef9)
This commit is contained in:
Rasmus Munk Larsen 2023-04-18 13:27:47 -07:00 committed by Antonio Sanchez
parent f04d02dbf6
commit f296720d7d

View File

@ -301,12 +301,9 @@ struct digamma_impl {
This implementation works on both scalars and Ts.
*/
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
// Clamp the inputs to the range [-4, 4] since anything outside
// this range is +/-1.0f in single-precision.
const T plus_4 = pset1<T>(4.f);
const T minus_4 = pset1<T>(-4.f);
const T x = pmax(pmin(a_x, plus_4), minus_4);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& x) {
const float kErfInvOneMinusHalfULP = 3.832506856900711f;
const T clamp = pcmp_le(pset1<T>(kErfInvOneMinusHalfULP), pabs(x));
// The monomial coefficients of the numerator polynomial (odd).
const T alpha_1 = pset1<T>(-1.60960333262415e-02f);
const T alpha_3 = pset1<T>(-2.95459980854025e-03f);
@ -342,7 +339,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator.
return pdiv(p, q);
const T sign = pselect(pcmp_le(x, pset1<T>(0.0f)), pset1<T>(-1.0f), pset1<T>(1.0f));
return pselect(clamp, sign, pdiv(p, q));
}
template <typename T>