From 6aad0f821b1c6bc4302dbbb7658d23121a94dbb4 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Mon, 22 Aug 2022 20:19:35 +0000 Subject: [PATCH] Fix psign for unsigned integer types, such as bool. --- .../arch/Default/GenericPacketMathFunctions.h | 17 ++++++++++++++++- Eigen/src/Core/arch/SSE/PacketMath.h | 1 - 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index d27fbd0b6..158f31182 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -855,7 +855,7 @@ Packet psqrt_complex(const Packet& a) { /** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is strictly positive. */ -template +template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t<(!NumTraits::type>::IsComplex && !NumTraits::type>::IsInteger), Packet> @@ -877,6 +877,7 @@ psign(const Packet& a) { template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t<(!NumTraits::type>::IsComplex && + NumTraits::type>::IsSigned && NumTraits::type>::IsInteger), Packet> psign(const Packet& a) { using Scalar = typename unpacket_traits::type; @@ -892,6 +893,20 @@ psign(const Packet& a) { return por(positive, negative); } +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +std::enable_if_t<(!NumTraits::type>::IsComplex && + !NumTraits::type>::IsSigned && + NumTraits::type>::IsInteger), Packet> +psign(const Packet& a) { + using Scalar = typename unpacket_traits::type; + const Packet cst_one = pset1(Scalar(1)); + const Packet cst_zero = pzero(a); + + const Packet zero_mask = pcmp_eq(cst_zero, a); + return pandnot(cst_one, zero_mask); +} + // \internal \returns the the sign of a complex number z, defined as z / abs(z). template EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index b13d65420..ee2883df3 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -218,7 +218,6 @@ template<> struct packet_traits : default_packet_traits HasMin = 0, HasMax = 0, HasConj = 0, - HasSign = 0, HasSqrt = 1 }; };