Add support for vectorizing logical comparisons.

This commit is contained in:
derekjchow 2021-07-19 16:51:17 -07:00 committed by Derek Chow
parent a77638387d
commit 66ca41bd47
2 changed files with 28 additions and 4 deletions

View File

@ -219,6 +219,7 @@ struct packet_traits<int8_t> : default_packet_traits
size = 16, size = 16,
HasHalfPacket = 1, HasHalfPacket = 1,
HasCmp = 1,
HasAdd = 1, HasAdd = 1,
HasSub = 1, HasSub = 1,
HasShift = 1, HasShift = 1,
@ -248,6 +249,7 @@ struct packet_traits<uint8_t> : default_packet_traits
size = 16, size = 16,
HasHalfPacket = 1, HasHalfPacket = 1,
HasCmp = 1,
HasAdd = 1, HasAdd = 1,
HasSub = 1, HasSub = 1,
HasShift = 1, HasShift = 1,

View File

@ -205,7 +205,8 @@ template<typename LhsScalar, typename RhsScalar, ComparisonName cmp>
struct functor_traits<scalar_cmp_op<LhsScalar,RhsScalar, cmp> > { struct functor_traits<scalar_cmp_op<LhsScalar,RhsScalar, cmp> > {
enum { enum {
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2, Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
PacketAccess = false PacketAccess = is_same<LhsScalar, RhsScalar>::value &&
packet_traits<LhsScalar>::HasCmp
}; };
}; };
@ -221,6 +222,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_EQ> : binary_op_base<LhsScalar,Rhs
typedef bool result_type; typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a==b;} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a==b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(a,b); }
}; };
template<typename LhsScalar, typename RhsScalar> template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LT> : binary_op_base<LhsScalar,RhsScalar> struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LT> : binary_op_base<LhsScalar,RhsScalar>
@ -228,6 +232,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LT> : binary_op_base<LhsScalar,Rhs
typedef bool result_type; typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<b;} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_lt(a,b); }
}; };
template<typename LhsScalar, typename RhsScalar> template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LE> : binary_op_base<LhsScalar,RhsScalar> struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LE> : binary_op_base<LhsScalar,RhsScalar>
@ -235,6 +242,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LE> : binary_op_base<LhsScalar,Rhs
typedef bool result_type; typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<=b;} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_le(a,b); }
}; };
template<typename LhsScalar, typename RhsScalar> template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GT> : binary_op_base<LhsScalar,RhsScalar> struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GT> : binary_op_base<LhsScalar,RhsScalar>
@ -242,6 +252,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GT> : binary_op_base<LhsScalar,Rhs
typedef bool result_type; typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>b;} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_lt(b,a); }
}; };
template<typename LhsScalar, typename RhsScalar> template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GE> : binary_op_base<LhsScalar,RhsScalar> struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GE> : binary_op_base<LhsScalar,RhsScalar>
@ -249,6 +262,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GE> : binary_op_base<LhsScalar,Rhs
typedef bool result_type; typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>=b;} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_le(b,a); }
}; };
template<typename LhsScalar, typename RhsScalar> template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,RhsScalar> struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,RhsScalar>
@ -256,6 +272,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,
typedef bool result_type; typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return !(a<=b || b<=a);} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return !(a<=b || b<=a);}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(internal::por(internal::pcmp_le(a, b), internal::pcmp_le(b, a)), internal::pzero(a)); }
}; };
template<typename LhsScalar, typename RhsScalar> template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,RhsScalar> struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,RhsScalar>
@ -263,6 +282,9 @@ struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,Rh
typedef bool result_type; typedef bool result_type;
EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_cmp_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a!=b;} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a!=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(internal::pcmp_eq(a, b), internal::pzero(a)); }
}; };
/** \internal /** \internal