This commit is contained in:
Charles Schlosser 2023-02-15 21:33:06 +00:00
parent 71a8e60a7a
commit 94b19dc5f2
6 changed files with 67 additions and 13 deletions

View File

@ -1192,27 +1192,27 @@ Packet prsqrt(const Packet& a) {
}
template <typename Packet, bool IsScalar = is_scalar<Packet>::value,
bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
struct psignbit_impl;
bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
struct psignbit_impl;
template <typename Packet, bool IsInteger>
struct psignbit_impl<Packet, true, IsInteger> {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return numext::signbit(a); }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return numext::signbit(a); }
};
template <typename Packet>
struct psignbit_impl<Packet, false, false> {
// generic implementation if not specialized in PacketMath.h
// slower than arithmetic shift
typedef typename unpacket_traits<Packet>::type Scalar;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Packet run(const Packet& a) {
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
const Packet cst_neg_one = pset1<Packet>(Scalar(-1));
return pcmp_eq(por(pand(a, cst_neg_one), cst_pos_one), cst_neg_one);
}
// generic implementation if not specialized in PacketMath.h
// slower than arithmetic shift
typedef typename unpacket_traits<Packet>::type Scalar;
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static Packet run(const Packet& a) {
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
const Packet cst_neg_one = pset1<Packet>(Scalar(-1));
return pcmp_eq(por(pand(a, cst_neg_one), cst_pos_one), cst_neg_one);
}
};
template <typename Packet>
struct psignbit_impl<Packet, false, true> {
// generic implementation for integer packets
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return pcmp_lt(a, pzero(a)); }
// generic implementation for integer packets
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static constexpr Packet run(const Packet& a) { return pcmp_lt(a, pzero(a)); }
};
/** \internal \returns the sign bit of \a a as a bitmask*/
template <typename Packet>
@ -1256,6 +1256,24 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packe
return result;
}
/** \internal \returns the argument of \a a as a complex number */
template <typename Packet, std::enable_if_t<is_scalar<Packet>::value, int> = 0>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pcarg(const Packet& a) {
return Packet(numext::arg(a));
}
/** \internal \returns the argument of \a a as a complex number */
template <typename Packet, std::enable_if_t<!is_scalar<Packet>::value, int> = 0>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pcarg(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
EIGEN_STATIC_ASSERT(NumTraits<Scalar>::IsComplex, THIS METHOD IS FOR COMPLEX TYPES ONLY)
using RealPacket = typename unpacket_traits<Packet>::as_real;
// a // r i r i ...
RealPacket aflip = pcplxflip(a).v; // i r i r ...
RealPacket result = patan2(aflip, a.v); // atan2 crap atan2 crap ...
return (Packet)pand(result, peven_mask(result)); // atan2 0 atan2 0 ...
}
} // end namespace internal
} // end namespace Eigen

View File

