diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index 38f123cff..eaa4352b8 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -518,22 +518,26 @@ struct scalar_atan2_op { return static_cast(atan2(y, x)); } template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Packet> packetOp(const Packet& y, - const Packet& x) const { + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + std::enable_if_t::value, Packet> + packetOp(const Packet& y, const Packet& x) const { // See https://en.cppreference.com/w/cpp/numeric/math/atan2 // for how corner cases are supposed to be handled according to the // IEEE floating-point standard (IEC 60559). - const Packet kSignMask = pset1(-Scalar(0.0)); + const Packet kSignMask = pset1(-Scalar(0)); const Packet kPi = pset1(Scalar(EIGEN_PI)); const Packet kPiO2 = pset1(Scalar(EIGEN_PI / 2)); const Packet kPiO4 = pset1(Scalar(EIGEN_PI / 4)); - constexpr Scalar k3PiO4f = Scalar(3.0 * (EIGEN_PI / 4)); - const Packet k3PiO4 = pset1(k3PiO4f); - Packet x_neg = pcmp_lt(x, pzero(x)); - Packet x_sign = pand(x, kSignMask); - Packet y_sign = pand(y, kSignMask); - Packet x_zero = pcmp_eq(x, pzero(x)); - Packet y_zero = pcmp_eq(y, pzero(y)); + const Packet k3PiO4 = pset1(Scalar(3.0 * (EIGEN_PI / 4))); + + // Various predicates about the inputs. + Packet x_signbit = pand(x, kSignMask); + Packet x_has_signbit = pcmp_lt(por(x_signbit, kPi), pzero(x)); + Packet x_is_zero = pcmp_eq(x, pzero(x)); + Packet x_neg = pandnot(x_has_signbit, x_is_zero); + + Packet y_signbit = pand(y, kSignMask); + Packet y_is_zero = pcmp_eq(y, pzero(y)); Packet x_is_not_nan = pcmp_eq(x, x); Packet y_is_not_nan = pcmp_eq(y, y); @@ -542,20 +546,22 @@ struct scalar_atan2_op { Packet result = patan(pdiv(y, x)); // Compute shift for when x != 0 and y != 0. - Packet shift = pselect(x_neg, por(kPi, y_sign), pzero(x)); + Packet shift = pselect(x_neg, por(kPi, y_signbit), pzero(x)); // Special cases: // Handle x = +/-inf && y = +/-inf. Packet is_not_nan = pcmp_eq(result, result); - result = pselect(is_not_nan, padd(shift, result), - pselect(x_neg, por(k3PiO4, y_sign), por(kPiO4, y_sign))); - // Handle x == +/-0. result = - pselect(x_zero, pselect(y_zero, pzero(y), por(y_sign, kPiO2)), result); + pselect(is_not_nan, padd(shift, result), + pselect(x_neg, por(k3PiO4, y_signbit), por(kPiO4, y_signbit))); + // Handle x == +/-0. + result = pselect( + x_is_zero, pselect(y_is_zero, pzero(y), por(y_signbit, kPiO2)), result); // Handle y == +/-0. - result = pselect(y_zero, - pselect(x_sign, por(y_sign, kPi), por(y_sign, pzero(y))), - result); + result = pselect( + y_is_zero, + pselect(x_has_signbit, por(y_signbit, kPi), por(y_signbit, pzero(y))), + result); // Handle NaN inputs. Packet kQNaN = pset1(NumTraits::quiet_NaN()); return pselect(pand(x_is_not_nan, y_is_not_nan), result, kQNaN);