diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index af773dde2..adb8f9d77 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -1219,6 +1219,56 @@ template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE constexpr Packet psignbit(const Packet& a) { return psignbit_impl::run(a); } +/** \internal \returns the 2-argument arc tangent of \a y and \a x (coeff-wise) */ +template ::value, int> = 0> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packet& x) { + return numext::atan2(y, x); +} + +/** \internal \returns the 2-argument arc tangent of \a y and \a x (coeff-wise) */ +template ::value, int> = 0> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packet& x) { + typedef typename internal::unpacket_traits::type Scalar; + + // See https://en.cppreference.com/w/cpp/numeric/math/atan2 + // for how corner cases are supposed to be handled according to the + // IEEE floating-point standard (IEC 60559). + + const Packet kSignMask = pset1(-Scalar(0)); + const Packet kZero = pzero(x); + const Packet kOne = pset1(Scalar(1)); + const Packet kPi = pset1(Scalar(EIGEN_PI)); + const Packet kInf = pset1(NumTraits::infinity()); + + const Packet abs_x = pabs(x); + const Packet x_is_zero = pcmp_eq(abs_x, kZero); + const Packet x_is_inf = pcmp_eq(abs_x, kInf); + const Packet x_has_signbit = psignbit(x); + + const Packet abs_y = pabs(y); + const Packet y_is_zero = pcmp_eq(abs_y, kZero); + const Packet y_is_inf = pcmp_eq(abs_y, kInf); + const Packet y_signmask = pand(y, kSignMask); + + const Packet arg_signmask = pand(pxor(x, y), kSignMask); + const Packet shift = pxor(pand(x_has_signbit, kPi), y_signmask); + + // bend two rules: + // 1) 0 / 0 == 0 + // 2) inf / inf == 1 + // otherwise, evaluate atan(y/x) as usual and shift to the appropriate quadrant + + Packet arg = pdiv(abs_y, abs_x); + arg = pselect(pand(x_is_zero, y_is_zero), kZero, arg); + arg = pselect(pand(x_is_inf, y_is_inf), kOne, arg); + + Packet result = patan(arg); + result = pxor(result, arg_signmask); + result = padd(result, shift); + + return result; +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index b194353d6..def5428c2 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -1734,6 +1734,12 @@ T atan(const T &x) { return static_cast(atan(x)); } +template ::IsComplex, int> = 0> +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T atan2(const T& y, const T& x) { + EIGEN_USING_STD(atan2); + return static_cast(atan2(y, x)); +} + template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T atanh(const T &x) { diff --git a/Eigen/src/Core/functors/BinaryFunctors.h b/Eigen/src/Core/functors/BinaryFunctors.h index c8bb4e775..3c7601963 100644 --- a/Eigen/src/Core/functors/BinaryFunctors.h +++ b/Eigen/src/Core/functors/BinaryFunctors.h @@ -516,68 +516,25 @@ struct functor_traits > { template struct scalar_atan2_op { using Scalar = LhsScalar; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Scalar> - operator()(const Scalar& y, const Scalar& x) const { - EIGEN_USING_STD(atan2); - return static_cast(atan2(y, x)); + + static constexpr bool Enable = is_same::value && !NumTraits::IsInteger && !NumTraits::IsComplex; + EIGEN_STATIC_ASSERT(Enable, "LhsScalar and RhsScalar must be the same non-integer, non-complex type") + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& y, const Scalar& x) const { + return numext::atan2(y, x); } template - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - std::enable_if_t::value, Packet> - packetOp(const Packet& y, const Packet& x) const { - // See https://en.cppreference.com/w/cpp/numeric/math/atan2 - // for how corner cases are supposed to be handled according to the - // IEEE floating-point standard (IEC 60559). - const Packet kSignMask = pset1(-Scalar(0)); - const Packet kPi = pset1(Scalar(EIGEN_PI)); - const Packet kPiO2 = pset1(Scalar(EIGEN_PI / 2)); - const Packet kPiO4 = pset1(Scalar(EIGEN_PI / 4)); - const Packet k3PiO4 = pset1(Scalar(3.0 * (EIGEN_PI / 4))); - - // Various predicates about the inputs. - Packet x_signbit = pand(x, kSignMask); - Packet x_has_signbit = pcmp_lt(por(x_signbit, kPi), pzero(x)); - Packet x_is_zero = pcmp_eq(x, pzero(x)); - Packet x_neg = pandnot(x_has_signbit, x_is_zero); - - Packet y_signbit = pand(y, kSignMask); - Packet y_is_zero = pcmp_eq(y, pzero(y)); - Packet x_is_not_nan = pcmp_eq(x, x); - Packet y_is_not_nan = pcmp_eq(y, y); - - // Compute the normal case. Notice that we expect that - // finite/infinite = +/-0 here. - Packet result = patan(pdiv(y, x)); - - // Compute shift for when x != 0 and y != 0. - Packet shift = pselect(x_neg, por(kPi, y_signbit), pzero(x)); - - // Special cases: - // Handle x = +/-inf && y = +/-inf. - Packet is_not_nan = pcmp_eq(result, result); - result = - pselect(is_not_nan, padd(shift, result), - pselect(x_neg, por(k3PiO4, y_signbit), por(kPiO4, y_signbit))); - // Handle x == +/-0. - result = pselect( - x_is_zero, pselect(y_is_zero, pzero(y), por(y_signbit, kPiO2)), result); - // Handle y == +/-0. - result = pselect( - y_is_zero, - pselect(x_has_signbit, por(y_signbit, kPi), por(y_signbit, pzero(y))), - result); - // Handle NaN inputs. - Packet kQNaN = pset1(NumTraits::quiet_NaN()); - return pselect(pand(x_is_not_nan, y_is_not_nan), result, kQNaN); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& y, const Packet& x) const { + return internal::patan2(y, x); } }; template struct functor_traits> { + using Scalar = LhsScalar; enum { - PacketAccess = is_same::value && packet_traits::HasATan && packet_traits::HasDiv && !NumTraits::IsInteger && !NumTraits::IsComplex, - Cost = - scalar_div_cost::value + 5 * NumTraits::MulCost + 5 * NumTraits::AddCost + PacketAccess = is_same::value && packet_traits::HasATan && packet_traits::HasDiv && !NumTraits::IsInteger && !NumTraits::IsComplex, + Cost = scalar_div_cost::value + functor_traits>::Cost }; }; diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index c63ef3629..7f178b9f1 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -192,6 +192,8 @@ template struct scalar_sin_op; template struct scalar_acos_op; template struct scalar_asin_op; template struct scalar_tan_op; +template struct scalar_atan_op; +template struct scalar_atan2_op; template struct scalar_inverse_op; template struct scalar_square_op; template struct scalar_cube_op; diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index af8a1ef1d..a8853752d 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -95,11 +95,11 @@ void binary_op_test(std::string name, Fn fun, RefFn ref) { template void binary_ops_test() { binary_op_test("pow", - [](auto x, auto y) { return Eigen::pow(x, y); }, - [](auto x, auto y) { return std::pow(x, y); }); + [](const auto& x, const auto& y) { return Eigen::pow(x, y); }, + [](const auto& x, const auto& y) { return std::pow(x, y); }); binary_op_test("atan2", - [](auto x, auto y) { return Eigen::atan2(x, y); }, - [](auto x, auto y) { return std::atan2(x, y); }); + [](const auto& x, const auto& y) { return Eigen::atan2(x, y); }, + [](const auto& x, const auto& y) { return std::atan2(x, y); }); } template