From 71a8e60a7aa6f3c46529e0f029dce5d4c9630890 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Wed, 15 Feb 2023 01:01:14 +0000 Subject: [PATCH] Tweak pasin_float, fix psqrt_complex --- .../arch/Default/GenericPacketMathFunctions.h | 49 +++++++++++-------- test/array_cwise.cpp | 1 + 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 03c3e9f7a..a322bc40e 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -770,44 +770,49 @@ Packet pasin_float(const Packet& x_in) { typedef typename unpacket_traits::type Scalar; static_assert(std::is_same::value, "Scalar type must be float"); + constexpr float kPiOverTwo = static_cast(EIGEN_PI / 2); + + const Packet cst_half = pset1(0.5f); + const Packet cst_one = pset1(1.0f); + const Packet cst_two = pset1(2.0f); + const Packet cst_pi_over_two = pset1(kPiOverTwo); // For |x| < 0.5 approximate asin(x)/x by an 8th order polynomial with // even terms only. - const Packet p9 = pset1(Scalar(5.08838854730129241943359375e-2f)); - const Packet p7 = pset1(Scalar(3.95139865577220916748046875e-2f)); - const Packet p5 = pset1(Scalar(7.550220191478729248046875e-2f)); - const Packet p3 = pset1(Scalar(0.16664917767047882080078125f)); - const Packet p1 = pset1(Scalar(1.00000011920928955078125f)); + const Packet p9 = pset1(5.08838854730129241943359375e-2f); + const Packet p7 = pset1(3.95139865577220916748046875e-2f); + const Packet p5 = pset1(7.550220191478729248046875e-2f); + const Packet p3 = pset1(0.16664917767047882080078125f); + const Packet p1 = pset1(1.00000011920928955078125f); + + const Packet abs_x = pabs(x_in); + const Packet sign_mask = pandnot(x_in, abs_x); + const Packet invalid_mask = pcmp_lt(cst_one, abs_x); - const Packet neg_mask = pcmp_lt(x_in, pzero(x_in)); - Packet x = pabs(x_in); - const Packet invalid_mask = pcmp_lt(pset1(1.0f), x); // For arguments |x| > 0.5, we map x back to [0:0.5] using // the transformation x_large = sqrt(0.5*(1-x)), and use the // identity // asin(x) = pi/2 - 2 * asin( sqrt( 0.5 * (1 - x))) - const Packet cst_half = pset1(Scalar(0.5f)); - const Packet cst_two = pset1(Scalar(2)); - Packet x_large = psqrt(pnmadd(cst_half, x, cst_half)); - const Packet large_mask = pcmp_lt(cst_half, x); - x = pselect(large_mask, x_large, x); + + const Packet x_large = psqrt(pnmadd(cst_half, abs_x, cst_half)); + const Packet large_mask = pcmp_lt(cst_half, abs_x); + const Packet x = pselect(large_mask, x_large, abs_x); + const Packet x2 = pmul(x, x); // Compute polynomial. // x * (p1 + x^2*(p3 + x^2*(p5 + x^2*(p7 + x^2*p9)))) - Packet x2 = pmul(x, x); + Packet p = pmadd(p9, x2, p7); p = pmadd(p, x2, p5); p = pmadd(p, x2, p3); p = pmadd(p, x2, p1); p = pmul(p, x); - constexpr float kPiOverTwo = static_cast(EIGEN_PI/2); - Packet p_large = pnmadd(cst_two, p, pset1(kPiOverTwo)); + const Packet p_large = pnmadd(cst_two, p, cst_pi_over_two); p = pselect(large_mask, p_large, p); // Flip the sign for negative arguments. - p = pselect(neg_mask, pnegate(p), p); - + p = pxor(p, sign_mask); // Return NaN for arguments outside [-1:1]. - return pselect(invalid_mask, pset1(std::numeric_limits::quiet_NaN()), p); + return por(invalid_mask, p); } // Computes elementwise atan(x) for x in [-1:1] with 2 ulp accuracy. @@ -1090,9 +1095,11 @@ Packet psqrt_complex(const Packet& a) { is_imag_inf = por(is_imag_inf, pcplxflip(is_imag_inf)); Packet imag_inf_result; imag_inf_result.v = por(pand(cst_pos_inf, real_mask), pandnot(a.v, real_mask)); + // unless otherwise specified, if either the real or imaginary component is nan, the entire result is nan + Packet result_is_nan = pandnot(ptrue(result), pcmp_eq(result, result)); + result = por(result_is_nan, result); - return pselect(is_imag_inf, imag_inf_result, - pselect(is_real_inf, real_inf_result,result)); + return pselect(is_imag_inf, imag_inf_result, pselect(is_real_inf, real_inf_result, result)); } diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index b55723de4..451b1ccb6 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -1005,6 +1005,7 @@ EIGEN_DECLARE_TEST(array_cwise) } for(int i = 0; i < g_repeat; i++) { CALL_SUBTEST_4( array_complex(ArrayXXcf(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE))) ); + CALL_SUBTEST_5( array_complex(ArrayXXcd(internal::random(1,EIGEN_TEST_MAX_SIZE), internal::random(1,EIGEN_TEST_MAX_SIZE)))); } for(int i = 0; i < g_repeat; i++) {