mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-26 06:44:27 +08:00
Add a vectorized implementation of atan2 to Eigen.
This commit is contained in:
parent
b3bf8d6a13
commit
1e1848fdb1
@ -181,6 +181,25 @@ namespace Eigen
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/** \returns an expression of the coefficient-wise atan2(\a x, \a y). \a x and \a y must be of the same type.
|
||||||
|
*
|
||||||
|
* This function computes the coefficient-wise atan2().
|
||||||
|
*
|
||||||
|
* \sa ArrayBase::atan2()
|
||||||
|
*
|
||||||
|
* \relates ArrayBase
|
||||||
|
*/
|
||||||
|
template <typename LhsDerived, typename RhsDerived>
|
||||||
|
inline const std::enable_if_t<
|
||||||
|
std::is_same<typename LhsDerived::Scalar, typename RhsDerived::Scalar>::value,
|
||||||
|
Eigen::CwiseBinaryOp<Eigen::internal::scalar_atan2_op<typename LhsDerived::Scalar, typename RhsDerived::Scalar>, const LhsDerived, const RhsDerived>
|
||||||
|
>
|
||||||
|
atan2(const Eigen::ArrayBase<LhsDerived>& x, const Eigen::ArrayBase<RhsDerived>& exponents) {
|
||||||
|
return Eigen::CwiseBinaryOp<Eigen::internal::scalar_atan2_op<typename LhsDerived::Scalar, typename RhsDerived::Scalar>, const LhsDerived, const RhsDerived>(
|
||||||
|
x.derived(),
|
||||||
|
exponents.derived()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
namespace internal
|
namespace internal
|
||||||
{
|
{
|
||||||
|
@ -626,6 +626,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sqrt(const bfloat16& a) {
|
|||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 pow(const bfloat16& a, const bfloat16& b) {
|
||||||
return bfloat16(::powf(float(a), float(b)));
|
return bfloat16(::powf(float(a), float(b)));
|
||||||
}
|
}
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 atan2(const bfloat16& a, const bfloat16& b) {
|
||||||
|
return bfloat16(::atan2f(float(a), float(b)));
|
||||||
|
}
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 sin(const bfloat16& a) {
|
||||||
return bfloat16(::sinf(float(a)));
|
return bfloat16(::sinf(float(a)));
|
||||||
}
|
}
|
||||||
|
@ -758,6 +758,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sqrt(const half& a) {
|
|||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half pow(const half& a, const half& b) {
|
||||||
return half(::powf(float(a), float(b)));
|
return half(::powf(float(a), float(b)));
|
||||||
}
|
}
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half atan2(const half& a, const half& b) {
|
||||||
|
return half(::atan2f(float(a), float(b)));
|
||||||
|
}
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half sin(const half& a) {
|
||||||
return half(::sinf(float(a)));
|
return half(::sinf(float(a)));
|
||||||
}
|
}
|
||||||
|
@ -509,6 +509,64 @@ struct functor_traits<scalar_absolute_difference_op<LhsScalar,RhsScalar> > {
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template <typename LhsScalar, typename RhsScalar>
|
||||||
|
struct scalar_atan2_op {
|
||||||
|
using Scalar = LhsScalar;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Scalar>
|
||||||
|
operator()(const Scalar& y, const Scalar& x) const {
|
||||||
|
EIGEN_USING_STD(atan2);
|
||||||
|
return static_cast<Scalar>(atan2(y, x));
|
||||||
|
}
|
||||||
|
template <typename Packet>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_same<LhsScalar,RhsScalar>::value, Packet> packetOp(const Packet& y,
|
||||||
|
const Packet& x) const {
|
||||||
|
// See https://en.cppreference.com/w/cpp/numeric/math/atan2
|
||||||
|
// for how corner cases are supposed to be handles according to the
|
||||||
|
// IEEE floating-point standard (IEC 60559).
|
||||||
|
constexpr Scalar k3PiO3f = Scalar(3.0 * M_PI_4);
|
||||||
|
const Packet kSignMask = pset1<Packet>(Scalar(-0.0));
|
||||||
|
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
|
||||||
|
const Packet kPiO2 = pset1<Packet>(Scalar(M_PI_2));
|
||||||
|
const Packet kPiO4 = pset1<Packet>(Scalar(M_PI_4));
|
||||||
|
const Packet k3PiO4 = pset1<Packet>(k3PiO3f);
|
||||||
|
Packet x_neg = pcmp_lt(x, pzero(x));
|
||||||
|
Packet x_sign = pand(x, kSignMask);
|
||||||
|
Packet y_sign = pand(y, kSignMask);
|
||||||
|
Packet x_zero = pcmp_eq(x, pzero(x));
|
||||||
|
Packet y_zero = pcmp_eq(y, pzero(y));
|
||||||
|
|
||||||
|
// Compute the normal case. Notice that we expect that
|
||||||
|
// finite/infinite = +/-0 here.
|
||||||
|
Packet result = patan(pdiv(y, x));
|
||||||
|
|
||||||
|
// Compute shift for when x != 0 and y != 0.
|
||||||
|
Packet shift = pselect(x_neg, por(kPi, y_sign), pzero(x));
|
||||||
|
|
||||||
|
// Special cases:
|
||||||
|
// Handle x = +/-inf && y = +/-inf.
|
||||||
|
Packet is_not_nan = pcmp_eq(result, result);
|
||||||
|
result = pselect(is_not_nan, padd(shift, result),
|
||||||
|
pselect(x_neg, por(k3PiO4, y_sign), por(kPiO4, y_sign)));
|
||||||
|
// Handle x == +/-0.
|
||||||
|
result =
|
||||||
|
pselect(x_zero, pselect(y_zero, pzero(y), por(y_sign, kPiO2)), result);
|
||||||
|
// Handle y == +/-0.
|
||||||
|
result = pselect(y_zero,
|
||||||
|
pselect(x_sign, por(y_sign, kPi), por(y_sign, pzero(y))),
|
||||||
|
result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename LhsScalar,typename RhsScalar>
|
||||||
|
struct functor_traits<scalar_atan2_op<LhsScalar, RhsScalar>> {
|
||||||
|
enum {
|
||||||
|
PacketAccess = is_same<LhsScalar,RhsScalar>::value && packet_traits<LhsScalar>::HasATan && packet_traits<LhsScalar>::HasDiv && !NumTraits<LhsScalar>::IsInteger && !NumTraits<LhsScalar>::IsComplex,
|
||||||
|
Cost =
|
||||||
|
scalar_div_cost<LhsScalar, PacketAccess>::value + 5 * NumTraits<LhsScalar>::MulCost + 5 * NumTraits<LhsScalar>::AddCost
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
//---------- binary functors bound to a constant, thus appearing as a unary functor ----------
|
//---------- binary functors bound to a constant, thus appearing as a unary functor ----------
|
||||||
|
|
||||||
|
@ -134,6 +134,14 @@ absolute_difference
|
|||||||
*/
|
*/
|
||||||
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
|
EIGEN_MAKE_CWISE_BINARY_OP(pow,pow)
|
||||||
|
|
||||||
|
/** \returns an expression of the coefficient-wise atan2(\c *this, \a y), where \a y is the given array argument.
|
||||||
|
*
|
||||||
|
* This function computes the coefficient-wise atan2.
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
EIGEN_MAKE_CWISE_BINARY_OP(atan2,atan2)
|
||||||
|
|
||||||
|
|
||||||
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
|
// TODO code generating macros could be moved to Macros.h and could include generation of documentation
|
||||||
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
|
#define EIGEN_MAKE_CWISE_COMP_OP(OP, COMPARATOR) \
|
||||||
template<typename OtherDerived> \
|
template<typename OtherDerived> \
|
||||||
|
4
doc/snippets/Cwise_array_atan2_array.cpp
Normal file
4
doc/snippets/Cwise_array_atan2_array.cpp
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
Array<double,1,3> x(8,-25,3),
|
||||||
|
y(1./3.,0.5,-2.);
|
||||||
|
cout << "atan2([" << x << "], [" << y << "]) = " << x.atan2(y) << endl; // using ArrayBase::pow
|
||||||
|
cout << "atan2([" << x << "], [" << y << "] = " << atan2(x,y) << endl; // using Eigen::pow
|
@ -531,6 +531,8 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
|
VERIFY_IS_APPROX(m1.sinh(), sinh(m1));
|
||||||
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
|
VERIFY_IS_APPROX(m1.cosh(), cosh(m1));
|
||||||
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
|
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
|
||||||
|
VERIFY_IS_APPROX(m1.atan2(m2), atan2(m1,m2));
|
||||||
|
|
||||||
#if EIGEN_HAS_CXX11_MATH
|
#if EIGEN_HAS_CXX11_MATH
|
||||||
VERIFY_IS_APPROX(m1.tanh().atanh(), atanh(tanh(m1)));
|
VERIFY_IS_APPROX(m1.tanh().atanh(), atanh(tanh(m1)));
|
||||||
VERIFY_IS_APPROX(m1.sinh().asinh(), asinh(sinh(m1)));
|
VERIFY_IS_APPROX(m1.sinh().asinh(), asinh(sinh(m1)));
|
||||||
@ -592,6 +594,13 @@ template<typename ArrayType> void array_real(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() );
|
VERIFY_IS_APPROX( m1.sign(), -(-m1).sign() );
|
||||||
VERIFY_IS_APPROX( m1*m1.sign(),m1.abs());
|
VERIFY_IS_APPROX( m1*m1.sign(),m1.abs());
|
||||||
VERIFY_IS_APPROX(m1.sign() * m1.abs(), m1);
|
VERIFY_IS_APPROX(m1.sign() * m1.abs(), m1);
|
||||||
|
|
||||||
|
ArrayType tmp = m1.atan2(m2);
|
||||||
|
for (Index i = 0; i < tmp.size(); ++i) {
|
||||||
|
Scalar actual = tmp.array()(i);
|
||||||
|
Scalar expected = atan2(m1.array()(i), m2.array()(i));
|
||||||
|
VERIFY_IS_APPROX(actual, expected);
|
||||||
|
}
|
||||||
|
|
||||||
VERIFY_IS_APPROX(numext::abs2(numext::real(m1)) + numext::abs2(numext::imag(m1)), numext::abs2(m1));
|
VERIFY_IS_APPROX(numext::abs2(numext::real(m1)) + numext::abs2(numext::imag(m1)), numext::abs2(m1));
|
||||||
VERIFY_IS_APPROX(numext::abs2(Eigen::real(m1)) + numext::abs2(Eigen::imag(m1)), numext::abs2(m1));
|
VERIFY_IS_APPROX(numext::abs2(Eigen::real(m1)) + numext::abs2(Eigen::imag(m1)), numext::abs2(m1));
|
||||||
@ -684,7 +693,6 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
|
|||||||
VERIFY_IS_APPROX(cos(m1+RealScalar(3)*m2), cos((m1+RealScalar(3)*m2).eval()));
|
VERIFY_IS_APPROX(cos(m1+RealScalar(3)*m2), cos((m1+RealScalar(3)*m2).eval()));
|
||||||
VERIFY_IS_APPROX(m1.sign(), sign(m1));
|
VERIFY_IS_APPROX(m1.sign(), sign(m1));
|
||||||
|
|
||||||
|
|
||||||
VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2));
|
VERIFY_IS_APPROX(m1.exp() * m2.exp(), exp(m1+m2));
|
||||||
VERIFY_IS_APPROX(m1.exp(), exp(m1));
|
VERIFY_IS_APPROX(m1.exp(), exp(m1));
|
||||||
VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp());
|
VERIFY_IS_APPROX(m1.exp() / m2.exp(),(m1-m2).exp());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user