mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-19 16:19:37 +08:00
Add a vectorized implementation of atan2 to Eigen.
This commit is contained in:
parent
b3bf8d6a13
commit
1e1848fdb1
@ -181,6 +181,25 @@ namespace Eigen
|
||||
}
|
||||
#endif
|
||||
|
||||
/** \returns an expression of the coefficient-wise atan2(\a x, \a y). \a x and \a y must be of the same type.
|
||||
*
|
||||
* This function computes the coefficient-wise atan2().
|
||||
*
|
||||
* \sa ArrayBase::atan2()
|
||||
*
|
||||
* \relates ArrayBase
|
||||
*/
|
||||
template <typename LhsDerived, typename RhsDerived>
|
||||
inline const std::enable_if_t<
|
||||
std::is_same<typename LhsDerived::Scalar, typename RhsDerived::Scalar>::value,
|
||||
Eigen::CwiseBinaryOp<Eigen::internal::scalar_atan2_op<typename LhsDerived::Scalar, typename RhsDerived::Scalar>, const LhsDerived, const RhsDerived>
|
||||
>
|
||||
atan2(const Eigen::ArrayBase<LhsDerived>& x, const Eigen::ArrayBase<RhsDerived>& exponents) {
|
||||
return Eigen::CwiseBinaryOp<Eigen::internal::scalar_atan2_op<typename LhsDerived::Scalar, typename RhsDerived::Scalar>, const LhsDerived, const RhsDerived>(
|
||||
x.derived(),
|
||||
exponents.derived()
|
||||
);
|
||||
}
|
||||
|
||||
namespace internal
|
||||
{
|
||||
|
@ -626,6 +626,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(::powf(float(a), float(b)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) {
|
||||
return bfloat16(::atan2f(float(a), float(b)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
|
||||
return bfloat16(::sinf(float(a)));
|
||||
}
|
||||
|
@ -758,6 +758,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
|
||||
return half(::powf(float(a), float(b)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half atan2(const half& a, const half& b) {
|
||||
return half(::atan2f(float(a), float(b)));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
|
||||
return half(::sinf(float(a)));
|
||||
}
|
||||
|
@ -509,6 +509,64 @@ struct functor_traits<scalar_absolute_difference_op<LhsScalar,RhsScalar> > {
|
||||
};
|
||||
|
||||
|
||||
template <typename LhsScalar, typename RhsScalar>
|
||||
struct scalar_atan2_op {
|
||||
using Scalar = LhsScalar;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Scalar>
|
||||
operator()(const Scalar& y, const Scalar& x) const {
|
||||
EIGEN_USING_STD(atan2);
|
||||
return static_cast<Scalar>(atan2(y, x));
|
||||
}
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Packet> packetOp(const Packet& y,
|
||||
const Packet& x) const {
|
||||
// See https://en.cppreference.com/w/cpp/numeric/math/atan2
|
||||
// for how corner cases are supposed to be handles according to the
|
||||
// IEEE floating-point standard (IEC 60559).
|
||||
constexpr Scalar k3PiO3f = Scalar(3.0 * M_PI_4);
|
||||
const Packet kSignMask = pset1<Packet>(Scalar(-0.0));
|
||||
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
|
||||
const Packet kPiO2 = pset1<Packet>(Scalar(M_PI_2));
|
||||
const Packet kPiO4 = pset1<Packet>(Scalar(M_PI_4));
|
||||
const Packet k3PiO4 = pset1<Packet>(k3PiO3f);
|
||||
Packet x_neg = pcmp_lt(x, pzero(x));
|
||||
Packet x_sign = pand(x, kSignMask);
|
||||
Packet y_sign = pand(y, kSignMask);
|
||||
Packet x_zero = pcmp_eq(x, pzero(x));
|
||||
Packet y_zero = pcmp_eq(y, pzero(y));
|
||||
|
||||
// Compute the normal case. Notice that we expect that
|
||||
// finite/infinite = +/-0 here.
|
||||
Packet result = patan(pdiv(y, x));
|
||||
|
||||
// Compute shift for when x != 0 and y != 0.
|
||||
Packet shift = pselect(x_neg, por(kPi, y_sign), pzero(x));
|
||||
|
||||
// Special cases:
|
||||
// Handle x = +/-inf && y = +/-inf.
|
||||
Packet is_not_nan = pcmp_eq(result, result);
|
||||
result = pselect(is_not_nan, padd(shift, result),
|
||||
pselect(x_neg, por(k3PiO4, y_sign), por(kPiO4, y_sign)));
|
||||
// Handle x == +/-0.
|
||||
result =
|
||||
pselect(x_zero, pselect(y_zero, pzero(y), por(y_sign, kPiO2)), result);
|
||||
// Handle y == +/-0.
|
||||
result = pselect(y_zero,
|
||||
pselect(x_sign, por(y_sign, kPi), por(y_sign, pzero(y))),
|
||||
result);
|
||||
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename LhsScalar,typename RhsScalar>
|
||||
struct functor_traits<scalar_atan2_op<LhsScalar, RhsScalar>> {
|
||||
enum {
|
||||
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasATan && packet_traits<LhsScalar>::HasDiv && !NumTraits<LhsScalar>::IsInteger && !NumTraits<LhsScalar>::IsComplex,
|
||||
Cost =
|
||||
scalar_div_cost<LhsScalar, PacketAccess>::value + 5 * NumTraits<LhsScalar>::MulCost + 5 * NumTraits<LhsScalar>::AddCost
|
||||
};
|
||||
};
|
||||
|
||||
//---------- binary functors bound to a constant, thus appearing as a unary functor ----------
|
||||
|
||||
|
@ -134,6 +134,14 @@ absolute_difference
|
||||
*/
|
||||
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
|
||||
|
||||
/** \returns an expression of the coefficient-wise atan2(\c *this, \a y), where \a y is the given array argument.
|
||||
*
|
||||
* This function computes the coefficient-wise atan2.
|
||||
*
|
||||
*/
|
||||
EIGEN_MAKE_CWISE_BINARY_OP(atan2,atan2)
|
||||
|
||||
|
||||
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
|
||||
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
|
||||
template<typename OtherDerived> \
|
||||
|
4
doc/snippets/Cwise_array_atan2_array.cpp
Normal file
4
doc/snippets/Cwise_array_atan2_array.cpp
Normal file
@ -0,0 +1,4 @@
|
||||
Array<double,1,3> x(8,-25,3),
|
||||
y(1./3.,0.5,-2.);
|
||||
cout << "atan2([" << x << "], [" << y << "]) = " << x.atan2(y) << endl; // using ArrayBase::pow
|
||||
cout << "atan2([" << x << "], [" << y << "] = " << atan2(x,y) << endl; // using Eigen::pow
|
@ -531,6 +531,8 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
||||
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
|
||||
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
|
||||
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
|
||||
VERIFY_IS_APPROX(m1.atan2(m2), atan2(m1,m2));
|
||||
|
||||
#if EIGEN_HAS_CXX11_MATH
|
||||
VERIFY_IS_APPROX(m1.tanh().atanh(), atanh(tanh(m1)));
|
||||
VERIFY_IS_APPROX(m1.sinh().asinh(), asinh(sinh(m1)));
|
||||
@ -592,6 +594,13 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
||||
VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() );
|
||||
VERIFY_IS_APPROX( m1*m1.sign(),m1.abs());
|
||||
VERIFY_IS_APPROX(m1.sign() * m1.abs(), m1);
|
||||
|
||||
ArrayType tmp = m1.atan2(m2);
|
||||
for (Index i = 0; i < tmp.size(); ++i) {
|
||||
Scalar actual = tmp.array()(i);
|
||||
Scalar expected = atan2(m1.array()(i), m2.array()(i));
|
||||
VERIFY_IS_APPROX(actual, expected);
|
||||
}
|
||||
|
||||
VERIFY_IS_APPROX(numext::abs2(numext::real(m1)) + numext::abs2(numext::imag(m1)), numext::abs2(m1));
|
||||
VERIFY_IS_APPROX(numext::abs2(Eigen::real(m1)) + numext::abs2(Eigen::imag(m1)), numext::abs2(m1));
|
||||
@ -684,7 +693,6 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
|
||||
VERIFY_IS_APPROX(cos(m1+RealScalar(3)*m2), cos((m1+RealScalar(3)*m2).eval()));
|
||||
VERIFY_IS_APPROX(m1.sign(), sign(m1));
|
||||
|
||||
|
||||
VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2));
|
||||
VERIFY_IS_APPROX(m1.exp(), exp(m1));
|
||||
VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp());
|
||||
|
Loading…
x
Reference in New Issue
Block a user