Don't use generic sign function for sign(complex) unless it is vectorizable

This commit is contained in:
Rasmus Munk Larsen 2022-10-12 16:03:29 +00:00
parent c0d6a72611
commit 462758e8a3
2 changed files with 3 additions and 2 deletions

View File

@ -1146,7 +1146,8 @@ struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<P
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
template <typename Packet>
struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex>> {
struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
unpacket_traits<Packet>::vectorizable>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename Scalar::value_type RealScalar;

View File

@ -930,7 +930,7 @@ struct functor_traits<scalar_sign_op<Scalar> >
NumTraits<Scalar>::IsComplex
? ( 8*NumTraits<Scalar>::MulCost ) // roughly
: ( 3*NumTraits<Scalar>::AddCost),
PacketAccess = packet_traits<Scalar>::HasSign
PacketAccess = packet_traits<Scalar>::HasSign && packet_traits<Scalar>::Vectorizable
};
};