Make sure we return +/-1 above the clamping point for Erf().

This commit is contained in:
Rasmus Munk Larsen 2023-04-18 13:27:47 -07:00
parent e2bbf496f6
commit b378014fef

View File

@ -305,11 +305,9 @@ struct digamma_impl {
This implementation works on both scalars and Ts.
*/
template <typename T>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
constexpr float kErfInvOneMinusULP = 3.7439211627767994f;
const T plus_clamp = pset1<T>(kErfInvOneMinusULP);
const T minus_clamp = pset1<T>(-kErfInvOneMinusULP);
const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& x) {
constexpr float kErfInvOneMinusHalfULP = 3.832506856900711f;
const T clamp = pcmp_le(pset1<T>(kErfInvOneMinusHalfULP), pabs(x));
// The monomial coefficients of the numerator polynomial (odd).
const T alpha_1 = pset1<T>(1.128379143519084f);
const T alpha_3 = pset1<T>(0.18520832239976145f);
@ -345,7 +343,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T generic_fast_erf_float(const T& a_x) {
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator.
return pdiv(p, q);
return pselect(clamp, psign(x), pdiv(p, q));
}
template <typename T>