Add vectorized implementation of tanh<double>

This commit is contained in:
Rasmus Munk Larsen 2024-08-21 02:29:45 +00:00
parent cc240eea2f
commit 32d95bb097
8 changed files with 109 additions and 14 deletions

View File

@ -27,6 +27,7 @@ EIGEN_DOUBLE_PACKET_FUNCTION(atan, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, Packet4d)
#ifdef EIGEN_VECTORIZE_AVX2
EIGEN_DOUBLE_PACKET_FUNCTION(sin, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(cos, Packet4d)

View File

@ -142,6 +142,7 @@ struct packet_traits<double> : default_packet_traits {
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
#endif
HasTanh = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasSqrt = 1,

View File

@ -153,6 +153,7 @@ struct packet_traits<double> : default_packet_traits {
HasLog = 1,
HasExp = 1,
HasATan = 1,
HasTanh = EIGEN_FAST_MATH,
HasCmp = 1,
HasDiv = 1
};

View File

@ -3176,6 +3176,7 @@ struct packet_traits<double> : default_packet_traits {
HasAbs = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasATan = 0,
HasLog = 0,
HasExp = 1,

View File

@ -1112,11 +1112,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x)
// floating point argument such that the approximation is exactly 1.
// This saves clamping the value at the end.
#ifdef EIGEN_VECTORIZE_FMA
const T plus_clamp = pset1<T>(8.05359363555908203f);
const T minus_clamp = pset1<T>(-8.05359363555908203f);
const T plus_clamp = pset1<T>(8.01773357391357422f);
const T minus_clamp = pset1<T>(-8.01773357391357422f);
#else
const T plus_clamp = pset1<T>(7.98551225662231445);
const T minus_clamp = pset1<T>(-7.98551225662231445);
const T plus_clamp = pset1<T>(7.90738964080810547f);
const T minus_clamp = pset1<T>(-7.90738964080810547f);
#endif
const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
@ -1128,17 +1128,17 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x)
// --output=tanhf.sollya --dispCoeff="dec"
// The monomial coefficients of the numerator polynomial (odd).
const T alpha_3 = pset1<T>(1.340216100215911865234375e-1f);
const T alpha_5 = pset1<T>(3.52075672708451747894287109375e-3f);
const T alpha_7 = pset1<T>(2.10273356060497462749481201171875e-5f);
const T alpha_9 = pset1<T>(1.39455362813123429077677428722381591796875e-8f);
const T alpha_3 = pset1<T>(1.340216100e-1f);
const T alpha_5 = pset1<T>(3.520756727e-3f);
const T alpha_7 = pset1<T>(2.102733560e-5f);
const T alpha_9 = pset1<T>(1.394553628e-8f);
// The monomial coefficients of the denominator polynomial (even).
const T beta_0 = pset1<T>(1.0f);
const T beta_2 = pset1<T>(4.67354834079742431640625e-1f);
const T beta_4 = pset1<T>(2.5972545146942138671875e-2f);
const T beta_6 = pset1<T>(3.326951409690082073211669921875e-4f);
const T beta_8 = pset1<T>(8.015776984393596649169921875e-7f);
const T beta_2 = pset1<T>(4.673548340e-1f);
const T beta_4 = pset1<T>(2.597254514e-2f);
const T beta_6 = pset1<T>(3.326951409e-4f);
const T beta_8 = pset1<T>(8.015776984e-7f);
// Since the polynomials are odd/even, we need x^2.
const T x2 = pmul(x, x);
@ -1161,6 +1161,91 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x)
return pdiv(p, q);
}
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
This uses a 19/18-degree rational interpolant which
is accurate up to a couple of ulps in the (approximate) range [-18.7, 18.7],
outside of which tanh(x) = +/-1 in single precision. The input is clamped
to the range [-c, c]. The value c is chosen as the smallest value where
the approximation evaluates to exactly 1.
This implementation works on both scalars and packets.
*/
template <typename T>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_double(const T& a_x) {
// Clamp the inputs to the range [-c, c] and set everything
// outside that range to 1.0. The value c is chosen as the smallest
// floating point argument such that the approximation is exactly 1.
// This saves clamping the value at the end.
#ifdef EIGEN_VECTORIZE_FMA
const T plus_clamp = pset1<T>(17.6610191624600077);
const T minus_clamp = pset1<T>(-17.6610191624600077);
#else
const T plus_clamp = pset1<T>(17.714196154005176);
const T minus_clamp = pset1<T>(-17.714196154005176);
#endif
const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
// The following rational approximation was generated by rminimax
// (https://gitlab.inria.fr/sfilip/rminimax) using the following
// command:
// $ ./ratapprox --function="tanh(x)" --dom='[-18.72,18.72]'
// --num="odd" --den="even" --type="[19,18]" --numF="[D]"
// --denF="[D]" --log --output=tanh.sollya --dispCoeff="dec"
// The monomial coefficients of the numerator polynomial (odd).
const T alpha_3 = pset1<T>(1.5184719640284322e-01);
const T alpha_5 = pset1<T>(5.9809711724441161e-03);
const T alpha_7 = pset1<T>(9.3839087674268880e-05);
const T alpha_9 = pset1<T>(6.8644367682497074e-07);
const T alpha_11 = pset1<T>(2.4618379131293676e-09);
const T alpha_13 = pset1<T>(4.2303918148209176e-12);
const T alpha_15 = pset1<T>(3.1309488231386680e-15);
const T alpha_17 = pset1<T>(7.6534862268749319e-19);
const T alpha_19 = pset1<T>(2.6158007860482230e-23);
// The monomial coefficients of the denominator polynomial (even).
const T beta_0 = pset1<T>(1.0);
const T beta_2 = pset1<T>(4.851805297361760360e-01);
const T beta_4 = pset1<T>(3.437448108450402717e-02);
const T beta_6 = pset1<T>(8.295161192716231542e-04);
const T beta_8 = pset1<T>(8.785185266237658698e-06);
const T beta_10 = pset1<T>(4.492975677839633985e-08);
const T beta_12 = pset1<T>(1.123643448069621992e-10);
const T beta_14 = pset1<T>(1.293019623712687916e-13);
const T beta_16 = pset1<T>(5.782506856739003571e-17);
const T beta_18 = pset1<T>(6.463747022670968018e-21);
// Since the polynomials are odd/even, we need x^2.
const T x2 = pmul(x, x);
const T x3 = pmul(x2, x);
// Interleave the evaluation of the numerator polynomial p and
// denominator polynomial q.
T p = pmadd(x2, alpha_19, alpha_17);
T q = pmadd(x2, beta_18, beta_16);
p = pmadd(x2, p, alpha_15);
q = pmadd(x2, q, beta_14);
p = pmadd(x2, p, alpha_13);
q = pmadd(x2, q, beta_12);
p = pmadd(x2, p, alpha_11);
q = pmadd(x2, q, beta_10);
p = pmadd(x2, p, alpha_9);
q = pmadd(x2, q, beta_8);
p = pmadd(x2, p, alpha_7);
q = pmadd(x2, q, beta_6);
p = pmadd(x2, p, alpha_5);
q = pmadd(x2, q, beta_4);
p = pmadd(x2, p, alpha_3);
q = pmadd(x2, q, beta_2);
// Take advantage of the fact that alpha_1 = 1 to compute
// x*(x^2*p + alpha_1) = x^3 * p + x.
p = pmadd(x3, p, x);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator.
return pdiv(p, q);
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x) {
typedef typename unpacket_traits<Packet>::type Scalar;

View File

@ -110,6 +110,10 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_double(const Pa
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptanh_float(const Packet& x);
/** \internal \returns tanh(x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptanh_double(const Packet& x);
/** \internal \returns atanh(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x);
@ -184,7 +188,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET)
EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(tanh, PACKET)
} // end namespace internal
} // end namespace Eigen

View File

@ -5128,7 +5128,7 @@ struct packet_traits<double> : default_packet_traits {
HasCos = EIGEN_FAST_MATH,
HasSqrt = 1,
HasRsqrt = 1,
HasTanh = 0,
HasTanh = EIGEN_FAST_MATH,
HasErf = 0
};
};

View File

@ -214,6 +214,7 @@ struct packet_traits<double> : default_packet_traits {
HasDiv = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasSqrt = 1,