mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-10 18:59:01 +08:00
Specialize psign<Packet8i> for AVX2, don't vectorize psign<bool>.
This commit is contained in:
parent
98e51c9e24
commit
7064ed1345
@ -921,6 +921,24 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); }
|
||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
|
||||
|
||||
template<typename Packet, typename EnableIf = void>
|
||||
struct psign_impl {
|
||||
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||
return numext::sign(a);
|
||||
}
|
||||
};
|
||||
|
||||
/** \internal \returns the sign of \a a (coeff-wise) */
|
||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||
psign(const Packet& a) {
|
||||
return psign_impl<Packet>::run(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC inline bool
|
||||
psign(const bool& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
/** \internal \returns the first element of a packet */
|
||||
template<typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
|
||||
|
@ -269,6 +269,17 @@ struct sign_impl<Scalar, true, IsInteger>
|
||||
}
|
||||
};
|
||||
|
||||
// The sign function for bool is the identity.
|
||||
template<>
|
||||
struct sign_impl<bool, false, true>
|
||||
{
|
||||
EIGEN_DEVICE_FUNC
|
||||
static inline bool run(const bool& a)
|
||||
{
|
||||
return a;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar>
|
||||
struct sign_retval
|
||||
{
|
||||
|
@ -695,6 +695,12 @@ template<> EIGEN_STRONG_INLINE Packet8i pmax<Packet8i>(const Packet8i& a, const
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
template<> EIGEN_STRONG_INLINE Packet8i psign(const Packet8i& a) {
|
||||
return _mm256_sign_epi32(_mm256_set1_epi32(1), a);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Add specializations for min/max with prescribed NaN progation.
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {
|
||||
|
@ -853,92 +853,95 @@ Packet psqrt_complex(const Packet& a) {
|
||||
}
|
||||
|
||||
|
||||
/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is
|
||||
strictly positive. */
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
||||
psign(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
template <typename Packet>
|
||||
struct psign_impl<
|
||||
Packet,
|
||||
std::enable_if_t<
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
|
||||
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
|
||||
const Packet not_nan_mask = pcmp_eq(a, a);
|
||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||
const Packet positive = pand(positive_mask, cst_one);
|
||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||
const Packet not_nan_mask = pcmp_eq(a, a);
|
||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||
const Packet positive = pand(positive_mask, cst_one);
|
||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||
|
||||
return pselect(not_nan_mask, por(positive, negative), a);
|
||||
}
|
||||
return pselect(not_nan_mask, por(positive, negative), a);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
||||
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
||||
psign(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
template <typename Packet>
|
||||
struct psign_impl<
|
||||
Packet, std::enable_if_t<
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
||||
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
|
||||
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
|
||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||
const Packet positive = pand(positive_mask, cst_one);
|
||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||
const Packet positive = pand(positive_mask, cst_one);
|
||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||
|
||||
return por(positive, negative);
|
||||
}
|
||||
return por(positive, negative);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
||||
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
||||
psign(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
template <typename Packet>
|
||||
struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
||||
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
|
||||
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
|
||||
const Packet zero_mask = pcmp_eq(cst_zero, a);
|
||||
return pandnot(cst_one, zero_mask);
|
||||
}
|
||||
const Packet zero_mask = pcmp_eq(cst_zero, a);
|
||||
return pandnot(cst_one, zero_mask);
|
||||
}
|
||||
};
|
||||
|
||||
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex, Packet>
|
||||
psign(const Packet& a) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename Scalar::value_type RealScalar;
|
||||
typedef typename unpacket_traits<Packet>::as_real RealPacket;
|
||||
template <typename Packet>
|
||||
struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex>> {
|
||||
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename Scalar::value_type RealScalar;
|
||||
typedef typename unpacket_traits<Packet>::as_real RealPacket;
|
||||
|
||||
// Step 1. Compute (for each element z = x + i*y in a)
|
||||
// l = abs(z) = sqrt(x^2 + y^2).
|
||||
// To avoid over- and underflow, we use the stable formula for each hypotenuse
|
||||
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
|
||||
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
|
||||
RealPacket a_abs = pabs(a.v);
|
||||
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
|
||||
RealPacket a_max = pmax(a_abs, a_abs_flip);
|
||||
RealPacket a_min = pmin(a_abs, a_abs_flip);
|
||||
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
|
||||
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
|
||||
RealPacket r = pdiv(a_min, a_max);
|
||||
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
|
||||
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
|
||||
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
|
||||
// lossy.
|
||||
l = pselect(a_min_zero_mask, a_max, l);
|
||||
// Step 2 compute a / abs(a).
|
||||
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
|
||||
Packet sign;
|
||||
sign.v = sign_as_real;
|
||||
return sign;
|
||||
}
|
||||
// Step 1. Compute (for each element z = x + i*y in a)
|
||||
// l = abs(z) = sqrt(x^2 + y^2).
|
||||
// To avoid over- and underflow, we use the stable formula for each hypotenuse
|
||||
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
|
||||
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
|
||||
RealPacket a_abs = pabs(a.v);
|
||||
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
|
||||
RealPacket a_max = pmax(a_abs, a_abs_flip);
|
||||
RealPacket a_min = pmin(a_abs, a_abs_flip);
|
||||
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
|
||||
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
|
||||
RealPacket r = pdiv(a_min, a_max);
|
||||
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
|
||||
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
|
||||
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
|
||||
// lossy.
|
||||
l = pselect(a_min_zero_mask, a_max, l);
|
||||
// Step 2 compute a / abs(a).
|
||||
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
|
||||
Packet sign;
|
||||
sign.v = sign_as_real;
|
||||
return sign;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO(rmlarsen): The following set of utilities for double word arithmetic
|
||||
// should perhaps be refactored as a separate file, since it would be generally
|
||||
|
@ -207,7 +207,7 @@ template<> struct packet_traits<bool> : default_packet_traits
|
||||
AlignedOnScalar = 1,
|
||||
HasHalfPacket = 0,
|
||||
size=16,
|
||||
|
||||
|
||||
HasAdd = 1,
|
||||
HasSub = 1,
|
||||
HasShift = 0,
|
||||
@ -218,7 +218,8 @@ template<> struct packet_traits<bool> : default_packet_traits
|
||||
HasMin = 0,
|
||||
HasMax = 0,
|
||||
HasConj = 0,
|
||||
HasSqrt = 1
|
||||
HasSqrt = 1,
|
||||
HasSign = 0 // Don't try to vectorize psign<bool> = identity.
|
||||
};
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user