mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-16 14:49:39 +08:00
Tweak atan2
This commit is contained in:
parent
6fc9de7d93
commit
6d9f662a70
@ -1219,6 +1219,56 @@ template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE constexpr Packet
|
||||
psignbit(const Packet& a) { return psignbit_impl<Packet>::run(a); }
|
||||
|
||||
/** \internal \returns the 2-argument arc tangent of \a y and \a x (coeff-wise) */
|
||||
template <typename Packet, std::enable_if_t<is_scalar<Packet>::value, int> = 0>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packet& x) {
|
||||
return numext::atan2(y, x);
|
||||
}
|
||||
|
||||
/** \internal \returns the 2-argument arc tangent of \a y and \a x (coeff-wise) */
|
||||
template <typename Packet, std::enable_if_t<!is_scalar<Packet>::value, int> = 0>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packet& x) {
|
||||
typedef typename internal::unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
// See https://en.cppreference.com/w/cpp/numeric/math/atan2
|
||||
// for how corner cases are supposed to be handled according to the
|
||||
// IEEE floating-point standard (IEC 60559).
|
||||
|
||||
const Packet kSignMask = pset1<Packet>(-Scalar(0));
|
||||
const Packet kZero = pzero(x);
|
||||
const Packet kOne = pset1<Packet>(Scalar(1));
|
||||
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
|
||||
const Packet kInf = pset1<Packet>(NumTraits<Scalar>::infinity());
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet x_is_zero = pcmp_eq(abs_x, kZero);
|
||||
const Packet x_is_inf = pcmp_eq(abs_x, kInf);
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
|
||||
const Packet abs_y = pabs(y);
|
||||
const Packet y_is_zero = pcmp_eq(abs_y, kZero);
|
||||
const Packet y_is_inf = pcmp_eq(abs_y, kInf);
|
||||
const Packet y_signmask = pand(y, kSignMask);
|
||||
|
||||
const Packet arg_signmask = pand(pxor(x, y), kSignMask);
|
||||
const Packet shift = pxor(pand(x_has_signbit, kPi), y_signmask);
|
||||
|
||||
// bend two rules:
|
||||
// 1) 0 / 0 == 0
|
||||
// 2) inf / inf == 1
|
||||
// otherwise, evaluate atan(y/x) as usual and shift to the appropriate quadrant
|
||||
|
||||
Packet arg = pdiv(abs_y, abs_x);
|
||||
arg = pselect(pand(x_is_zero, y_is_zero), kZero, arg);
|
||||
arg = pselect(pand(x_is_inf, y_is_inf), kOne, arg);
|
||||
|
||||
Packet result = patan(arg);
|
||||
result = pxor(result, arg_signmask);
|
||||
result = padd(result, shift);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -1734,6 +1734,12 @@ T atan(const T &x) {
|
||||
return static_cast<T>(atan(x));
|
||||
}
|
||||
|
||||
template <typename T, std::enable_if_t<!NumTraits<T>::IsComplex, int> = 0>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T atan2(const T& y, const T& x) {
|
||||
EIGEN_USING_STD(atan2);
|
||||
return static_cast<T>(atan2(y, x));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
T atanh(const T &x) {
|
||||
|
@ -516,68 +516,25 @@ struct functor_traits<scalar_absolute_difference_op<LhsScalar,RhsScalar> > {
|
||||
template <typename LhsScalar, typename RhsScalar>
|
||||
struct scalar_atan2_op {
|
||||
using Scalar = LhsScalar;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Scalar>
|
||||
operator()(const Scalar& y, const Scalar& x) const {
|
||||
EIGEN_USING_STD(atan2);
|
||||
return static_cast<Scalar>(atan2(y, x));
|
||||
|
||||
static constexpr bool Enable = is_same<LhsScalar, RhsScalar>::value && !NumTraits<Scalar>::IsInteger && !NumTraits<Scalar>::IsComplex;
|
||||
EIGEN_STATIC_ASSERT(Enable, "LhsScalar and RhsScalar must be the same non-integer, non-complex type")
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& y, const Scalar& x) const {
|
||||
return numext::atan2(y, x);
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
std::enable_if_t<is_same<LhsScalar, RhsScalar>::value, Packet>
|
||||
packetOp(const Packet& y, const Packet& x) const {
|
||||
// See https://en.cppreference.com/w/cpp/numeric/math/atan2
|
||||
// for how corner cases are supposed to be handled according to the
|
||||
// IEEE floating-point standard (IEC 60559).
|
||||
const Packet kSignMask = pset1<Packet>(-Scalar(0));
|
||||
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
|
||||
const Packet kPiO2 = pset1<Packet>(Scalar(EIGEN_PI / 2));
|
||||
const Packet kPiO4 = pset1<Packet>(Scalar(EIGEN_PI / 4));
|
||||
const Packet k3PiO4 = pset1<Packet>(Scalar(3.0 * (EIGEN_PI / 4)));
|
||||
|
||||
// Various predicates about the inputs.
|
||||
Packet x_signbit = pand(x, kSignMask);
|
||||
Packet x_has_signbit = pcmp_lt(por(x_signbit, kPi), pzero(x));
|
||||
Packet x_is_zero = pcmp_eq(x, pzero(x));
|
||||
Packet x_neg = pandnot(x_has_signbit, x_is_zero);
|
||||
|
||||
Packet y_signbit = pand(y, kSignMask);
|
||||
Packet y_is_zero = pcmp_eq(y, pzero(y));
|
||||
Packet x_is_not_nan = pcmp_eq(x, x);
|
||||
Packet y_is_not_nan = pcmp_eq(y, y);
|
||||
|
||||
// Compute the normal case. Notice that we expect that
|
||||
// finite/infinite = +/-0 here.
|
||||
Packet result = patan(pdiv(y, x));
|
||||
|
||||
// Compute shift for when x != 0 and y != 0.
|
||||
Packet shift = pselect(x_neg, por(kPi, y_signbit), pzero(x));
|
||||
|
||||
// Special cases:
|
||||
// Handle x = +/-inf && y = +/-inf.
|
||||
Packet is_not_nan = pcmp_eq(result, result);
|
||||
result =
|
||||
pselect(is_not_nan, padd(shift, result),
|
||||
pselect(x_neg, por(k3PiO4, y_signbit), por(kPiO4, y_signbit)));
|
||||
// Handle x == +/-0.
|
||||
result = pselect(
|
||||
x_is_zero, pselect(y_is_zero, pzero(y), por(y_signbit, kPiO2)), result);
|
||||
// Handle y == +/-0.
|
||||
result = pselect(
|
||||
y_is_zero,
|
||||
pselect(x_has_signbit, por(y_signbit, kPi), por(y_signbit, pzero(y))),
|
||||
result);
|
||||
// Handle NaN inputs.
|
||||
Packet kQNaN = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
|
||||
return pselect(pand(x_is_not_nan, y_is_not_nan), result, kQNaN);
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& y, const Packet& x) const {
|
||||
return internal::patan2(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename LhsScalar,typename RhsScalar>
|
||||
struct functor_traits<scalar_atan2_op<LhsScalar, RhsScalar>> {
|
||||
using Scalar = LhsScalar;
|
||||
enum {
|
||||
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasATan && packet_traits<LhsScalar>::HasDiv && !NumTraits<LhsScalar>::IsInteger && !NumTraits<LhsScalar>::IsComplex,
|
||||
Cost =
|
||||
scalar_div_cost<LhsScalar, PacketAccess>::value + 5 * NumTraits<LhsScalar>::MulCost + 5 * NumTraits<LhsScalar>::AddCost
|
||||
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<Scalar>::HasATan && packet_traits<Scalar>::HasDiv && !NumTraits<Scalar>::IsInteger && !NumTraits<Scalar>::IsComplex,
|
||||
Cost = scalar_div_cost<Scalar, PacketAccess>::value + functor_traits<scalar_atan_op<Scalar>>::Cost
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -192,6 +192,8 @@ template<typename Scalar> struct scalar_sin_op;
|
||||
template<typename Scalar> struct scalar_acos_op;
|
||||
template<typename Scalar> struct scalar_asin_op;
|
||||
template<typename Scalar> struct scalar_tan_op;
|
||||
template<typename Scalar> struct scalar_atan_op;
|
||||
template <typename LhsScalar, typename RhsScalar = LhsScalar> struct scalar_atan2_op;
|
||||
template<typename Scalar> struct scalar_inverse_op;
|
||||
template<typename Scalar> struct scalar_square_op;
|
||||
template<typename Scalar> struct scalar_cube_op;
|
||||
|
@ -95,11 +95,11 @@ void binary_op_test(std::string name, Fn fun, RefFn ref) {
|
||||
template <typename Scalar>
|
||||
void binary_ops_test() {
|
||||
binary_op_test<Scalar>("pow",
|
||||
[](auto x, auto y) { return Eigen::pow(x, y); },
|
||||
[](auto x, auto y) { return std::pow(x, y); });
|
||||
[](const auto& x, const auto& y) { return Eigen::pow(x, y); },
|
||||
[](const auto& x, const auto& y) { return std::pow(x, y); });
|
||||
binary_op_test<Scalar>("atan2",
|
||||
[](auto x, auto y) { return Eigen::atan2(x, y); },
|
||||
[](auto x, auto y) { return std::atan2(x, y); });
|
||||
[](const auto& x, const auto& y) { return Eigen::atan2(x, y); },
|
||||
[](const auto& x, const auto& y) { return std::atan2(x, y); });
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
|
Loading…
x
Reference in New Issue
Block a user