Optimize psign

This commit is contained in:
Charles Schlosser 2023-02-09 22:15:26 +00:00 committed by Rasmus Munk Larsen
parent 0e490d452d
commit 325e3063d9

View File

@ -1097,33 +1097,25 @@ Packet psqrt_complex(const Packet& a) {
template <typename Packet> template <typename Packet>
struct psign_impl< struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
Packet, !NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
std::enable_if_t<
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1)); const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_zero = pzero(a); const Packet cst_zero = pzero(a);
const Packet not_nan_mask = pcmp_eq(a, a); const Packet abs_a = pabs(a);
const Packet positive_mask = pcmp_lt(cst_zero, a); const Packet sign_mask = pandnot(a, abs_a);
const Packet positive = pand(positive_mask, cst_one); const Packet nonzero_mask = pcmp_lt(cst_zero, abs_a);
const Packet negative_mask = pcmp_lt(a, cst_zero);
const Packet negative = pand(negative_mask, cst_minus_one);
return pselect(not_nan_mask, por(positive, negative), a); return pselect(nonzero_mask, por(sign_mask, cst_one), abs_a);
} }
}; };
template <typename Packet> template <typename Packet>
struct psign_impl< struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
Packet, std::enable_if_t< NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex && NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_one = pset1<Packet>(Scalar(1)); const Packet cst_one = pset1<Packet>(Scalar(1));