mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-22 01:29:35 +08:00
Vectorize the sign operator in Eigen.
This commit is contained in:
parent
be20207d10
commit
97e0784dc6
@ -59,6 +59,7 @@ struct default_packet_traits
|
||||
HasMax = 1,
|
||||
HasConj = 1,
|
||||
HasSetLinear = 1,
|
||||
HasSign = 1,
|
||||
HasBlend = 0,
|
||||
// This flag is used to indicate whether packet comparison is supported.
|
||||
// pcmp_eq, pcmp_lt and pcmp_le should be defined for it to be true.
|
||||
@ -101,8 +102,7 @@ struct default_packet_traits
|
||||
HasRound = 0,
|
||||
HasRint = 0,
|
||||
HasFloor = 0,
|
||||
HasCeil = 0,
|
||||
HasSign = 0
|
||||
HasCeil = 0
|
||||
};
|
||||
};
|
||||
|
||||
@ -179,7 +179,7 @@ struct eigen_packet_wrapper
|
||||
*/
|
||||
template<typename Packet>
|
||||
struct is_scalar {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
enum {
|
||||
value = internal::is_same<Packet, Scalar>::value
|
||||
};
|
||||
|
@ -229,6 +229,52 @@ struct imag_ref_retval
|
||||
typedef typename NumTraits<Scalar>::Real & type;
|
||||
};
|
||||
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of sign *
|
||||
****************************************************************************/
|
||||
template<typename Scalar, bool IsComplex = (NumTraits<Scalar>::IsComplex!=0),
|
||||
bool IsInteger = (NumTraits<Scalar>::IsInteger!=0)>
|
||||
struct sign_impl
|
||||
{
|
||||
EIGEN_DEVICE_FUNC
|
||||
static inline Scalar run(const Scalar& a)
|
||||
{
|
||||
return Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar>
|
||||
struct sign_impl<Scalar, false, false>
|
||||
{
|
||||
EIGEN_DEVICE_FUNC
|
||||
static inline Scalar run(const Scalar& a)
|
||||
{
|
||||
return (std::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar, bool IsInteger>
|
||||
struct sign_impl<Scalar, true, IsInteger>
|
||||
{
|
||||
EIGEN_DEVICE_FUNC
|
||||
static inline Scalar run(const Scalar& a)
|
||||
{
|
||||
using real_type = typename NumTraits<Scalar>::Real;
|
||||
real_type aa = std::abs(a);
|
||||
if (aa==real_type(0))
|
||||
return Scalar(0);
|
||||
aa = real_type(1)/aa;
|
||||
return Scalar(a.real()*aa, a.imag()*aa );
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Scalar>
|
||||
struct sign_retval
|
||||
{
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of conj *
|
||||
****************************************************************************/
|
||||
@ -1279,6 +1325,13 @@ inline EIGEN_MATHFUNC_RETVAL(conj, Scalar) conj(const Scalar& x)
|
||||
return EIGEN_MATHFUNC_IMPL(conj, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template<typename Scalar>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline EIGEN_MATHFUNC_RETVAL(sign, Scalar) sign(const Scalar& x)
|
||||
{
|
||||
return EIGEN_MATHFUNC_IMPL(sign, Scalar)::run(x);
|
||||
}
|
||||
|
||||
template<typename Scalar>
|
||||
EIGEN_DEVICE_FUNC
|
||||
inline EIGEN_MATHFUNC_RETVAL(abs2, Scalar) abs2(const Scalar& x)
|
||||
|
@ -852,6 +852,79 @@ Packet psqrt_complex(const Packet& a) {
|
||||
pselect(is_real_inf, real_inf_result,result));
|
||||
}
|
||||
|
||||
|
||||
/** \internal \returns -1 if a is strictly negative, 0 otherwise, +1 if a is
|
||||
strictly positive. */
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
||||
psign(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
|
||||
const Packet not_nan_mask = pcmp_eq(a, a);
|
||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||
const Packet positive = pand(positive_mask, cst_one);
|
||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||
|
||||
return pselect(not_nan_mask, por(positive, negative), a);
|
||||
}
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
std::enable_if_t<(!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
|
||||
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger), Packet>
|
||||
psign(const Packet& a) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||
const Packet cst_zero = pzero(a);
|
||||
|
||||
const Packet positive_mask = pcmp_lt(cst_zero, a);
|
||||
const Packet positive = pand(positive_mask, cst_one);
|
||||
const Packet negative_mask = pcmp_lt(a, cst_zero);
|
||||
const Packet negative = pand(negative_mask, cst_minus_one);
|
||||
|
||||
return por(positive, negative);
|
||||
}
|
||||
|
||||
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
|
||||
template<typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex, Packet>
|
||||
psign(const Packet& a) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename Scalar::value_type RealScalar;
|
||||
typedef typename unpacket_traits<Packet>::as_real RealPacket;
|
||||
|
||||
// Step 1. Compute (for each element z = x + i*y in a)
|
||||
// l = abs(z) = sqrt(x^2 + y^2).
|
||||
// To avoid over- and underflow, we use the stable formula for each hypotenuse
|
||||
// l = (zmin == 0 ? zmax : zmax * sqrt(1 + (zmin/zmax)**2)),
|
||||
// where zmax = max(|x|, |y|), zmin = min(|x|, |y|),
|
||||
RealPacket a_abs = pabs(a.v);
|
||||
RealPacket a_abs_flip = pcplxflip(Packet(a_abs)).v;
|
||||
RealPacket a_max = pmax(a_abs, a_abs_flip);
|
||||
RealPacket a_min = pmin(a_abs, a_abs_flip);
|
||||
RealPacket a_min_zero_mask = pcmp_eq(a_min, pzero(a_min));
|
||||
RealPacket a_max_zero_mask = pcmp_eq(a_max, pzero(a_max));
|
||||
RealPacket r = pdiv(a_min, a_max);
|
||||
const RealPacket cst_one = pset1<RealPacket>(RealScalar(1));
|
||||
RealPacket l = pmul(a_max, psqrt(padd(cst_one, pmul(r, r)))); // [l0, l0, l1, l1]
|
||||
// Set l to a_max if a_min is zero, since the roundtrip sqrt(a_max^2) may be
|
||||
// lossy.
|
||||
l = pselect(a_min_zero_mask, a_max, l);
|
||||
// Step 2 compute a / abs(a).
|
||||
RealPacket sign_as_real = pandnot(pdiv(a.v, l), a_max_zero_mask);
|
||||
Packet sign;
|
||||
sign.v = sign_as_real;
|
||||
return sign;
|
||||
}
|
||||
|
||||
// TODO(rmlarsen): The following set of utilities for double word arithmetic
|
||||
// should perhaps be refactored as a separate file, since it would be generally
|
||||
// useful for special function implementation etc. Writing the algorithms in
|
||||
|
@ -155,7 +155,8 @@ struct packet_traits<float> : default_packet_traits {
|
||||
#ifdef EIGEN_VECTORIZE_SSE4_1
|
||||
HasRound = 1,
|
||||
#endif
|
||||
HasRint = 1
|
||||
HasRint = 1,
|
||||
HasSign = 0 // The manually vectorized version is slightly slower for SSE.
|
||||
};
|
||||
};
|
||||
template <>
|
||||
@ -217,6 +218,7 @@ template<> struct packet_traits<bool> : default_packet_traits
|
||||
HasMin = 0,
|
||||
HasMax = 0,
|
||||
HasConj = 0,
|
||||
HasSign = 0,
|
||||
HasSqrt = 1
|
||||
};
|
||||
};
|
||||
|
@ -916,44 +916,19 @@ struct functor_traits<scalar_boolean_not_op<Scalar> > {
|
||||
* \brief Template functor to compute the signum of a scalar
|
||||
* \sa class CwiseUnaryOp, Cwise::sign()
|
||||
*/
|
||||
template<typename Scalar,bool is_complex=(NumTraits<Scalar>::IsComplex!=0), bool is_integer=(NumTraits<Scalar>::IsInteger!=0) > struct scalar_sign_op;
|
||||
template<typename Scalar>
|
||||
struct scalar_sign_op<Scalar, false, true> {
|
||||
struct scalar_sign_op {
|
||||
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
|
||||
{
|
||||
return Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
|
||||
return numext::sign(a);
|
||||
}
|
||||
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const {
|
||||
return internal::psign(a);
|
||||
}
|
||||
//TODO
|
||||
//template <typename Packet>
|
||||
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
|
||||
};
|
||||
|
||||
template<typename Scalar>
|
||||
struct scalar_sign_op<Scalar, false, false> {
|
||||
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
|
||||
{
|
||||
return (numext::isnan)(a) ? a : Scalar( (a>Scalar(0)) - (a<Scalar(0)) );
|
||||
}
|
||||
//TODO
|
||||
//template <typename Packet>
|
||||
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
|
||||
};
|
||||
|
||||
template<typename Scalar, bool is_integer>
|
||||
struct scalar_sign_op<Scalar,true, is_integer> {
|
||||
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
|
||||
{
|
||||
typedef typename NumTraits<Scalar>::Real real_type;
|
||||
real_type aa = numext::abs(a);
|
||||
if (aa==real_type(0))
|
||||
return Scalar(0);
|
||||
aa = real_type(1)/aa;
|
||||
return Scalar(a.real()*aa, a.imag()*aa );
|
||||
}
|
||||
//TODO
|
||||
//template <typename Packet>
|
||||
//EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& a) const { return internal::psign(a); }
|
||||
};
|
||||
template<typename Scalar>
|
||||
struct functor_traits<scalar_sign_op<Scalar> >
|
||||
{ enum {
|
||||
|
@ -196,7 +196,7 @@ template<typename Scalar, typename NewType> struct scalar_cast_op;
|
||||
template<typename Scalar> struct scalar_random_op;
|
||||
template<typename Scalar> struct scalar_constant_op;
|
||||
template<typename Scalar> struct scalar_identity_op;
|
||||
template<typename Scalar,bool is_complex, bool is_integer> struct scalar_sign_op;
|
||||
template<typename Scalar> struct scalar_sign_op;
|
||||
template<typename Scalar,typename ScalarExponent> struct scalar_pow_op;
|
||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_hypot_op;
|
||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_product_op;
|
||||
|
@ -527,6 +527,7 @@ void packetmath() {
|
||||
CHECK_CWISE1_IF(PacketTraits::HasNegate, internal::negate, internal::pnegate);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasReciprocal, REF_RECIPROCAL, internal::preciprocal);
|
||||
CHECK_CWISE1(numext::conj, internal::pconj);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
|
||||
|
||||
|
||||
for (int offset = 0; offset < 3; ++offset) {
|
||||
@ -816,6 +817,7 @@ void packetmath_real() {
|
||||
CHECK_CWISE1_EXACT_IF(PacketTraits::HasCeil, numext::ceil, internal::pceil);
|
||||
CHECK_CWISE1_EXACT_IF(PacketTraits::HasFloor, numext::floor, internal::pfloor);
|
||||
CHECK_CWISE1_EXACT_IF(PacketTraits::HasRint, numext::rint, internal::print);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
|
||||
|
||||
packetmath_boolean_mask_ops_real<Scalar,Packet>();
|
||||
|
||||
@ -1340,6 +1342,7 @@ void packetmath_complex() {
|
||||
data1[i] = Scalar(internal::random<RealScalar>(), internal::random<RealScalar>());
|
||||
}
|
||||
CHECK_CWISE1_N(numext::sqrt, internal::psqrt, size);
|
||||
CHECK_CWISE1_IF(PacketTraits::HasSign, numext::sign, internal::psign);
|
||||
|
||||
// Test misc. corner cases.
|
||||
const RealScalar zero = RealScalar(0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user