mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
Specialize psign<Packet8i> for AVX2, don't vectorize psign<bool>.
This commit is contained in:
parent
98e51c9e24
commit
7064ed1345
@ -921,6 +921,24 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); }
|
|||||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||||
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
|
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
|
||||||
|
|
||||||
|
template<typename Packet, typename EnableIf = void>
|
||||||
|
struct psign_impl {
|
||||||
|
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||||
|
return numext::sign(a);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/** \internal \returns the sign of \a a (coeff-wise) */
|
||||||
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
|
psign(const Packet& a) {
|
||||||
|
return psign_impl<Packet>::run(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<> EIGEN_DEVICE_FUNC inline bool
|
||||||
|
psign(const bool& a) {
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
/** \internal \returns the first element of a packet */
|
/** \internal \returns the first element of a packet */
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
|
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type
|
||||||
|
@ -269,6 +269,17 @@ struct sign_impl<Scalar, true, IsInteger>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The sign function for bool is the identity.
|
||||||
|
template<>
|
||||||
|
struct sign_impl<bool, false, true>
|
||||||
|
{
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
static inline bool run(const bool& a)
|
||||||
|
{
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Scalar>
|
template<typename Scalar>
|
||||||
struct sign_retval
|
struct sign_retval
|
||||||
{
|
{
|
||||||
|
@ -695,6 +695,12 @@ template<> EIGEN_STRONG_INLINE Packet8i pmax<Packet8i>(const Packet8i& a, const
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef EIGEN_VECTORIZE_AVX2
|
||||||
|
template<> EIGEN_STRONG_INLINE Packet8i psign(const Packet8i& a) {
|
||||||
|
return _mm256_sign_epi32(_mm256_set1_epi32(1), a);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Add specializations for min/max with prescribed NaN progation.
|
// Add specializations for min/max with prescribed NaN progation.
|
||||||
template<>
|
template<>
|
||||||
EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {
|
EIGEN_STRONG_INLINE Packet8f pmin<PropagateNumbers, Packet8f>(const Packet8f& a, const Packet8f& b) {
|
||||||
|
@ -853,92 +853,95 @@ Packet psqrt_complex(const Packet& a) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is
|
template <typename Packet>
|
||||||
strictly positive. */
|
struct psign_impl<
|
||||||
template<typename Packet>
|
Packet,
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
std::enable_if_t<
|
||||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
|
||||||
psign(const Packet& a) {
|
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||||
using Scalar = typename unpacket_traits<Packet>::type;
|
using Scalar = typename unpacket_traits<Packet>::type;
|
||||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||||
const Packet cst_zero = pzero(a);
|
const Packet cst_zero = pzero(a);
|
||||||
|
|
||||||
const Packet not_nan_mask = pcmp_eq(a, a);
|
const Packet not_nan_mask = pcmp_eq(a, a);
|
||||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||||
const Packet positive = pand(positive_mask, cst_one);
|
const Packet positive = pand(positive_mask, cst_one);
|
||||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||||
|
|
||||||
return pselect(not_nan_mask, por(positive, negative), a);
|
return pselect(not_nan_mask, por(positive, negative), a);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
struct psign_impl<
|
||||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
Packet, std::enable_if_t<
|
||||||
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||||
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
||||||
psign(const Packet& a) {
|
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
|
||||||
using Scalar = typename unpacket_traits<Packet>::type;
|
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
using Scalar = typename unpacket_traits<Packet>::type;
|
||||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||||
const Packet cst_zero = pzero(a);
|
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||||
|
const Packet cst_zero = pzero(a);
|
||||||
|
|
||||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||||
const Packet positive = pand(positive_mask, cst_one);
|
const Packet positive = pand(positive_mask, cst_one);
|
||||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||||
|
|
||||||
return por(positive, negative);
|
return por(positive, negative);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
||||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
|
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
|
||||||
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||||
psign(const Packet& a) {
|
using Scalar = typename unpacket_traits<Packet>::type;
|
||||||
using Scalar = typename unpacket_traits<Packet>::type;
|
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
const Packet cst_zero = pzero(a);
|
||||||
const Packet cst_zero = pzero(a);
|
|
||||||
|
|
||||||
const Packet zero_mask = pcmp_eq(cst_zero, a);
|
const Packet zero_mask = pcmp_eq(cst_zero, a);
|
||||||
return pandnot(cst_one, zero_mask);
|
return pandnot(cst_one, zero_mask);
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
|
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
|
||||||
template<typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex>> {
|
||||||
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex, Packet>
|
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
|
||||||
psign(const Packet& a) {
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
typedef typename Scalar::value_type RealScalar;
|
||||||
typedef typename Scalar::value_type RealScalar;
|
typedef typename unpacket_traits<Packet>::as_real RealPacket;
|
||||||
typedef typename unpacket_traits<Packet>::as_real RealPacket;
|
|
||||||
|
|
||||||
// Step 1. Compute (for each element z = x + i*y in a)
|
// Step 1. Compute (for each element z = x + i*y in a)
|
||||||
// l = abs(z) = sqrt(x^2 + y^2).
|
// l = abs(z) = sqrt(x^2 + y^2).
|
||||||
// To avoid over- and underflow, we use the stable formula for each hypotenuse
|
// To avoid over- and underflow, we use the stable formula for each hypotenuse
|
||||||
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
|
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
|
||||||
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
|
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
|
||||||
RealPacket a_abs = pabs(a.v);
|
RealPacket a_abs = pabs(a.v);
|
||||||
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
|
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
|
||||||
RealPacket a_max = pmax(a_abs, a_abs_flip);
|
RealPacket a_max = pmax(a_abs, a_abs_flip);
|
||||||
RealPacket a_min = pmin(a_abs, a_abs_flip);
|
RealPacket a_min = pmin(a_abs, a_abs_flip);
|
||||||
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
|
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
|
||||||
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
|
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
|
||||||
RealPacket r = pdiv(a_min, a_max);
|
RealPacket r = pdiv(a_min, a_max);
|
||||||
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
|
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
|
||||||
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
|
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
|
||||||
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
|
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
|
||||||
// lossy.
|
// lossy.
|
||||||
l = pselect(a_min_zero_mask, a_max, l);
|
l = pselect(a_min_zero_mask, a_max, l);
|
||||||
// Step 2 compute a / abs(a).
|
// Step 2 compute a / abs(a).
|
||||||
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
|
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
|
||||||
Packet sign;
|
Packet sign;
|
||||||
sign.v = sign_as_real;
|
sign.v = sign_as_real;
|
||||||
return sign;
|
return sign;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// TODO(rmlarsen): The following set of utilities for double word arithmetic
|
// TODO(rmlarsen): The following set of utilities for double word arithmetic
|
||||||
// should perhaps be refactored as a separate file, since it would be generally
|
// should perhaps be refactored as a separate file, since it would be generally
|
||||||
|
@ -218,7 +218,8 @@ template<> struct packet_traits<bool> : default_packet_traits
|
|||||||
HasMin = 0,
|
HasMin = 0,
|
||||||
HasMax = 0,
|
HasMax = 0,
|
||||||
HasConj = 0,
|
HasConj = 0,
|
||||||
HasSqrt = 1
|
HasSqrt = 1,
|
||||||
|
HasSign = 0 // Don't try to vectorize psign<bool> = identity.
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user