From 32d95bb0973bf13336de3cf99d5b5e0579307d82 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 21 Aug 2024 02:29:45 +0000 Subject: [PATCH] Add vectorized implementation of tanh --- Eigen/src/Core/arch/AVX/MathFunctions.h | 1 + Eigen/src/Core/arch/AVX/PacketMath.h | 1 + Eigen/src/Core/arch/AVX512/PacketMath.h | 1 + Eigen/src/Core/arch/AltiVec/PacketMath.h | 1 + .../arch/Default/GenericPacketMathFunctions.h | 109 ++++++++++++++++-- .../Default/GenericPacketMathFunctionsFwd.h | 7 +- Eigen/src/Core/arch/NEON/PacketMath.h | 2 +- Eigen/src/Core/arch/SSE/PacketMath.h | 1 + 8 files changed, 109 insertions(+), 14 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index 321188c4b..42933f77b 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -27,6 +27,7 @@ EIGEN_DOUBLE_PACKET_FUNCTION(atan, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d) +EIGEN_DOUBLE_PACKET_FUNCTION(tanh, Packet4d) #ifdef EIGEN_VECTORIZE_AVX2 EIGEN_DOUBLE_PACKET_FUNCTION(sin, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(cos, Packet4d) diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index ea58f0ed2..b9a8fc3a0 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -142,6 +142,7 @@ struct packet_traits : default_packet_traits { HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, #endif + HasTanh = EIGEN_FAST_MATH, HasLog = 1, HasExp = 1, HasSqrt = 1, diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 0659ddd15..36195430d 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -153,6 +153,7 @@ struct packet_traits : default_packet_traits { HasLog = 1, HasExp = 1, HasATan = 1, + HasTanh = EIGEN_FAST_MATH, HasCmp = 1, HasDiv = 1 }; diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index 6c472cdf1..29c5a5da3 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -3176,6 +3176,7 @@ struct packet_traits : default_packet_traits { HasAbs = 1, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, HasATan = 0, HasLog = 0, HasExp = 1, diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index b50ceb9d2..42d812f5a 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1112,11 +1112,11 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x) // floating point argument such that the approximation is exactly 1. // This saves clamping the value at the end. #ifdef EIGEN_VECTORIZE_FMA - const T plus_clamp = pset1(8.05359363555908203f); - const T minus_clamp = pset1(-8.05359363555908203f); + const T plus_clamp = pset1(8.01773357391357422f); + const T minus_clamp = pset1(-8.01773357391357422f); #else - const T plus_clamp = pset1(7.98551225662231445); - const T minus_clamp = pset1(-7.98551225662231445); + const T plus_clamp = pset1(7.90738964080810547f); + const T minus_clamp = pset1(-7.90738964080810547f); #endif const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); @@ -1128,17 +1128,17 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x) // --output=tanhf.sollya --dispCoeff="dec" // The monomial coefficients of the numerator polynomial (odd). - const T alpha_3 = pset1(1.340216100215911865234375e-1f); - const T alpha_5 = pset1(3.52075672708451747894287109375e-3f); - const T alpha_7 = pset1(2.10273356060497462749481201171875e-5f); - const T alpha_9 = pset1(1.39455362813123429077677428722381591796875e-8f); + const T alpha_3 = pset1(1.340216100e-1f); + const T alpha_5 = pset1(3.520756727e-3f); + const T alpha_7 = pset1(2.102733560e-5f); + const T alpha_9 = pset1(1.394553628e-8f); // The monomial coefficients of the denominator polynomial (even). const T beta_0 = pset1(1.0f); - const T beta_2 = pset1(4.67354834079742431640625e-1f); - const T beta_4 = pset1(2.5972545146942138671875e-2f); - const T beta_6 = pset1(3.326951409690082073211669921875e-4f); - const T beta_8 = pset1(8.015776984393596649169921875e-7f); + const T beta_2 = pset1(4.673548340e-1f); + const T beta_4 = pset1(2.597254514e-2f); + const T beta_6 = pset1(3.326951409e-4f); + const T beta_8 = pset1(8.015776984e-7f); // Since the polynomials are odd/even, we need x^2. const T x2 = pmul(x, x); @@ -1161,6 +1161,91 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_float(const T& a_x) return pdiv(p, q); } +/** \internal \returns the hyperbolic tan of \a a (coeff-wise) + This uses a 19/18-degree rational interpolant which + is accurate up to a couple of ulps in the (approximate) range [-18.7, 18.7], + outside of which tanh(x) = +/-1 in single precision. The input is clamped + to the range [-c, c]. The value c is chosen as the smallest value where + the approximation evaluates to exactly 1. + + This implementation works on both scalars and packets. +*/ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS T ptanh_double(const T& a_x) { + // Clamp the inputs to the range [-c, c] and set everything + // outside that range to 1.0. The value c is chosen as the smallest + // floating point argument such that the approximation is exactly 1. + // This saves clamping the value at the end. +#ifdef EIGEN_VECTORIZE_FMA + const T plus_clamp = pset1(17.6610191624600077); + const T minus_clamp = pset1(-17.6610191624600077); +#else + const T plus_clamp = pset1(17.714196154005176); + const T minus_clamp = pset1(-17.714196154005176); +#endif + const T x = pmax(pmin(a_x, plus_clamp), minus_clamp); + + // The following rational approximation was generated by rminimax + // (https://gitlab.inria.fr/sfilip/rminimax) using the following + // command: + // $ ./ratapprox --function="tanh(x)" --dom='[-18.72,18.72]' + // --num="odd" --den="even" --type="[19,18]" --numF="[D]" + // --denF="[D]" --log --output=tanh.sollya --dispCoeff="dec" + + // The monomial coefficients of the numerator polynomial (odd). + const T alpha_3 = pset1(1.5184719640284322e-01); + const T alpha_5 = pset1(5.9809711724441161e-03); + const T alpha_7 = pset1(9.3839087674268880e-05); + const T alpha_9 = pset1(6.8644367682497074e-07); + const T alpha_11 = pset1(2.4618379131293676e-09); + const T alpha_13 = pset1(4.2303918148209176e-12); + const T alpha_15 = pset1(3.1309488231386680e-15); + const T alpha_17 = pset1(7.6534862268749319e-19); + const T alpha_19 = pset1(2.6158007860482230e-23); + + // The monomial coefficients of the denominator polynomial (even). + const T beta_0 = pset1(1.0); + const T beta_2 = pset1(4.851805297361760360e-01); + const T beta_4 = pset1(3.437448108450402717e-02); + const T beta_6 = pset1(8.295161192716231542e-04); + const T beta_8 = pset1(8.785185266237658698e-06); + const T beta_10 = pset1(4.492975677839633985e-08); + const T beta_12 = pset1(1.123643448069621992e-10); + const T beta_14 = pset1(1.293019623712687916e-13); + const T beta_16 = pset1(5.782506856739003571e-17); + const T beta_18 = pset1(6.463747022670968018e-21); + + // Since the polynomials are odd/even, we need x^2. + const T x2 = pmul(x, x); + const T x3 = pmul(x2, x); + + // Interleave the evaluation of the numerator polynomial p and + // denominator polynomial q. + T p = pmadd(x2, alpha_19, alpha_17); + T q = pmadd(x2, beta_18, beta_16); + p = pmadd(x2, p, alpha_15); + q = pmadd(x2, q, beta_14); + p = pmadd(x2, p, alpha_13); + q = pmadd(x2, q, beta_12); + p = pmadd(x2, p, alpha_11); + q = pmadd(x2, q, beta_10); + p = pmadd(x2, p, alpha_9); + q = pmadd(x2, q, beta_8); + p = pmadd(x2, p, alpha_7); + q = pmadd(x2, q, beta_6); + p = pmadd(x2, p, alpha_5); + q = pmadd(x2, q, beta_4); + p = pmadd(x2, p, alpha_3); + q = pmadd(x2, q, beta_2); + // Take advantage of the fact that alpha_1 = 1 to compute + // x*(x^2*p + alpha_1) = x^3 * p + x. + p = pmadd(x3, p, x); + q = pmadd(x2, q, beta_0); + + // Divide the numerator by the denominator. + return pdiv(p, q); +} + template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x) { typedef typename unpacket_traits::type Scalar; diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index 4d113ca6a..b200b3be7 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -110,6 +110,10 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_double(const Pa template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptanh_float(const Packet& x); +/** \internal \returns tanh(x) for double precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptanh_double(const Packet& x); + /** \internal \returns atanh(x) for single precision float */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x); @@ -184,7 +188,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a); EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET) \ EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET) \ EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \ - EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) + EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET) \ + EIGEN_DOUBLE_PACKET_FUNCTION(tanh, PACKET) } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 794d06316..56e8b2d51 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -5128,7 +5128,7 @@ struct packet_traits : default_packet_traits { HasCos = EIGEN_FAST_MATH, HasSqrt = 1, HasRsqrt = 1, - HasTanh = 0, + HasTanh = EIGEN_FAST_MATH, HasErf = 0 }; }; diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index e5dce3b49..00925f685 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -214,6 +214,7 @@ struct packet_traits : default_packet_traits { HasDiv = 1, HasSin = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH, + HasTanh = EIGEN_FAST_MATH, HasLog = 1, HasExp = 1, HasSqrt = 1,