@ -86,6 +86,7 @@ namespace Eigen
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(abs,scalar_abs_op,absolute value,\sa ArrayBase::abs DOXCOMMA MatrixBase::cwiseAbs)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(abs2,scalar_abs2_op,squared absolute value,\sa ArrayBase::abs2 DOXCOMMA MatrixBase::cwiseAbs2)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(arg,scalar_arg_op,complex argument,\sa ArrayBase::arg DOXCOMMA MatrixBase::cwiseArg)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(carg, scalar_carg_op, complex argument, \sa ArrayBase::carg DOXCOMMA MatrixBase::cwiseCArg)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(sqrt,scalar_sqrt_op,square root,\sa ArrayBase::sqrt DOXCOMMA MatrixBase::cwiseSqrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(rsqrt,scalar_rsqrt_op,reciprocal square root,\sa ArrayBase::rsqrt)
EIGEN_ARRAY_DECLARE_GLOBAL_UNARY(square,scalar_square_op,square (power 2),\sa Eigen::abs2 DOXCOMMA Eigen::pow DOXCOMMA ArrayBase::square)

View File

@ -146,6 +146,28 @@ struct functor_traits<scalar_arg_op<Scalar> >
PacketAccess = packet_traits<Scalar>::HasArg
};
};
/** \internal
* \brief Template functor to compute the complex argument, returned as a complex type
*
* \sa class CwiseUnaryOp, Cwise::carg
*/
template <typename Scalar>
struct scalar_carg_op {
using result_type = Scalar;
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar operator()(const Scalar& a) const { return Scalar(numext::arg(a)); }
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Packet packetOp(const Packet& a) const {
return pcarg(a);
}
};
template <typename Scalar>
struct functor_traits<scalar_carg_op<Scalar>> {
using RealScalar = typename NumTraits<Scalar>::Real;
enum { Cost = functor_traits<scalar_atan2_op<RealScalar>>::Cost, PacketAccess = packet_traits<RealScalar>::HasATan };
};
/** \internal
* \brief Template functor to cast a scalar to another type
*

View File

@ -2,6 +2,7 @@
typedef CwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived> AbsReturnType;
typedef CwiseUnaryOp<internal::scalar_arg_op<Scalar>, const Derived> ArgReturnType;
typedef CwiseUnaryOp<internal::scalar_carg_op<Scalar>, const Derived> CArgReturnType;
typedef CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const Derived> Abs2ReturnType;
typedef CwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> SqrtReturnType;
typedef CwiseUnaryOp<internal::scalar_rsqrt_op<Scalar>, const Derived> RsqrtReturnType;
@ -66,6 +67,10 @@ arg() const
return ArgReturnType(derived());
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CArgReturnType
carg() const { return CArgReturnType(derived()); }
/** \returns an expression of the coefficient-wise squared absolute value of \c *this
*
* Example: \include Cwise_abs2.cpp

View File

@ -15,6 +15,7 @@
typedef CwiseUnaryOp<internal::scalar_abs_op<Scalar>, const Derived> CwiseAbsReturnType;
typedef CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const Derived> CwiseAbs2ReturnType;
typedef CwiseUnaryOp<internal::scalar_arg_op<Scalar>, const Derived> CwiseArgReturnType;
typedef CwiseUnaryOp<internal::scalar_carg_op<Scalar>, const Derived> CwiseCArgReturnType;
typedef CwiseUnaryOp<internal::scalar_sqrt_op<Scalar>, const Derived> CwiseSqrtReturnType;
typedef CwiseUnaryOp<internal::scalar_sign_op<Scalar>, const Derived> CwiseSignReturnType;
typedef CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> CwiseInverseReturnType;
@ -94,6 +95,10 @@ EIGEN_DEVICE_FUNC
inline const CwiseArgReturnType
cwiseArg() const { return CwiseArgReturnType(derived()); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseCArgReturnType
cwiseCArg() const { return CwiseCArgReturnType(derived()); }
template <typename ScalarExponent>
using CwisePowReturnType =
std::enable_if_t<internal::is_arithmetic<typename NumTraits<ScalarExponent>::Real>::value,

View File

@ -773,6 +773,8 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
VERIFY_IS_APPROX(m1.tanh(), tanh(m1));
VERIFY_IS_APPROX(m1.logistic(), logistic(m1));
VERIFY_IS_APPROX(m1.arg(), arg(m1));
VERIFY_IS_APPROX(m1.carg(), carg(m1));
VERIFY_IS_APPROX(arg(m1), carg(m1));
VERIFY((m1.isNaN() == (Eigen::isnan)(m1)).all());
VERIFY((m1.isInf() == (Eigen::isinf)(m1)).all());
VERIFY((m1.isFinite() == (Eigen::isfinite)(m1)).all());
@ -806,6 +808,7 @@ template<typename ArrayType> void array_complex(const ArrayType& m)
for (Index j = 0; j < m.cols(); ++j)
m3(i,j) = std::atan2(m1(i,j).imag(), m1(i,j).real());
VERIFY_IS_APPROX(arg(m1), m3);
VERIFY_IS_APPROX(carg(m1), m3);
std::complex<RealScalar> zero(0.0,0.0);
VERIFY((Eigen::isnan)(m1*zero/zero).all());