Tweak special case handling in atan2.

This commit is contained in:
Rasmus Munk Larsen 2023-01-31 17:48:00 -08:00
parent a1cdcdb038
commit 37b2e97175

View File

@ -1233,12 +1233,6 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packe
// See https://en.cppreference.com/w/cpp/numeric/math/atan2 // See https://en.cppreference.com/w/cpp/numeric/math/atan2
// for how corner cases are supposed to be handled according to the // for how corner cases are supposed to be handled according to the
// IEEE floating-point standard (IEC 60559). // IEEE floating-point standard (IEC 60559).
// bend two rules:
// 1) inf / inf == 1
// 2) 0 / 0 == 0
// otherwise, evaluate atan(y/x) as usual and shift to the appropriate quadrant
const Packet kSignMask = pset1<Packet>(-Scalar(0)); const Packet kSignMask = pset1<Packet>(-Scalar(0));
const Packet kZero = pzero(x); const Packet kZero = pzero(x);
const Packet kOne = pset1<Packet>(Scalar(1)); const Packet kOne = pset1<Packet>(Scalar(1));
@ -1246,21 +1240,16 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packe
const Packet x_has_signbit = psignbit(x); const Packet x_has_signbit = psignbit(x);
const Packet y_signmask = pand(y, kSignMask); const Packet y_signmask = pand(y, kSignMask);
const Packet x_signmask = pand(x, kSignMask);
const Packet result_signmask = pxor(y_signmask, x_signmask);
const Packet shift = por(pand(x_has_signbit, kPi), y_signmask); const Packet shift = por(pand(x_has_signbit, kPi), y_signmask);
const Packet xor_xy = pxor(x, y); const Packet x_and_y_are_same = pcmp_eq(pabs(x), pabs(y));
// if x and y have the same absolute value, then xor(x,y) is zero
// make sure that neither x nor y is nan
// furthermore, xor(x,y) has the sign of the result
const Packet x_and_y_are_same = pand(pcmp_eq(xor_xy, kZero), pcmp_eq(x, x));
// more strictly, if x and y are both zero, then or(x,y) is zero
// this implicitly checks for nan
// the sign of or(x,y) is not meaningful
const Packet x_and_y_are_zero = pcmp_eq(por(x, y), kZero); const Packet x_and_y_are_zero = pcmp_eq(por(x, y), kZero);
Packet arg = pdiv(y, x); Packet arg = pdiv(y, x);
arg = pselect(x_and_y_are_same, por(kOne, xor_xy), arg); arg = pselect(x_and_y_are_same, por(kOne, result_signmask), arg);
arg = pselect(x_and_y_are_zero, xor_xy, arg); arg = pselect(x_and_y_are_zero, result_signmask, arg);
Packet result = patan(arg); Packet result = patan(arg);
result = padd(result, shift); result = padd(result, shift);