Specialize psign<Packet8i> for AVX2, don't vectorize psign<bool>.

This commit is contained in:
Rasmus Munk Larsen 2022-08-26 17:02:37 +00:00
parent 98e51c9e24
commit 7064ed1345
5 changed files with 117 additions and 78 deletions

View File

@ -921,6 +921,24 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); }
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); } Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
template<typename Packet, typename EnableIf = void>
struct psign_impl {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
return numext::sign(a);
}
};
/** \internal \returns the sign of \a a (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
psign(const Packet& a) {
return psign_impl<Packet>::run(a);
}
template<> EIGEN_DEVICE_FUNC inline bool
psign(const bool& a) {
return a;
}
/** \internal \returns the first element of a packet */ /** \internal \returns the first element of a packet */
template<typename Packet> template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type

View File

@ -269,6 +269,17 @@ struct sign_impl<Scalar, true, IsInteger>
} }
}; };
// The sign function for bool is the identity.
template<>
struct sign_impl<bool, false, true>
{
EIGEN_DEVICE_FUNC
static inline bool run(const bool& a)
{
return a;
}
};
template<typename Scalar> template<typename Scalar>
struct sign_retval struct sign_retval
{ {

View File

@ -695,6 +695,12 @@ template<> EIGEN_STRONG_INLINE Packet8i pmax<Packet8i>(const Packet8i& a, const
#endif #endif
} }
#ifdef EIGEN_VECTORIZE_AVX2
template<> EIGEN_STRONG_INLINE Packet8i psign(const Packet8i& a) {
return _mm256_sign_epi32(_mm256_set1_epi32(1), a);
}
#endif
// Add specializations for min/max with prescribed NaN progation. // Add specializations for min/max with prescribed NaN progation.
template<> template<>
EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) { EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {

View File

@ -853,92 +853,95 @@ Packet psqrt_complex(const Packet& a) {
} }
/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is template <typename Packet>
strictly positive. */ struct psign_impl<
template<typename Packet> Packet,
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t<
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex && !NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet> !NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
psign(const Packet& a) { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1)); const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1)); const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a); const Packet cst_zero = pzero(a);
const Packet not_nan_mask = pcmp_eq(a, a); const Packet not_nan_mask = pcmp_eq(a, a);
const Packet positive_mask = pcmp_lt(cst_zero, a); const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one); const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero); const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one); const Packet negative = pand(negative_mask, cst_minus_one);
return pselect(not_nan_mask, por(positive, negative), a); return pselect(not_nan_mask, por(positive, negative), a);
} }
};
template<typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS struct psign_impl<
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex && Packet, std::enable_if_t<
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned && !NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet> NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
psign(const Packet& a) { NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
using Scalar = typename unpacket_traits<Packet>::type; static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
const Packet cst_one = pset1<Packet>(Scalar(1)); using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_minus_one = pset1<Packet>(Scalar(-1)); const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_zero = pzero(a); const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a);
const Packet positive_mask = pcmp_lt(cst_zero, a); const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one); const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero); const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one); const Packet negative = pand(negative_mask, cst_minus_one);
return por(positive, negative); return por(positive, negative);
} }
};
template<typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex && !NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned && NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet> static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
psign(const Packet& a) { using Scalar = typename unpacket_traits<Packet>::type;
using Scalar = typename unpacket_traits<Packet>::type; const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_one = pset1<Packet>(Scalar(1)); const Packet cst_zero = pzero(a);
const Packet cst_zero = pzero(a);
const Packet zero_mask = pcmp_eq(cst_zero, a); const Packet zero_mask = pcmp_eq(cst_zero, a);
return pandnot(cst_one, zero_mask); return pandnot(cst_one, zero_mask);
} }
};
// \internal \returns the the sign of a complex number z, defined as z / abs(z). // \internal \returns the the sign of a complex number z, defined as z / abs(z).
template<typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex>> {
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex, Packet> static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
psign(const Packet& a) { typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::type Scalar; typedef typename Scalar::value_type RealScalar;
typedef typename Scalar::value_type RealScalar; typedef typename unpacket_traits<Packet>::as_real RealPacket;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
// Step 1. Compute (for each element z = x + i*y in a) // Step 1. Compute (for each element z = x + i*y in a)
// l = abs(z) = sqrt(x^2 + y^2). // l = abs(z) = sqrt(x^2 + y^2).
// To avoid over- and underflow, we use the stable formula for each hypotenuse // To avoid over- and underflow, we use the stable formula for each hypotenuse
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)), // l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|), // where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
RealPacket a_abs = pabs(a.v); RealPacket a_abs = pabs(a.v);
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v; RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
RealPacket a_max = pmax(a_abs, a_abs_flip); RealPacket a_max = pmax(a_abs, a_abs_flip);
RealPacket a_min = pmin(a_abs, a_abs_flip); RealPacket a_min = pmin(a_abs, a_abs_flip);
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min)); RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max)); RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
RealPacket r = pdiv(a_min, a_max); RealPacket r = pdiv(a_min, a_max);
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1)); const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1] RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be // Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
// lossy. // lossy.
l = pselect(a_min_zero_mask, a_max, l); l = pselect(a_min_zero_mask, a_max, l);
// Step 2 compute a / abs(a). // Step 2 compute a / abs(a).
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask); RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
Packet sign; Packet sign;
sign.v = sign_as_real; sign.v = sign_as_real;
return sign; return sign;
} }
};
// TODO(rmlarsen): The following set of utilities for double word arithmetic // TODO(rmlarsen): The following set of utilities for double word arithmetic
// should perhaps be refactored as a separate file, since it would be generally // should perhaps be refactored as a separate file, since it would be generally

View File

@ -218,7 +218,8 @@ template<> struct packet_traits<bool> : default_packet_traits
HasMin = 0, HasMin = 0,
HasMax = 0, HasMax = 0,
HasConj = 0, HasConj = 0,
HasSqrt = 1 HasSqrt = 1,
HasSign = 0 // Don't try to vectorize psign<bool> = identity.
}; };
}; };