Add a vectorized implementation of atan2 to Eigen.

This commit is contained in:
Rasmus Munk Larsen 2022-09-28 20:46:49 +00:00
parent b3bf8d6a13
commit 1e1848fdb1
7 changed files with 104 additions and 1 deletions

View File

@ -181,6 +181,25 @@ namespace Eigen
}
#endif
/** \returns an expression of the coefficient-wise atan2(\a x, \a y). \a x and \a y must be of the same type.
*
* This function computes the coefficient-wise atan2().
*
* \sa ArrayBase::atan2()
*
* \relates ArrayBase
*/
template <typename LhsDerived, typename RhsDerived>
inline const std::enable_if_t<
std::is_same<typename LhsDerived::Scalar, typename RhsDerived::Scalar>::value,
Eigen::CwiseBinaryOp<Eigen::internal::scalar_atan2_op<typename LhsDerived::Scalar, typename RhsDerived::Scalar>, const LhsDerived, const RhsDerived>
>
atan2(const Eigen::ArrayBase<LhsDerived>& x, const Eigen::ArrayBase<RhsDerived>& exponents) {
return Eigen::CwiseBinaryOp<Eigen::internal::scalar_atan2_op<typename LhsDerived::Scalar, typename RhsDerived::Scalar>, const LhsDerived, const RhsDerived>(
x.derived(),
exponents.derived()
);
}
namespace internal
{

View File

@ -626,6 +626,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
return bfloat16(::powf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) {
return bfloat16(::atan2f(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
return bfloat16(::sinf(float(a)));
}

View File

@ -758,6 +758,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
return half(::powf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half atan2(const half& a, const half& b) {
return half(::atan2f(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
return half(::sinf(float(a)));
}

View File

@ -509,6 +509,64 @@ struct functor_traits<scalar_absolute_difference_op<LhsScalar,RhsScalar> > {
};
template <typename LhsScalar, typename RhsScalar>
struct scalar_atan2_op {
using Scalar = LhsScalar;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Scalar>
operator()(const Scalar& y, const Scalar& x) const {
EIGEN_USING_STD(atan2);
return static_cast<Scalar>(atan2(y, x));
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Packet> packetOp(const Packet& y,
const Packet& x) const {
// See https://en.cppreference.com/w/cpp/numeric/math/atan2
// for how corner cases are supposed to be handles according to the
// IEEE floating-point standard (IEC 60559).
constexpr Scalar k3PiO3f = Scalar(3.0 * M_PI_4);
const Packet kSignMask = pset1<Packet>(Scalar(-0.0));
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
const Packet kPiO2 = pset1<Packet>(Scalar(M_PI_2));
const Packet kPiO4 = pset1<Packet>(Scalar(M_PI_4));
const Packet k3PiO4 = pset1<Packet>(k3PiO3f);
Packet x_neg = pcmp_lt(x, pzero(x));
Packet x_sign = pand(x, kSignMask);
Packet y_sign = pand(y, kSignMask);
Packet x_zero = pcmp_eq(x, pzero(x));
Packet y_zero = pcmp_eq(y, pzero(y));
// Compute the normal case. Notice that we expect that
// finite/infinite = +/-0 here.
Packet result = patan(pdiv(y, x));
// Compute shift for when x != 0 and y != 0.
Packet shift = pselect(x_neg, por(kPi, y_sign), pzero(x));
// Special cases:
// Handle x = +/-inf && y = +/-inf.
Packet is_not_nan = pcmp_eq(result, result);
result = pselect(is_not_nan, padd(shift, result),
pselect(x_neg, por(k3PiO4, y_sign), por(kPiO4, y_sign)));
// Handle x == +/-0.
result =
pselect(x_zero, pselect(y_zero, pzero(y), por(y_sign, kPiO2)), result);
// Handle y == +/-0.
result = pselect(y_zero,
pselect(x_sign, por(y_sign, kPi), por(y_sign, pzero(y))),
result);
return result;
}
};
template<typename LhsScalar,typename RhsScalar>
struct functor_traits<scalar_atan2_op<LhsScalar, RhsScalar>> {
enum {
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasATan && packet_traits<LhsScalar>::HasDiv && !NumTraits<LhsScalar>::IsInteger && !NumTraits<LhsScalar>::IsComplex,
Cost =
scalar_div_cost<LhsScalar, PacketAccess>::value + 5 * NumTraits<LhsScalar>::MulCost + 5 * NumTraits<LhsScalar>::AddCost
};
};
//---------- binary functors bound to a constant, thus appearing as a unary functor ----------

View File

@ -134,6 +134,14 @@ absolute_difference
*/
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
/** \returns an expression of the coefficient-wise atan2(\c *this, \a y), where \a y is the given array argument.
*
* This function computes the coefficient-wise atan2.
*
*/
EIGEN_MAKE_CWISE_BINARY_OP(atan2,atan2)
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
template<typename OtherDerived> \

View File

@ -0,0 +1,4 @@
Array<double,1,3> x(8,-25,3),
y(1./3.,0.5,-2.);
cout << "atan2([" << x << "], [" << y << "]) = " << x.atan2(y) << endl; // using ArrayBase::pow
cout << "atan2([" << x << "], [" << y << "] = " << atan2(x,y) << endl; // using Eigen::pow

View File

@ -531,6 +531,8 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
VERIFY_IS_APPROX(m1.atan2(m2), atan2(m1,m2));
#if EIGEN_HAS_CXX11_MATH
VERIFY_IS_APPROX(m1.tanh().atanh(), atanh(tanh(m1)));
VERIFY_IS_APPROX(m1.sinh().asinh(), asinh(sinh(m1)));
@ -592,6 +594,13 @@ template<typename ArrayType> void array_real(const ArrayType& m)
VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() );
VERIFY_IS_APPROX( m1*m1.sign(),m1.abs());
VERIFY_IS_APPROX(m1.sign() * m1.abs(), m1);
ArrayType tmp = m1.atan2(m2);
for (Index i = 0; i < tmp.size(); ++i) {
Scalar actual = tmp.array()(i);
Scalar expected = atan2(m1.array()(i), m2.array()(i));
VERIFY_IS_APPROX(actual, expected);
}
VERIFY_IS_APPROX(numext::abs2(numext::real(m1)) + numext::abs2(numext::imag(m1)), numext::abs2(m1));
VERIFY_IS_APPROX(numext::abs2(Eigen::real(m1)) + numext::abs2(Eigen::imag(m1)), numext::abs2(m1));
@ -684,7 +693,6 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
VERIFY_IS_APPROX(cos(m1+RealScalar(3)*m2), cos((m1+RealScalar(3)*m2).eval()));
VERIFY_IS_APPROX(m1.sign(), sign(m1));
VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2));
VERIFY_IS_APPROX(m1.exp(), exp(m1));
VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp());