Specialize psign<Packet8i> for AVX2, don't vectorize psign<bool>.

This commit is contained in:
Rasmus Munk Larsen 2022-08-26 17:02:37 +00:00
parent 98e51c9e24
commit 7064ed1345
5 changed files with 117 additions and 78 deletions

View File

@ -921,6 +921,24 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); }
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); } Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
template<typename Packet, typename EnableIf = void>
struct psign_impl {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
return numext::sign(a);
}
};
/** \internal \returns the sign of \a a (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
psign(const Packet& a) {
return psign_impl<Packet>::run(a);
}
template<> EIGEN_DEVICE_FUNC inline bool
psign(const bool& a) {
return a;
}
/** \internal \returns the first element of a packet */ /** \internal \returns the first element of a packet */
template<typename Packet> template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type

View File

@ -269,6 +269,17 @@ struct sign_impl<Scalar, true, IsInteger>
} }
}; };
// The sign function for bool is the identity.
template<>
struct sign_impl<bool, false, true>
{
EIGEN_DEVICE_FUNC
static inline bool run(const bool& a)
{
return a;
}
};
template<typename Scalar> template<typename Scalar>
struct sign_retval struct sign_retval
{ {

View File

@ -695,6 +695,12 @@ template<> EIGEN_STRONG_INLINE Packet8i pmax<Packet8i>(const Packet8i& a, const
#endif #endif
} }
#ifdef EIGEN_VECTORIZE_AVX2
template<> EIGEN_STRONG_INLINE Packet8i psign(const Packet8i& a) {
return _mm256_sign_epi32(_mm256_set1_epi32(1), a);
}
#endif
// Add specializations for min/max with prescribed NaN progation. // Add specializations for min/max with prescribed NaN progation.
template<> template<>
EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) { EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {

View File

@ -853,13 +853,13 @@ Packet psqrt_complex(const Packet& a) {
} }
/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is template <typename Packet>
strictly positive. */ struct psign_impl<
template<typename Packet> Packet,
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t<
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex && !NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet> !NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
psign(const Packet& a) { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1)); const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1)); const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
@ -872,14 +872,16 @@ psign(const Packet& a) {
const Packet negative = pand(negative_mask, cst_minus_one); const Packet negative = pand(negative_mask, cst_minus_one);
return pselect(not_nan_mask, por(positive, negative), a); return pselect(not_nan_mask, por(positive, negative), a);
} }
};
template<typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS struct psign_impl<
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex && Packet, std::enable_if_t<
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned && NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet> NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
psign(const Packet& a) { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1)); const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1)); const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
@ -891,27 +893,27 @@ psign(const Packet& a) {
const Packet negative = pand(negative_mask, cst_minus_one); const Packet negative = pand(negative_mask, cst_minus_one);
return por(positive, negative); return por(positive, negative);
} }
};
template<typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned && !NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet> NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
psign(const Packet& a) { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1)); const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_zero = pzero(a); const Packet cst_zero = pzero(a);
const Packet zero_mask = pcmp_eq(cst_zero, a); const Packet zero_mask = pcmp_eq(cst_zero, a);
return pandnot(cst_one, zero_mask); return pandnot(cst_one, zero_mask);
} }
};
// \internal \returns the the sign of a complex number z, defined as z / abs(z). // \internal \returns the the sign of a complex number z, defined as z / abs(z).
template<typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex>> {
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex, Packet> static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
psign(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar; typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket; typedef typename unpacket_traits<Packet>::as_real RealPacket;
@ -938,7 +940,8 @@ psign(const Packet& a) {
Packet sign; Packet sign;
sign.v = sign_as_real; sign.v = sign_as_real;
return sign; return sign;
} }
};
// TODO(rmlarsen): The following set of utilities for double word arithmetic // TODO(rmlarsen): The following set of utilities for double word arithmetic
// should perhaps be refactored as a separate file, since it would be generally // should perhaps be refactored as a separate file, since it would be generally

View File

@ -218,7 +218,8 @@ template<> struct packet_traits<bool> : default_packet_traits
HasMin = 0, HasMin = 0,
HasMax = 0, HasMax = 0,
HasConj = 0, HasConj = 0,
HasSqrt = 1 HasSqrt = 1,
HasSign = 0 // Don't try to vectorize psign<bool> = identity.
}; };
}; };