From c475228b28d1bc81d2247860c6cd907bb316a0d6 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Sat, 1 Oct 2022 01:49:30 +0000 Subject: [PATCH] Vectorize atan() for double. --- Eigen/src/Core/arch/AVX/MathFunctions.h | 6 ++ Eigen/src/Core/arch/AVX/PacketMath.h | 1 + Eigen/src/Core/arch/AVX512/MathFunctions.h | 6 ++ Eigen/src/Core/arch/AVX512/PacketMath.h | 1 + Eigen/src/Core/arch/AltiVec/MathFunctions.h | 6 ++ Eigen/src/Core/arch/AltiVec/PacketMath.h | 1 + .../arch/Default/GenericPacketMathFunctions.h | 66 ++++++++++++++++++- .../Default/GenericPacketMathFunctionsFwd.h | 5 ++ Eigen/src/Core/arch/NEON/MathFunctions.h | 3 + Eigen/src/Core/arch/NEON/PacketMath.h | 7 +- Eigen/src/Core/arch/SSE/MathFunctions.h | 5 ++ Eigen/src/Core/arch/SSE/PacketMath.h | 1 + 12 files changed, 104 insertions(+), 4 deletions(-) diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index 43aea548f..cb7d7b82a 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -50,6 +50,12 @@ patan(const Packet8f& _x) { return patan_float(_x); } +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet4d +patan(const Packet4d& _x) { + return patan_double(_x); +} + template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8f plog(const Packet8f& _x) { diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 227e88a86..ecbb73c6c 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -109,6 +109,7 @@ template<> struct packet_traits : default_packet_traits HasExp = 1, HasSqrt = 1, HasRsqrt = 1, + HasATan = 1, HasBlend = 1, HasRound = 1, HasFloor = 1, diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index 79a904389..af47a85cf 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -275,6 +275,12 @@ patan(const Packet16f& _x) { return patan_float(_x); } +template <> +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d +patan(const Packet8d& _x) { + return patan_double(_x); +} + template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f ptanh(const Packet16f& _x) { diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 5e9670c52..bf6fc3721 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -161,6 +161,7 @@ template<> struct packet_traits : default_packet_traits HasSqrt = EIGEN_FAST_MATH, HasRsqrt = EIGEN_FAST_MATH, #endif + HasATan = 1, HasCmp = 1, HasDiv = 1, HasRound = 1, diff --git a/Eigen/src/Core/arch/AltiVec/MathFunctions.h b/Eigen/src/Core/arch/AltiVec/MathFunctions.h index fffd2e5cd..057796121 100644 --- a/Eigen/src/Core/arch/AltiVec/MathFunctions.h +++ b/Eigen/src/Core/arch/AltiVec/MathFunctions.h @@ -60,6 +60,12 @@ Packet4f patan(const Packet4f& _x) return patan_float(_x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet2d patan(const Packet2d& _x) +{ + return patan_double(_x); +} + #ifdef __VSX__ #ifndef EIGEN_COMP_CLANG template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index f03018987..37398de15 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -2708,6 +2708,7 @@ template<> struct packet_traits : default_packet_traits HasAbs = 1, HasSin = 0, HasCos = 0, + HasATan = 1, HasLog = 0, HasExp = 1, HasSqrt = 1, diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index a8837b235..110a54485 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -864,6 +864,68 @@ Packet patan_float(const Packet& x_in) { return pselect(neg_mask, pnegate(p), p); } +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet patan_double(const Packet& x_in) { + typedef typename unpacket_traits::type Scalar; + static_assert(std::is_same::value, "Scalar type must be double"); + + const Packet cst_one = pset1(1.0); + constexpr double kPiOverTwo = static_cast(M_PI_2); + const Packet cst_pi_over_two = pset1(kPiOverTwo); + constexpr double kPiOverFour = static_cast(M_PI_4); + const Packet cst_pi_over_four = pset1(kPiOverFour); + const Packet cst_large = pset1(2.4142135623730950488016887); // tan(3*pi/8); + const Packet cst_medium = pset1(0.4142135623730950488016887); // tan(pi/8); + const Packet q0 = pset1(-0.33333333333330028569463365784031338989734649658203); + const Packet q2 = pset1(0.199999999990664090177006073645316064357757568359375); + const Packet q4 = pset1(-0.142857141937123677255527809393242932856082916259766); + const Packet q6 = pset1(0.111111065991039953404495577160560060292482376098633); + const Packet q8 = pset1(-9.0907812986129224452902519715280504897236824035645e-2); + const Packet q10 = pset1(7.6900542950704739442180368769186316058039665222168e-2); + const Packet q12 = pset1(-6.6410112986494976294871150912513257935643196105957e-2); + const Packet q14 = pset1(5.6920144995467943094258345126945641823112964630127e-2); + const Packet q16 = pset1(-4.3577020814990513608577771265117917209863662719727e-2); + const Packet q18 = pset1(2.1244050233624342527427586446719942614436149597168e-2); + + const Packet neg_mask = pcmp_lt(x_in, pzero(x_in)); + Packet x = pabs(x_in); + + // Use the same range reduction strategy (to [0:tan(pi/8)]) as the + // Cephes library: + // "Large": For x >= tan(3*pi/8), use atan(1/x) = pi/2 - atan(x). + // "Medium": For x in [tan(pi/8) : tan(3*pi/8)), + // use atan(x) = pi/4 + atan((x-1)/(x+1)). + // "Small": For x < tan(pi/8), approximate atan(x) directly by a polynomial + // calculated using Sollya. + const Packet large_mask = pcmp_lt(cst_large, x); + x = pselect(large_mask, preciprocal(x), x); + const Packet medium_mask = pandnot(pcmp_lt(cst_medium, x), large_mask); + x = pselect(medium_mask, pdiv(psub(x, cst_one), padd(x, cst_one)), x); + + // Approximate atan(x) on [0:tan(pi/8)] by a polynomial of the form + // P(x) = x + x^3 * Q(x^2), + // where Q(x^2) is a 9th order polynomial in x^2. + const Packet x2 = pmul(x, x); + const Packet x4 = pmul(x2, x2); + Packet q_odd = pmadd(q18, x4, q14); + Packet q_even = pmadd(q16, x4, q12); + q_odd = pmadd(q_odd, x4, q10); + q_even = pmadd(q_even, x4, q8); + q_odd = pmadd(q_odd, x4, q6); + q_even = pmadd(q_even, x4, q4); + q_odd = pmadd(q_odd, x4, q2); + q_even = pmadd(q_even, x4, q0); + const Packet q = pmadd(q_odd, x2, q_even); + Packet p = pmadd(q, pmul(x, x2), x); + + // Apply transformations according to the range reduction masks. + p = pselect(large_mask, psub(cst_pi_over_two, p), p); + p = pselect(medium_mask, padd(cst_pi_over_four, p), p); + return pselect(neg_mask, pnegate(p), p); +} + + template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pdiv_complex(const Packet& x, const Packet& y) { @@ -958,8 +1020,8 @@ Packet psqrt_complex(const Packet& a) { // Step 4. Compute solution for inputs with negative real part: // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] - const RealScalar neg_zero = RealScalar(numext::bit_cast(0x80000000u)); - const RealPacket cst_imag_sign_mask = pset1(Scalar(RealScalar(0.0), neg_zero)).v; + const RealPacket cst_imag_sign_mask = + pset1(Scalar(RealScalar(0.0), RealScalar(-0.0))).v; RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); Packet negative_real_result; // Notice that rho is positive, so taking it's absolute value is a noop. diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h index de6fd9542..179c55cf3 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctionsFwd.h @@ -104,6 +104,11 @@ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_float(const Packet& x); +/** \internal \returns atan(x) for double precision float */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet patan_double(const Packet& x); + /** \internal \returns sqrt(x) for complex types */ template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS diff --git a/Eigen/src/Core/arch/NEON/MathFunctions.h b/Eigen/src/Core/arch/NEON/MathFunctions.h index 445572f8b..aea514949 100644 --- a/Eigen/src/Core/arch/NEON/MathFunctions.h +++ b/Eigen/src/Core/arch/NEON/MathFunctions.h @@ -102,6 +102,9 @@ template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet2d pexp EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet2d plog(const Packet2d& x) { return plog_double(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet2d patan(const Packet2d& x) +{ return patan_double(x); } + #endif } // end namespace internal diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index 9ef83d498..5cbf4ac97 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3762,10 +3762,13 @@ template<> struct packet_traits : default_packet_traits HasCeil = 1, HasRint = 1, +#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG + HasExp = 1, + HasLog = 1, + HasATan = 1, +#endif HasSin = 0, HasCos = 0, - HasLog = 1, - HasExp = 1, HasSqrt = 1, HasRsqrt = 1, HasTanh = 0, diff --git a/Eigen/src/Core/arch/SSE/MathFunctions.h b/Eigen/src/Core/arch/SSE/MathFunctions.h index 8e8a0a48d..f98fb7a3e 100644 --- a/Eigen/src/Core/arch/SSE/MathFunctions.h +++ b/Eigen/src/Core/arch/SSE/MathFunctions.h @@ -81,6 +81,11 @@ Packet4f pacos(const Packet4f& _x) return pacos_float(_x); } +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet2d patan(const Packet2d& _x) { + return patan_double(_x); +} + template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet4f pasin(const Packet4f& _x) { diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index f9426689e..847ff07d8 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -178,6 +178,7 @@ struct packet_traits : default_packet_traits { HasExp = 1, HasSqrt = 1, HasRsqrt = 1, + HasATan = 1, HasBlend = 1, HasFloor = 1, HasCeil = 1,