From e95c4a837fb0057cfc6ac9b55073f5d1fe61ae2f Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Tue, 4 Oct 2022 18:11:00 +0000 Subject: [PATCH] Simpler range reduction strategy for atan(). --- .../arch/Default/GenericPacketMathFunctions.h | 143 +++++++++++------- 1 file changed, 85 insertions(+), 58 deletions(-) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 110a54485..22f6fb977 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -815,6 +815,37 @@ Packet pasin_float(const Packet& x_in) { return pselect(invalid_mask, pset1(std::numeric_limits::quiet_NaN()), p); } +// Computes elementwise atan(x) for x in [-1:1] with 2 ulp accuracy. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet patan_reduced_float(const Packet& x) { + const Packet q0 = pset1(-0.3333314359188079833984375f); + const Packet q2 = pset1(0.19993579387664794921875f); + const Packet q4 = pset1(-0.14209578931331634521484375f); + const Packet q6 = pset1(0.1066047251224517822265625f); + const Packet q8 = pset1(-7.5408883392810821533203125e-2f); + const Packet q10 = pset1(4.3082617223262786865234375e-2f); + const Packet q12 = pset1(-1.62907354533672332763671875e-2f); + const Packet q14 = pset1(2.90188402868807315826416015625e-3f); + + // Approximate atan(x) by a polynomial of the form + // P(x) = x + x^3 * Q(x^2), + // where Q(x^2) is a 7th order polynomial in x^2. + // We evaluate even and odd terms in x^2 in parallel + // to take advantage of instruction level parallelism + // and hardware with multiple FMA units. + const Packet x2 = pmul(x, x); + const Packet x4 = pmul(x2, x2); + Packet q_odd = pmadd(q14, x4, q10); + Packet q_even = pmadd(q12, x4, q8); + q_odd = pmadd(q_odd, x4, q6); + q_even = pmadd(q_even, x4, q4); + q_odd = pmadd(q_odd, x4, q2); + q_even = pmadd(q_even, x4, q0); + const Packet q = pmadd(q_odd, x2, q_even); + return pmadd(q, pmul(x, x2), x); +} + template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patan_float(const Packet& x_in) { @@ -823,45 +854,64 @@ Packet patan_float(const Packet& x_in) { const Packet cst_one = pset1(1.0f); constexpr float kPiOverTwo = static_cast(EIGEN_PI/2); - const Packet cst_pi_over_two = pset1(kPiOverTwo); - constexpr float kPiOverFour = static_cast(EIGEN_PI/4); - const Packet cst_pi_over_four = pset1(kPiOverFour); - const Packet cst_large = pset1(2.4142135623730950488016887f); // tan(3*pi/8); - const Packet cst_medium = pset1(0.4142135623730950488016887f); // tan(pi/8); - const Packet q0 = pset1(-0.333329379558563232421875f); - const Packet q2 = pset1(0.19977366924285888671875f); - const Packet q4 = pset1(-0.13874518871307373046875f); - const Packet q6 = pset1(8.044691383838653564453125e-2f); - const Packet neg_mask = pcmp_lt(x_in, pzero(x_in)); - Packet x = pabs(x_in); - - // Use the same range reduction strategy (to [0:tan(pi/8)]) as the - // Cephes library: - // "Large": For x >= tan(3*pi/8), use atan(1/x) = pi/2 - atan(x). - // "Medium": For x in [tan(pi/8) : tan(3*pi/8)), - // use atan(x) = pi/4 + atan((x-1)/(x+1)). - // "Small": For x < pi/8, approximate atan(x) directly by a polynomial + // "Large": For |x| > 1, use atan(1/x) = sign(x)*pi/2 - atan(x). + // "Small": For |x| <= 1, approximate atan(x) directly by a polynomial // calculated using Sollya. - const Packet large_mask = pcmp_lt(cst_large, x); - x = pselect(large_mask, preciprocal(x), x); - const Packet medium_mask = pandnot(pcmp_lt(cst_medium, x), large_mask); - x = pselect(medium_mask, pdiv(psub(x, cst_one), padd(x, cst_one)), x); + const Packet neg_mask = pcmp_lt(x_in, pzero(x_in)); + const Packet large_mask = pcmp_lt(cst_one, pabs(x_in)); + const Packet large_shift = pselect(neg_mask, pset1(-kPiOverTwo), pset1(kPiOverTwo)); + const Packet x = pselect(large_mask, preciprocal(x_in), x_in); + const Packet p = patan_reduced_float(x); + + // Apply transformations according to the range reduction masks. + return pselect(large_mask, psub(large_shift, p), p); +} + +// Computes elementwise atan(x) for x in [-tan(pi/8):tan(pi/8)] +// with 2 ulp accuracy. +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet +patan_reduced_double(const Packet& x) { + const Packet q0 = + pset1(-0.33333333333330028569463365784031338989734649658203); + const Packet q2 = + pset1(0.199999999990664090177006073645316064357757568359375); + const Packet q4 = + pset1(-0.142857141937123677255527809393242932856082916259766); + const Packet q6 = + pset1(0.111111065991039953404495577160560060292482376098633); + const Packet q8 = + pset1(-9.0907812986129224452902519715280504897236824035645e-2); + const Packet q10 = + pset1(7.6900542950704739442180368769186316058039665222168e-2); + const Packet q12 = + pset1(-6.6410112986494976294871150912513257935643196105957e-2); + const Packet q14 = + pset1(5.6920144995467943094258345126945641823112964630127e-2); + const Packet q16 = + pset1(-4.3577020814990513608577771265117917209863662719727e-2); + const Packet q18 = + pset1(2.1244050233624342527427586446719942614436149597168e-2); // Approximate atan(x) on [0:tan(pi/8)] by a polynomial of the form // P(x) = x + x^3 * Q(x^2), - // where Q(x^2) is a cubic polynomial in x^2. + // where Q(x^2) is a 9th order polynomial in x^2. + // We evaluate even and odd terms in x^2 in parallel + // to take advantage of instruction level parallelism + // and hardware with multiple FMA units. const Packet x2 = pmul(x, x); const Packet x4 = pmul(x2, x2); - Packet q_odd = pmadd(q6, x4, q2); - Packet q_even = pmadd(q4, x4, q0); - const Packet q = pmadd(q_odd, x2, q_even); - Packet p = pmadd(q, pmul(x, x2), x); - - // Apply transformations according to the range reduction masks. - p = pselect(large_mask, psub(cst_pi_over_two, p), p); - p = pselect(medium_mask, padd(cst_pi_over_four, p), p); - return pselect(neg_mask, pnegate(p), p); + Packet q_odd = pmadd(q18, x4, q14); + Packet q_even = pmadd(q16, x4, q12); + q_odd = pmadd(q_odd, x4, q10); + q_even = pmadd(q_even, x4, q8); + q_odd = pmadd(q_odd, x4, q6); + q_even = pmadd(q_even, x4, q4); + q_odd = pmadd(q_odd, x4, q2); + q_even = pmadd(q_even, x4, q0); + const Packet p = pmadd(q_odd, x2, q_even); + return pmadd(p, pmul(x, x2), x); } template @@ -877,16 +927,6 @@ Packet patan_double(const Packet& x_in) { const Packet cst_pi_over_four = pset1(kPiOverFour); const Packet cst_large = pset1(2.4142135623730950488016887); // tan(3*pi/8); const Packet cst_medium = pset1(0.4142135623730950488016887); // tan(pi/8); - const Packet q0 = pset1(-0.33333333333330028569463365784031338989734649658203); - const Packet q2 = pset1(0.199999999990664090177006073645316064357757568359375); - const Packet q4 = pset1(-0.142857141937123677255527809393242932856082916259766); - const Packet q6 = pset1(0.111111065991039953404495577160560060292482376098633); - const Packet q8 = pset1(-9.0907812986129224452902519715280504897236824035645e-2); - const Packet q10 = pset1(7.6900542950704739442180368769186316058039665222168e-2); - const Packet q12 = pset1(-6.6410112986494976294871150912513257935643196105957e-2); - const Packet q14 = pset1(5.6920144995467943094258345126945641823112964630127e-2); - const Packet q16 = pset1(-4.3577020814990513608577771265117917209863662719727e-2); - const Packet q18 = pset1(2.1244050233624342527427586446719942614436149597168e-2); const Packet neg_mask = pcmp_lt(x_in, pzero(x_in)); Packet x = pabs(x_in); @@ -903,21 +943,9 @@ Packet patan_double(const Packet& x_in) { const Packet medium_mask = pandnot(pcmp_lt(cst_medium, x), large_mask); x = pselect(medium_mask, pdiv(psub(x, cst_one), padd(x, cst_one)), x); - // Approximate atan(x) on [0:tan(pi/8)] by a polynomial of the form - // P(x) = x + x^3 * Q(x^2), - // where Q(x^2) is a 9th order polynomial in x^2. - const Packet x2 = pmul(x, x); - const Packet x4 = pmul(x2, x2); - Packet q_odd = pmadd(q18, x4, q14); - Packet q_even = pmadd(q16, x4, q12); - q_odd = pmadd(q_odd, x4, q10); - q_even = pmadd(q_even, x4, q8); - q_odd = pmadd(q_odd, x4, q6); - q_even = pmadd(q_even, x4, q4); - q_odd = pmadd(q_odd, x4, q2); - q_even = pmadd(q_even, x4, q0); - const Packet q = pmadd(q_odd, x2, q_even); - Packet p = pmadd(q, pmul(x, x2), x); + // Compute approximation of p ~= atan(x') where x' is the argument reduced to + // [0:tan(pi/8)]. + Packet p = patan_reduced_double(x); // Apply transformations according to the range reduction masks. p = pselect(large_mask, psub(cst_pi_over_two, p), p); @@ -925,7 +953,6 @@ Packet patan_double(const Packet& x_in) { return pselect(neg_mask, pnegate(p), p); } - template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pdiv_complex(const Packet& x, const Packet& y) {