diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 8119200ba..e1e9bfc4b 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -59,6 +59,7 @@ struct default_packet_traits HasMax = 1, HasConj = 1, HasSetLinear = 1, + HasSign = 1, HasBlend = 0, // This flag is used to indicate whether packet comparison is supported. // pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true. @@ -101,8 +102,7 @@ struct default_packet_traits HasRound = 0, HasRint = 0, HasFloor = 0, - HasCeil = 0, - HasSign = 0 + HasCeil = 0 }; }; @@ -179,7 +179,7 @@ struct eigen_packet_wrapper */ template struct is_scalar { - typedef typename unpacket_traits::type Scalar; + using Scalar = typename unpacket_traits::type; enum { value = internal::is_same::value }; diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 59cc831fc..cf7caddcd 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -229,6 +229,52 @@ struct imag_ref_retval typedef typename NumTraits::Real & type; }; + +/**************************************************************************** +* Implementation of sign * +****************************************************************************/ +template::IsComplex!=0), + bool IsInteger = (NumTraits::IsInteger!=0)> +struct sign_impl +{ + EIGEN_DEVICE_FUNC + static inline Scalar run(const Scalar& a) + { + return Scalar( (a>Scalar(0)) - (a +struct sign_impl +{ + EIGEN_DEVICE_FUNC + static inline Scalar run(const Scalar& a) + { + return (std::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a +struct sign_impl +{ + EIGEN_DEVICE_FUNC + static inline Scalar run(const Scalar& a) + { + using real_type = typename NumTraits::Real; + real_type aa = std::abs(a); + if (aa==real_type(0)) + return Scalar(0); + aa = real_type(1)/aa; + return Scalar(a.real()*aa, a.imag()*aa ); + } +}; + +template +struct sign_retval +{ + typedef Scalar type; +}; + /**************************************************************************** * Implementation of conj * ****************************************************************************/ @@ -1279,6 +1325,13 @@ inline EIGEN_MATHFUNC_RETVAL(conj, Scalar) conj(const Scalar& x) return EIGEN_MATHFUNC_IMPL(conj, Scalar)::run(x); } +template +EIGEN_DEVICE_FUNC +inline EIGEN_MATHFUNC_RETVAL(sign, Scalar) sign(const Scalar& x) +{ + return EIGEN_MATHFUNC_IMPL(sign, Scalar)::run(x); +} + template EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(abs2, Scalar) abs2(const Scalar& x) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 5949d2ce2..66310b651 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -852,6 +852,79 @@ Packet psqrt_complex(const Packet& a) { pselect(is_real_inf, real_inf_result,result)); } + +/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is + strictly positive. */ +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +std::enable_if_t<(!NumTraits::type>::IsComplex && + !NumTraits::type>::IsInteger), Packet> +psign(const Packet& a) { + using Scalar = typename unpacket_traits::type; + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_minus_one = pset1(Scalar(-1)); + const Packet cst_zero = pzero(a); + + const Packet not_nan_mask = pcmp_eq(a, a); + const Packet positive_mask = pcmp_lt(cst_zero, a); + const Packet positive = pand(positive_mask, cst_one); + const Packet negative_mask = pcmp_lt(a, cst_zero); + const Packet negative = pand(negative_mask, cst_minus_one); + + return pselect(not_nan_mask, por(positive, negative), a); +} + +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +std::enable_if_t<(!NumTraits::type>::IsComplex && + NumTraits::type>::IsInteger), Packet> +psign(const Packet& a) { + using Scalar = typename unpacket_traits::type; + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_minus_one = pset1(Scalar(-1)); + const Packet cst_zero = pzero(a); + + const Packet positive_mask = pcmp_lt(cst_zero, a); + const Packet positive = pand(positive_mask, cst_one); + const Packet negative_mask = pcmp_lt(a, cst_zero); + const Packet negative = pand(negative_mask, cst_minus_one); + + return por(positive, negative); +} + +// \internal \returns the the sign of a complex number z, defined as z / abs(z). +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +std::enable_if_t::type>::IsComplex, Packet> +psign(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; + + // Step 1. Compute (for each element z = x + i*y in a) + // l = abs(z) = sqrt(x^2 + y^2). + // To avoid over- and underflow, we use the stable formula for each hypotenuse + // l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)), + // where zmax = max(|x|, |y|), zmin = min(|x|, |y|), + RealPacket a_abs = pabs(a.v); + RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; + RealPacket a_max = pmax(a_abs, a_abs_flip); + RealPacket a_min = pmin(a_abs, a_abs_flip); + RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); + RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); + RealPacket r = pdiv(a_min, a_max); + const RealPacket cst_one = pset1(RealScalar(1)); + RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] + // Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be + // lossy. + l = pselect(a_min_zero_mask, a_max, l); + // Step 2 compute a / abs(a). + RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask); + Packet sign; + sign.v = sign_as_real; + return sign; +} + // TODO(rmlarsen): The following set of utilities for double word arithmetic // should perhaps be refactored as a separate file, since it would be generally // useful for special function implementation etc. Writing the algorithms in diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index e896040b5..b13d65420 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -155,7 +155,8 @@ struct packet_traits : default_packet_traits { #ifdef EIGEN_VECTORIZE_SSE4_1 HasRound = 1, #endif - HasRint = 1 + HasRint = 1, + HasSign = 0 // The manually vectorized version is slightly slower for SSE. }; }; template <> @@ -217,6 +218,7 @@ template<> struct packet_traits : default_packet_traits HasMin = 0, HasMax = 0, HasConj = 0, + HasSign = 0, HasSqrt = 1 }; }; diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index fafb53376..f4d5fcae4 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -916,44 +916,19 @@ struct functor_traits > { * \brief Template functor to compute the signum of a scalar * \sa class CwiseUnaryOp, Cwise::sign() */ -template::IsComplex!=0), bool is_integer=(NumTraits::IsInteger!=0) > struct scalar_sign_op; template -struct scalar_sign_op { +struct scalar_sign_op { EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { - return Scalar( (a>Scalar(0)) - (a + EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { + return internal::psign(a); } - //TODO - //template - //EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); } }; -template -struct scalar_sign_op { - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const - { - return (numext::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a - //EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); } -}; - -template -struct scalar_sign_op { - EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const - { - typedef typename NumTraits::Real real_type; - real_type aa = numext::abs(a); - if (aa==real_type(0)) - return Scalar(0); - aa = real_type(1)/aa; - return Scalar(a.real()*aa, a.imag()*aa ); - } - //TODO - //template - //EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); } -}; template struct functor_traits > { enum { diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index a69696121..bbed82c1f 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -196,7 +196,7 @@ template struct scalar_cast_op; template struct scalar_random_op; template struct scalar_constant_op; template struct scalar_identity_op; -template struct scalar_sign_op; +template struct scalar_sign_op; template struct scalar_pow_op; template struct scalar_hypot_op; template struct scalar_product_op; diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 0f61ea84c..d7c7c9cad 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -527,6 +527,7 @@ void packetmath() { CHECK_CWISE1_IF(PacketTraits::HasNegate, internal::negate, internal::pnegate); CHECK_CWISE1_IF(PacketTraits::HasReciprocal, REF_RECIPROCAL, internal::preciprocal); CHECK_CWISE1(numext::conj, internal::pconj); + CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign); for (int offset = 0; offset < 3; ++offset) { @@ -816,6 +817,7 @@ void packetmath_real() { CHECK_CWISE1_EXACT_IF(PacketTraits::HasCeil, numext::ceil, internal::pceil); CHECK_CWISE1_EXACT_IF(PacketTraits::HasFloor, numext::floor, internal::pfloor); CHECK_CWISE1_EXACT_IF(PacketTraits::HasRint, numext::rint, internal::print); + CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign); packetmath_boolean_mask_ops_real(); @@ -1340,6 +1342,7 @@ void packetmath_complex() { data1[i] = Scalar(internal::random(), internal::random()); } CHECK_CWISE1_N(numext::sqrt, internal::psqrt, size); + CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign); // Test misc. corner cases. const RealScalar zero = RealScalar(0);