diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index e1e9bfc4b..b67c4eda6 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -921,6 +921,24 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); } template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); } +template +struct psign_impl { + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { + return numext::sign(a); + } +}; + +/** \internal \returns the sign of \a a (coeff-wise) */ +template EIGEN_DEVICE_FUNC inline Packet +psign(const Packet& a) { + return psign_impl::run(a); +} + +template<> EIGEN_DEVICE_FUNC inline bool +psign(const bool& a) { + return a; +} + /** \internal \returns the first element of a packet */ template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index cf7caddcd..af8b0296d 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -269,6 +269,17 @@ struct sign_impl } }; +// The sign function for bool is the identity. +template<> +struct sign_impl +{ + EIGEN_DEVICE_FUNC + static inline bool run(const bool& a) + { + return a; + } +}; + template struct sign_retval { diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 1542215ce..28e65d4a0 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -695,6 +695,12 @@ template<> EIGEN_STRONG_INLINE Packet8i pmax(const Packet8i& a, const #endif } +#ifdef EIGEN_VECTORIZE_AVX2 +template<> EIGEN_STRONG_INLINE Packet8i psign(const Packet8i& a) { + return _mm256_sign_epi32(_mm256_set1_epi32(1), a); +} +#endif + // Add specializations for min/max with prescribed NaN progation. template<> EIGEN_STRONG_INLINE Packet8f pmin(const Packet8f& a, const Packet8f& b) { diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 8f01ae580..ea72b6a28 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -853,92 +853,95 @@ Packet psqrt_complex(const Packet& a) { } -/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is - strictly positive. */ -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -std::enable_if_t<(!NumTraits::type>::IsComplex && - !NumTraits::type>::IsInteger), Packet> -psign(const Packet& a) { - using Scalar = typename unpacket_traits::type; - const Packet cst_one = pset1(Scalar(1)); - const Packet cst_minus_one = pset1(Scalar(-1)); - const Packet cst_zero = pzero(a); +template +struct psign_impl< + Packet, + std::enable_if_t< + !NumTraits::type>::IsComplex && + !NumTraits::type>::IsInteger>> { + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { + using Scalar = typename unpacket_traits::type; + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_minus_one = pset1(Scalar(-1)); + const Packet cst_zero = pzero(a); - const Packet not_nan_mask = pcmp_eq(a, a); - const Packet positive_mask = pcmp_lt(cst_zero, a); - const Packet positive = pand(positive_mask, cst_one); - const Packet negative_mask = pcmp_lt(a, cst_zero); - const Packet negative = pand(negative_mask, cst_minus_one); + const Packet not_nan_mask = pcmp_eq(a, a); + const Packet positive_mask = pcmp_lt(cst_zero, a); + const Packet positive = pand(positive_mask, cst_one); + const Packet negative_mask = pcmp_lt(a, cst_zero); + const Packet negative = pand(negative_mask, cst_minus_one); - return pselect(not_nan_mask, por(positive, negative), a); -} + return pselect(not_nan_mask, por(positive, negative), a); + } +}; -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -std::enable_if_t<(!NumTraits::type>::IsComplex && - NumTraits::type>::IsSigned && - NumTraits::type>::IsInteger), Packet> -psign(const Packet& a) { - using Scalar = typename unpacket_traits::type; - const Packet cst_one = pset1(Scalar(1)); - const Packet cst_minus_one = pset1(Scalar(-1)); - const Packet cst_zero = pzero(a); +template +struct psign_impl< + Packet, std::enable_if_t< + !NumTraits::type>::IsComplex && + NumTraits::type>::IsSigned && + NumTraits::type>::IsInteger>> { + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { + using Scalar = typename unpacket_traits::type; + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_minus_one = pset1(Scalar(-1)); + const Packet cst_zero = pzero(a); - const Packet positive_mask = pcmp_lt(cst_zero, a); - const Packet positive = pand(positive_mask, cst_one); - const Packet negative_mask = pcmp_lt(a, cst_zero); - const Packet negative = pand(negative_mask, cst_minus_one); + const Packet positive_mask = pcmp_lt(cst_zero, a); + const Packet positive = pand(positive_mask, cst_one); + const Packet negative_mask = pcmp_lt(a, cst_zero); + const Packet negative = pand(negative_mask, cst_minus_one); - return por(positive, negative); -} + return por(positive, negative); + } +}; -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -std::enable_if_t<(!NumTraits::type>::IsComplex && - !NumTraits::type>::IsSigned && - NumTraits::type>::IsInteger), Packet> -psign(const Packet& a) { - using Scalar = typename unpacket_traits::type; - const Packet cst_one = pset1(Scalar(1)); - const Packet cst_zero = pzero(a); +template +struct psign_impl::type>::IsComplex && + !NumTraits::type>::IsSigned && + NumTraits::type>::IsInteger>> { + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { + using Scalar = typename unpacket_traits::type; + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_zero = pzero(a); - const Packet zero_mask = pcmp_eq(cst_zero, a); - return pandnot(cst_one, zero_mask); -} + const Packet zero_mask = pcmp_eq(cst_zero, a); + return pandnot(cst_one, zero_mask); + } +}; // \internal \returns the the sign of a complex number z, defined as z / abs(z). -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -std::enable_if_t::type>::IsComplex, Packet> -psign(const Packet& a) { - typedef typename unpacket_traits::type Scalar; - typedef typename Scalar::value_type RealScalar; - typedef typename unpacket_traits::as_real RealPacket; +template +struct psign_impl::type>::IsComplex>> { + static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { + typedef typename unpacket_traits::type Scalar; + typedef typename Scalar::value_type RealScalar; + typedef typename unpacket_traits::as_real RealPacket; - // Step 1. Compute (for each element z = x + i*y in a) - // l = abs(z) = sqrt(x^2 + y^2). - // To avoid over- and underflow, we use the stable formula for each hypotenuse - // l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)), - // where zmax = max(|x|, |y|), zmin = min(|x|, |y|), - RealPacket a_abs = pabs(a.v); - RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; - RealPacket a_max = pmax(a_abs, a_abs_flip); - RealPacket a_min = pmin(a_abs, a_abs_flip); - RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); - RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); - RealPacket r = pdiv(a_min, a_max); - const RealPacket cst_one = pset1(RealScalar(1)); - RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] - // Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be - // lossy. - l = pselect(a_min_zero_mask, a_max, l); - // Step 2 compute a / abs(a). - RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask); - Packet sign; - sign.v = sign_as_real; - return sign; -} + // Step 1. Compute (for each element z = x + i*y in a) + // l = abs(z) = sqrt(x^2 + y^2). + // To avoid over- and underflow, we use the stable formula for each hypotenuse + // l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)), + // where zmax = max(|x|, |y|), zmin = min(|x|, |y|), + RealPacket a_abs = pabs(a.v); + RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; + RealPacket a_max = pmax(a_abs, a_abs_flip); + RealPacket a_min = pmin(a_abs, a_abs_flip); + RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); + RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); + RealPacket r = pdiv(a_min, a_max); + const RealPacket cst_one = pset1(RealScalar(1)); + RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] + // Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be + // lossy. + l = pselect(a_min_zero_mask, a_max, l); + // Step 2 compute a / abs(a). + RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask); + Packet sign; + sign.v = sign_as_real; + return sign; + } +}; // TODO(rmlarsen): The following set of utilities for double word arithmetic // should perhaps be refactored as a separate file, since it would be generally diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index ee2883df3..42b698ac1 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -207,7 +207,7 @@ template<> struct packet_traits : default_packet_traits AlignedOnScalar = 1, HasHalfPacket = 0, size=16, - + HasAdd = 1, HasSub = 1, HasShift = 0, @@ -218,7 +218,8 @@ template<> struct packet_traits : default_packet_traits HasMin = 0, HasMax = 0, HasConj = 0, - HasSqrt = 1 + HasSqrt = 1, + HasSign = 0 // Don't try to vectorize psign = identity. }; };