diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index 2169c1e4e..dae0d71b7 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -521,19 +521,21 @@ struct scalar_atan2_op { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Packet> packetOp(const Packet& y, const Packet& x) const { // See https://en.cppreference.com/w/cpp/numeric/math/atan2 - // for how corner cases are supposed to be handles according to the + // for how corner cases are supposed to be handled according to the // IEEE floating-point standard (IEC 60559). - constexpr Scalar k3PiO3f = Scalar(3 * (EIGEN_PI / 4)); const Packet kSignMask = pset1(Scalar(-0.0)); const Packet kPi = pset1(Scalar(EIGEN_PI)); const Packet kPiO2 = pset1(Scalar(EIGEN_PI / 2)); const Packet kPiO4 = pset1(Scalar(EIGEN_PI / 4)); - const Packet k3PiO4 = pset1(k3PiO3f); + constexpr Scalar k3PiO4f = Scalar(3.0 * (EIGEN_PI / 4)); + const Packet k3PiO4 = pset1(k3PiO4f); Packet x_neg = pcmp_lt(x, pzero(x)); Packet x_sign = pand(x, kSignMask); Packet y_sign = pand(y, kSignMask); Packet x_zero = pcmp_eq(x, pzero(x)); Packet y_zero = pcmp_eq(y, pzero(y)); + Packet x_is_not_nan = pcmp_eq(x, x); + Packet y_is_not_nan = pcmp_eq(y, y); // Compute the normal case. Notice that we expect that // finite/infinite = +/-0 here. @@ -554,7 +556,9 @@ struct scalar_atan2_op { result = pselect(y_zero, pselect(x_sign, por(y_sign, kPi), por(y_sign, pzero(y))), result); - return result; + // Handle NaN inputs. + Packet kQNaN = pset1(NumTraits::quiet_NaN()); + return pselect(pand(x_is_not_nan, y_is_not_nan), result, kQNaN); } };