Handle NaN inputs to atan2.

This commit is contained in:
Rasmus Munk Larsen 2022-10-10 19:36:36 -07:00
parent 72db3f0fa5
commit 3167544873

View File

@ -521,19 +521,21 @@ struct scalar_atan2_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Packet> packetOp(const Packet& y,
const Packet& x) const {
// See https://en.cppreference.com/w/cpp/numeric/math/atan2
// for how corner cases are supposed to be handles according to the
// for how corner cases are supposed to be handled according to the
// IEEE floating-point standard (IEC 60559).
constexpr Scalar k3PiO3f = Scalar(3 * (EIGEN_PI / 4));
const Packet kSignMask = pset1<Packet>(Scalar(-0.0));
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
const Packet kPiO2 = pset1<Packet>(Scalar(EIGEN_PI / 2));
const Packet kPiO4 = pset1<Packet>(Scalar(EIGEN_PI / 4));
const Packet k3PiO4 = pset1<Packet>(k3PiO3f);
constexpr Scalar k3PiO4f = Scalar(3.0 * (EIGEN_PI / 4));
const Packet k3PiO4 = pset1<Packet>(k3PiO4f);
Packet x_neg = pcmp_lt(x, pzero(x));
Packet x_sign = pand(x, kSignMask);
Packet y_sign = pand(y, kSignMask);
Packet x_zero = pcmp_eq(x, pzero(x));
Packet y_zero = pcmp_eq(y, pzero(y));
Packet x_is_not_nan = pcmp_eq(x, x);
Packet y_is_not_nan = pcmp_eq(y, y);
// Compute the normal case. Notice that we expect that
// finite/infinite = +/-0 here.
@ -554,7 +556,9 @@ struct scalar_atan2_op {
result = pselect(y_zero,
pselect(x_sign, por(y_sign, kPi), por(y_sign, pzero(y))),
result);
return result;
// Handle NaN inputs.
Packet kQNaN = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
return pselect(pand(x_is_not_nan, y_is_not_nan), result, kQNaN);
}
};