Vectorize the sign operator in Eigen.

This commit is contained in:
Rasmus Munk Larsen 2022-08-09 19:54:57 +00:00
parent be20207d10
commit 97e0784dc6
7 changed files with 143 additions and 37 deletions

View File

@ -59,6 +59,7 @@ struct default_packet_traits
HasMax = 1,
HasConj = 1,
HasSetLinear = 1,
HasSign = 1,
HasBlend = 0,
// This flag is used to indicate whether packet comparison is supported.
// pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true.
@ -101,8 +102,7 @@ struct default_packet_traits
HasRound = 0,
HasRint = 0,
HasFloor = 0,
HasCeil = 0,
HasSign = 0
HasCeil = 0
};
};
@ -179,7 +179,7 @@ struct eigen_packet_wrapper
*/
template<typename Packet>
struct is_scalar {
typedef typename unpacket_traits<Packet>::type Scalar;
using Scalar = typename unpacket_traits<Packet>::type;
enum {
value = internal::is_same<Packet, Scalar>::value
};

View File

@ -229,6 +229,52 @@ struct imag_ref_retval
typedef typename NumTraits<Scalar>::Real & type;
};
/****************************************************************************
* Implementation of sign *
****************************************************************************/
template<typename Scalar, bool IsComplex = (NumTraits<Scalar>::IsComplex!=0),
bool IsInteger = (NumTraits<Scalar>::IsInteger!=0)>
struct sign_impl
{
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& a)
{
return Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
}
};
template<typename Scalar>
struct sign_impl<Scalar, false, false>
{
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& a)
{
return (std::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
}
};
template<typename Scalar, bool IsInteger>
struct sign_impl<Scalar, true, IsInteger>
{
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& a)
{
using real_type = typename NumTraits<Scalar>::Real;
real_type aa = std::abs(a);
if (aa==real_type(0))
return Scalar(0);
aa = real_type(1)/aa;
return Scalar(a.real()*aa, a.imag()*aa );
}
};
template<typename Scalar>
struct sign_retval
{
typedef Scalar type;
};
/****************************************************************************
* Implementation of conj *
****************************************************************************/
@ -1279,6 +1325,13 @@ inline EIGEN_MATHFUNC_RETVAL(conj, Scalar) conj(const Scalar& x)
return EIGEN_MATHFUNC_IMPL(conj, Scalar)::run(x);
}
template<typename Scalar>
EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(sign, Scalar) sign(const Scalar& x)
{
return EIGEN_MATHFUNC_IMPL(sign, Scalar)::run(x);
}
template<typename Scalar>
EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(abs2, Scalar) abs2(const Scalar& x)

View File

@ -852,6 +852,79 @@ Packet psqrt_complex(const Packet& a) {
pselect(is_real_inf, real_inf_result,result));
}
/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is
strictly positive. */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
psign(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a);
const Packet not_nan_mask = pcmp_eq(a, a);
const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one);
return pselect(not_nan_mask, por(positive, negative), a);
}
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
psign(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a);
const Packet positive_mask = pcmp_lt(cst_zero, a);
const Packet positive = pand(positive_mask, cst_one);
const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one);
return por(positive, negative);
}
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex, Packet>
psign(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;
typedef typename unpacket_traits<Packet>::as_real RealPacket;
// Step 1. Compute (for each element z = x + i*y in a)
// l = abs(z) = sqrt(x^2 + y^2).
// To avoid over- and underflow, we use the stable formula for each hypotenuse
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
RealPacket a_abs = pabs(a.v);
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
RealPacket a_max = pmax(a_abs, a_abs_flip);
RealPacket a_min = pmin(a_abs, a_abs_flip);
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
RealPacket r = pdiv(a_min, a_max);
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
// lossy.
l = pselect(a_min_zero_mask, a_max, l);
// Step 2 compute a / abs(a).
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
Packet sign;
sign.v = sign_as_real;
return sign;
}
// TODO(rmlarsen): The following set of utilities for double word arithmetic
// should perhaps be refactored as a separate file, since it would be generally
// useful for special function implementation etc. Writing the algorithms in

View File

@ -155,7 +155,8 @@ struct packet_traits<float> : default_packet_traits {
#ifdef EIGEN_VECTORIZE_SSE4_1
HasRound = 1,
#endif
HasRint = 1
HasRint = 1,
HasSign = 0 // The manually vectorized version is slightly slower for SSE.
};
};
template <>
@ -217,6 +218,7 @@ template<> struct packet_traits<bool> : default_packet_traits
HasMin = 0,
HasMax = 0,
HasConj = 0,
HasSign = 0,
HasSqrt = 1
};
};

View File

@ -916,44 +916,19 @@ struct functor_traits<scalar_boolean_not_op<Scalar> > {
* \brief Template functor to compute the signum of a scalar
* \sa class CwiseUnaryOp, Cwise::sign()
*/
template<typename Scalar,bool is_complex=(NumTraits<Scalar>::IsComplex!=0), bool is_integer=(NumTraits<Scalar>::IsInteger!=0) > struct scalar_sign_op;
template<typename Scalar>
struct scalar_sign_op<Scalar, false, true> {
struct scalar_sign_op {
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
return Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
return numext::sign(a);
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
return internal::psign(a);
}
//TODO
//template <typename Packet>
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
};
template<typename Scalar>
struct scalar_sign_op<Scalar, false, false> {
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
return (numext::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
}
//TODO
//template <typename Packet>
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
};
template<typename Scalar, bool is_integer>
struct scalar_sign_op<Scalar,true, is_integer> {
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{
typedef typename NumTraits<Scalar>::Real real_type;
real_type aa = numext::abs(a);
if (aa==real_type(0))
return Scalar(0);
aa = real_type(1)/aa;
return Scalar(a.real()*aa, a.imag()*aa );
}
//TODO
//template <typename Packet>
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
};
template<typename Scalar>
struct functor_traits<scalar_sign_op<Scalar> >
{ enum {

View File

@ -196,7 +196,7 @@ template<typename Scalar, typename NewType> struct scalar_cast_op;
template<typename Scalar> struct scalar_random_op;
template<typename Scalar> struct scalar_constant_op;
template<typename Scalar> struct scalar_identity_op;
template<typename Scalar,bool is_complex, bool is_integer> struct scalar_sign_op;
template<typename Scalar> struct scalar_sign_op;
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;

View File

@ -527,6 +527,7 @@ void packetmath() {
CHECK_CWISE1_IF(PacketTraits::HasNegate, internal::negate, internal::pnegate);
CHECK_CWISE1_IF(PacketTraits::HasReciprocal, REF_RECIPROCAL, internal::preciprocal);
CHECK_CWISE1(numext::conj, internal::pconj);
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
for (int offset = 0; offset < 3; ++offset) {
@ -816,6 +817,7 @@ void packetmath_real() {
CHECK_CWISE1_EXACT_IF(PacketTraits::HasCeil, numext::ceil, internal::pceil);
CHECK_CWISE1_EXACT_IF(PacketTraits::HasFloor, numext::floor, internal::pfloor);
CHECK_CWISE1_EXACT_IF(PacketTraits::HasRint, numext::rint, internal::print);
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
packetmath_boolean_mask_ops_real<Scalar,Packet>();
@ -1340,6 +1342,7 @@ void packetmath_complex() {
data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>());
}
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, size);
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
// Test misc. corner cases.
const RealScalar zero = RealScalar(0);