Speed up and improve accuracy of tanh.

This commit is contained in:
Rasmus Munk Larsen 2024-08-16 23:46:28 +00:00
parent 92e373e6f5
commit cc240eea2f

View File

@ -1097,62 +1097,68 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_double(const Pa
}
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
Doesn't do anything fancy, just a 9/8-degree rational interpolant which
is accurate up to a couple of ulps in the (approximate) range [-8, 8],
outside of which tanh(x) = +/-1 in single precision. The input is clamped
to the range [-c, c]. The value c is chosen as the smallest value where
the approximation evaluates to exactly 1. In the reange [-0.0004, 0.0004]
the approximation tanh(x) ~= x is used for better accuracy as x tends to zero.
the approximation evaluates to exactly 1.
This implementation works on both scalars and packets.
*/
template <typename T>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x) {
// Clamp the inputs to the range [-c, c]
// Clamp the inputs to the range [-c, c] and set everything
// outside that range to 1.0. The value c is chosen as the smallest
// floating point argument such that the approximation is exactly 1.
// This saves clamping the value at the end.
#ifdef EIGEN_VECTORIZE_FMA
const T plus_clamp = pset1<T>(7.99881172180175781f);
const T minus_clamp = pset1<T>(-7.99881172180175781f);
const T plus_clamp = pset1<T>(8.05359363555908203f);
const T minus_clamp = pset1<T>(-8.05359363555908203f);
#else
const T plus_clamp = pset1<T>(7.90531110763549805f);
const T minus_clamp = pset1<T>(-7.90531110763549805f);
const T plus_clamp = pset1<T>(7.98551225662231445);
const T minus_clamp = pset1<T>(-7.98551225662231445);
#endif
const T tiny = pset1<T>(0.0004f);
const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
const T tiny_mask = pcmp_lt(pabs(a_x), tiny);
// The following rational approximation was generated by rminimax
// (https://gitlab.inria.fr/sfilip/rminimax) using the following
// command:
// $ ratapprox --function="tanh(x)" --dom='[-8.67,8.67]' --num="odd"
// --den="even" --type="[9,8]" --numF="[SG]" --denF="[SG]" --log
// --output=tanhf.sollya --dispCoeff="dec"
// The monomial coefficients of the numerator polynomial (odd).
const T alpha_1 = pset1<T>(4.89352455891786e-03f);
const T alpha_3 = pset1<T>(6.37261928875436e-04f);
const T alpha_5 = pset1<T>(1.48572235717979e-05f);
const T alpha_7 = pset1<T>(5.12229709037114e-08f);
const T alpha_9 = pset1<T>(-8.60467152213735e-11f);
const T alpha_11 = pset1<T>(2.00018790482477e-13f);
const T alpha_13 = pset1<T>(-2.76076847742355e-16f);
const T alpha_3 = pset1<T>(1.340216100215911865234375e-1f);
const T alpha_5 = pset1<T>(3.52075672708451747894287109375e-3f);
const T alpha_7 = pset1<T>(2.10273356060497462749481201171875e-5f);
const T alpha_9 = pset1<T>(1.39455362813123429077677428722381591796875e-8f);
// The monomial coefficients of the denominator polynomial (even).
const T beta_0 = pset1<T>(4.89352518554385e-03f);
const T beta_2 = pset1<T>(2.26843463243900e-03f);
const T beta_4 = pset1<T>(1.18534705686654e-04f);
const T beta_6 = pset1<T>(1.19825839466702e-06f);
const T beta_0 = pset1<T>(1.0f);
const T beta_2 = pset1<T>(4.67354834079742431640625e-1f);
const T beta_4 = pset1<T>(2.5972545146942138671875e-2f);
const T beta_6 = pset1<T>(3.326951409690082073211669921875e-4f);
const T beta_8 = pset1<T>(8.015776984393596649169921875e-7f);
// Since the polynomials are odd/even, we need x^2.
const T x2 = pmul(x, x);
const T x3 = pmul(x2, x);
// Evaluate the numerator polynomial p.
T p = pmadd(x2, alpha_13, alpha_11);
p = pmadd(x2, p, alpha_9);
p = pmadd(x2, p, alpha_7);
// Interleave the evaluation of the numerator polynomial p and
// denominator polynomial q.
T p = pmadd(x2, alpha_9, alpha_7);
T q = pmadd(x2, beta_8, beta_6);
p = pmadd(x2, p, alpha_5);
q = pmadd(x2, q, beta_4);
p = pmadd(x2, p, alpha_3);
p = pmadd(x2, p, alpha_1);
p = pmul(x, p);
// Evaluate the denominator polynomial q.
T q = pmadd(x2, beta_6, beta_4);
q = pmadd(x2, q, beta_2);
// Take advantage of the fact that alpha_1 = 1 to compute
// x*(x^2*p + alpha_1) = x^3 * p + x.
p = pmadd(x3, p, x);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator.
return pselect(tiny_mask, x, pdiv(p, q));
return pdiv(p, q);
}
template <typename Packet>