Specialize psign<Packet8i> for AVX2, don't vectorize psign<bool>.

This commit is contained in:
Rasmus Munk Larsen 2022-08-26 17:02:37 +00:00
parent 98e51c9e24
commit 7064ed1345
5 changed files with 117 additions and 78 deletions

View File

@ -921,6 +921,24 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); }
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
template<typename Packet, typename EnableIf = void>
struct psign_impl {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
return numext::sign(a);
}
};
/** \internal \returns the sign of \a a (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
psign(const Packet& a) {
return psign_impl<Packet>::run(a);
}
template<> EIGEN_DEVICE_FUNC inline bool
psign(const bool& a) {
return a;
}
/** \internal \returns the first element of a packet */
template<typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type

View File

@ -269,6 +269,17 @@ struct sign_impl<Scalar, true, IsInteger>
}
};
// The sign function for bool is the identity.
template<>
struct sign_impl<bool, false, true>
{
EIGEN_DEVICE_FUNC
static inline bool run(const bool& a)
{
return a;
}
};
template<typename Scalar>
struct sign_retval
{

View File

@ -695,6 +695,12 @@ template<> EIGEN_STRONG_INLINE Packet8i pmax<Packet8i>(const Packet8i& a, const
#endif
}
#ifdef EIGEN_VECTORIZE_AVX2
template<> EIGEN_STRONG_INLINE Packet8i psign(const Packet8i& a) {
return _mm256_sign_epi32(_mm256_set1_epi32(1), a);
}
#endif
// Add specializations for min/max with prescribed NaN progation.
template<>
EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {

View File

@ -853,92 +853,95 @@ Packet psqrt_complex(const Packet& a) {
}
/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is
strictly positive. */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
psign(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a);
template <typename Packet>
struct psign_impl<
Packet,
std::enable_if_t<
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a);
const Packet not_nan_mask = pcmp_eq(a, a);
const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one);
const Packet not_nan_mask = pcmp_eq(a, a);
const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one);
return pselect(not_nan_mask, por(positive, negative), a);
}
return pselect(not_nan_mask, por(positive, negative), a);
}
};
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
psign(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a);
template <typename Packet>
struct psign_impl<
Packet, std::enable_if_t<
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a);
const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one);
const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one);
return por(positive, negative);
}
return por(positive, negative);
}
};
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
psign(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_zero = pzero(a);
template <typename Packet>
struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_zero = pzero(a);
const Packet zero_mask = pcmp_eq(cst_zero, a);
return pandnot(cst_one, zero_mask);
}
const Packet zero_mask = pcmp_eq(cst_zero, a);
return pandnot(cst_one, zero_mask);
}
};
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex, Packet>
psign(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
template <typename Packet>
struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
// Step 1. Compute (for each element z = x + i*y in a)
// l = abs(z) = sqrt(x^2 + y^2).
// To avoid over- and underflow, we use the stable formula for each hypotenuse
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
RealPacket a_abs = pabs(a.v);
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
RealPacket a_max = pmax(a_abs, a_abs_flip);
RealPacket a_min = pmin(a_abs, a_abs_flip);
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
RealPacket r = pdiv(a_min, a_max);
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
// lossy.
l = pselect(a_min_zero_mask, a_max, l);
// Step 2 compute a / abs(a).
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
Packet sign;
sign.v = sign_as_real;
return sign;
}
// Step 1. Compute (for each element z = x + i*y in a)
// l = abs(z) = sqrt(x^2 + y^2).
// To avoid over- and underflow, we use the stable formula for each hypotenuse
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
RealPacket a_abs = pabs(a.v);
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
RealPacket a_max = pmax(a_abs, a_abs_flip);
RealPacket a_min = pmin(a_abs, a_abs_flip);
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
RealPacket r = pdiv(a_min, a_max);
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
// lossy.
l = pselect(a_min_zero_mask, a_max, l);
// Step 2 compute a / abs(a).
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
Packet sign;
sign.v = sign_as_real;
return sign;
}
};
// TODO(rmlarsen): The following set of utilities for double word arithmetic
// should perhaps be refactored as a separate file, since it would be generally

View File

@ -207,7 +207,7 @@ template<> struct packet_traits<bool> : default_packet_traits
AlignedOnScalar = 1,
HasHalfPacket = 0,
size=16,
HasAdd = 1,
HasSub = 1,
HasShift = 0,
@ -218,7 +218,8 @@ template<> struct packet_traits<bool> : default_packet_traits
HasMin = 0,
HasMax = 0,
HasConj = 0,
HasSqrt = 1
HasSqrt = 1,
HasSign = 0 // Don't try to vectorize psign<bool> = identity.
};
};