diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index ab9c0e142..4287fa249 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -375,7 +375,7 @@ EIGEN_DEVICE_FUNC inline bool pdiv(const bool& a, const bool& b) { return a && b; } -// In the generic case, memset to all one bits. +// In the generic packet case, memset to all one bits. template struct ptrue_impl { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& /*a*/) { @@ -385,19 +385,16 @@ struct ptrue_impl { } }; +// Use a value of one for scalars. +template +struct ptrue_impl::value>> { + static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar&) { return Scalar(1); } +}; + // For booleans, we can only directly set a valid `bool` value to avoid UB. template <> struct ptrue_impl { - static EIGEN_DEVICE_FUNC inline bool run(const bool& /*a*/) { return true; } -}; - -// For non-trivial scalars, set to Scalar(1) (i.e. a non-zero value). -// Although this is technically not a valid bitmask, the scalar path for pselect -// uses a comparison to zero, so this should still work in most cases. We don't -// have another option, since the scalar type requires initialization. -template -struct ptrue_impl::value && NumTraits::RequireInitialization>> { - static EIGEN_DEVICE_FUNC inline T run(const T& /*a*/) { return T(1); } + static EIGEN_DEVICE_FUNC inline bool run(const bool&) { return true; } }; /** \internal \returns one bits. */ @@ -406,7 +403,7 @@ EIGEN_DEVICE_FUNC inline Packet ptrue(const Packet& a) { return ptrue_impl::run(a); } -// In the general case, memset to zero. +// In the general packet case, memset to zero. template struct pzero_impl { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& /*a*/) { @@ -875,17 +872,29 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet plset(const typename unpacket_trait return a; } +template +struct peven_mask_impl { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet&) { + typedef typename unpacket_traits::type Scalar; + const size_t n = unpacket_traits::size; + EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Scalar elements[n]; + for (size_t i = 0; i < n; ++i) { + memset(elements + i, ((i & 1) == 0 ? 0xff : 0), sizeof(Scalar)); + } + return ploadu(elements); + } +}; + +template +struct peven_mask_impl::value>> { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar&) { return Scalar(1); } +}; + /** \internal \returns a packet with constant coefficients \a a, e.g.: (x, 0, x, 0), where x is the value of all 1-bits. */ template -EIGEN_DEVICE_FUNC inline Packet peven_mask(const Packet& /*a*/) { - typedef typename unpacket_traits::type Scalar; - const size_t n = unpacket_traits::size; - EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Scalar elements[n]; - for (size_t i = 0; i < n; ++i) { - memset(elements + i, ((i & 1) == 0 ? 0xff : 0), sizeof(Scalar)); - } - return ploadu(elements); +EIGEN_DEVICE_FUNC inline Packet peven_mask(const Packet& a) { + return peven_mask_impl::run(a); } /** \internal copy the packet \a from to \a *to, \a to must be properly aligned */ diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 941961d99..481e057d0 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -182,10 +182,6 @@ struct imag_ref_retval { typedef typename NumTraits::Real& type; }; -// implementation in MathFunctionsImpl.h -template ::value> -struct scalar_select_mask; - } // namespace internal namespace numext { @@ -211,9 +207,9 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x); } -template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar select(const Mask& mask, const Scalar& a, const Scalar& b) { - return internal::scalar_select_mask::run(mask) ? b : a; +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar select(const Scalar& mask, const Scalar& a, const Scalar& b) { + return numext::is_exactly_zero(mask) ? b : a; } } // namespace numext diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index cbac1c2a4..cf8dcc3b8 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -256,48 +256,6 @@ EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) { return ComplexT(numext::log(a), b); } -// For generic scalars, use ternary select. -template -struct scalar_select_mask { - static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) { return numext::is_exactly_zero(mask); } -}; - -// For built-in float mask, bitcast the mask to its integer counterpart and use ternary select. -template -struct scalar_select_mask { - using IntegerType = typename numext::get_integer_by_size::unsigned_type; - static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) { - return numext::is_exactly_zero(numext::bit_cast(std::abs(mask))); - } -}; - -template -struct ldbl_select_mask { - static constexpr int MantissaDigits = std::numeric_limits::digits; - static constexpr int NumBytes = (MantissaDigits == 64 ? 80 : 128) / CHAR_BIT; - static EIGEN_DEVICE_FUNC inline bool run(const long double& mask) { - const uint8_t* mask_bytes = reinterpret_cast(&mask); - for (Index i = 0; i < NumBytes; i++) { - if (mask_bytes[i] != 0) return false; - } - return true; - } -}; - -template <> -struct ldbl_select_mask : scalar_select_mask {}; - -template <> -struct scalar_select_mask : ldbl_select_mask<> {}; - -template -struct scalar_select_mask, false> { - using impl = scalar_select_mask; - static EIGEN_DEVICE_FUNC inline bool run(const std::complex& mask) { - return impl::run(numext::real(mask)) && impl::run(numext::imag(mask)); - } -}; - } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index e9f564b53..a46a8eff0 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1689,7 +1689,8 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet phypot_complex(const } template -struct psign_impl::type>::IsComplex && +struct psign_impl::value && + !NumTraits::type>::IsComplex && !NumTraits::type>::IsInteger>> { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { using Scalar = typename unpacket_traits::type; @@ -1705,7 +1706,8 @@ struct psign_impl -struct psign_impl::type>::IsComplex && +struct psign_impl::value && + !NumTraits::type>::IsComplex && NumTraits::type>::IsSigned && NumTraits::type>::IsInteger>> { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { @@ -1724,7 +1726,8 @@ struct psign_impl -struct psign_impl::type>::IsComplex && +struct psign_impl::value && + !NumTraits::type>::IsComplex && !NumTraits::type>::IsSigned && NumTraits::type>::IsInteger>> { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { @@ -1739,7 +1742,8 @@ struct psign_impl -struct psign_impl::type>::IsComplex && +struct psign_impl::value && + NumTraits::type>::IsComplex && unpacket_traits::vectorizable>> { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) { typedef typename unpacket_traits::type Scalar; @@ -2176,7 +2180,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, c // Generic implementation of pow(x,y). template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Packet& x, const Packet& y) { +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t::value, Packet> generic_pow( + const Packet& x, const Packet& y) { typedef typename unpacket_traits::type Scalar; const Packet cst_inf = pset1(NumTraits::infinity()); @@ -2266,6 +2271,12 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac return pow; } +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t::value, Scalar> generic_pow( + const Scalar& x, const Scalar& y) { + return numext::pow(x, y); +} + namespace unary_pow { template ::IsInteger> @@ -2347,35 +2358,36 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const Scal } template -EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x, - const typename unpacket_traits::type& exponent) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Packet> gen_pow( + const Packet& x, const typename unpacket_traits::type& exponent) { const Packet exponent_packet = pset1(exponent); return generic_pow_impl(x, exponent_packet); } +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t::value, Scalar> gen_pow( + const Scalar& x, const Scalar& exponent) { + return numext::pow(x, exponent); +} + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, const ScalarExponent& exponent) { using Scalar = typename unpacket_traits::type; // non-integer base and exponent case - - const Scalar pos_zero = Scalar(0); - const Scalar all_ones = ptrue(Scalar()); - const Scalar pos_one = Scalar(1); - const Scalar pos_inf = NumTraits::infinity(); - const Packet cst_pos_zero = pzero(x); - const Packet cst_pos_one = pset1(pos_one); - const Packet cst_pos_inf = pset1(pos_inf); + const Packet cst_pos_one = pset1(Scalar(1)); + const Packet cst_pos_inf = pset1(NumTraits::infinity()); + const Packet cst_true = ptrue(x); const bool exponent_is_not_fin = !(numext::isfinite)(exponent); const bool exponent_is_neg = exponent < ScalarExponent(0); const bool exponent_is_pos = exponent > ScalarExponent(0); - const Packet exp_is_not_fin = pset1(exponent_is_not_fin ? all_ones : pos_zero); - const Packet exp_is_neg = pset1(exponent_is_neg ? all_ones : pos_zero); - const Packet exp_is_pos = pset1(exponent_is_pos ? all_ones : pos_zero); + const Packet exp_is_not_fin = exponent_is_not_fin ? cst_true : cst_pos_zero; + const Packet exp_is_neg = exponent_is_neg ? cst_true : cst_pos_zero; + const Packet exp_is_pos = exponent_is_pos ? cst_true : cst_pos_zero; const Packet exp_is_inf = pand(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); const Packet exp_is_nan = pandnot(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); @@ -2411,22 +2423,15 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Pack // This routine handles negative exponents. // The return value is either 0, 1, or -1. - - const Scalar pos_zero = Scalar(0); - const Scalar all_ones = ptrue(Scalar()); - const Scalar pos_one = Scalar(1); - - const Packet cst_pos_one = pset1(pos_one); - + const Packet cst_pos_one = pset1(Scalar(1)); const bool exponent_is_odd = exponent % ScalarExponent(2) != ScalarExponent(0); - - const Packet exp_is_odd = pset1(exponent_is_odd ? all_ones : pos_zero); + const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x); const Packet abs_x = pabs(x); const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); Packet result = pselect(exp_is_odd, x, abs_x); - result = pand(abs_x_is_one, result); + result = pselect(abs_x_is_one, result, pzero(x)); return result; } diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 4e09361a3..5f48d713c 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -354,28 +354,28 @@ void packetmath_boolean_mask_ops() { for (int i = 0; i < size; ++i) { data1[i] = internal::random(); } - CHECK_CWISE1(internal::ptrue, internal::ptrue); + CHECK_CWISE1_MASK(internal::ptrue, internal::ptrue); CHECK_CWISE2_IF(true, internal::pandnot, internal::pandnot); for (int i = 0; i < PacketSize; ++i) { data1[i] = Scalar(RealScalar(i)); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); + CHECK_CWISE2_MASK(internal::pcmp_eq, internal::pcmp_eq); // Test (-0) == (0) for signed operations for (int i = 0; i < PacketSize; ++i) { data1[i] = Scalar(-0.0); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); + CHECK_CWISE2_MASK(internal::pcmp_eq, internal::pcmp_eq); // Test NaN for (int i = 0; i < PacketSize; ++i) { data1[i] = NumTraits::quiet_NaN(); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq); + CHECK_CWISE2_MASK(internal::pcmp_eq, internal::pcmp_eq); } template @@ -384,28 +384,27 @@ void packetmath_boolean_mask_ops_real() { const int size = 2 * PacketSize; EIGEN_ALIGN_MAX Scalar data1[size]; EIGEN_ALIGN_MAX Scalar data2[size]; - EIGEN_ALIGN_MAX Scalar ref[size]; for (int i = 0; i < PacketSize; ++i) { data1[i] = internal::random(); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_lt_or_nan, internal::pcmp_lt_or_nan); + CHECK_CWISE2_MASK(internal::pcmp_lt_or_nan, internal::pcmp_lt_or_nan); // Test (-0) <=/< (0) for signed operations for (int i = 0; i < PacketSize; ++i) { data1[i] = Scalar(-0.0); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_lt_or_nan, internal::pcmp_lt_or_nan); + CHECK_CWISE2_MASK(internal::pcmp_lt_or_nan, internal::pcmp_lt_or_nan); // Test NaN for (int i = 0; i < PacketSize; ++i) { data1[i] = NumTraits::quiet_NaN(); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_lt_or_nan, internal::pcmp_lt_or_nan); + CHECK_CWISE2_MASK(internal::pcmp_lt_or_nan, internal::pcmp_lt_or_nan); } template @@ -422,31 +421,30 @@ struct packetmath_boolean_mask_ops_notcomplex_test< const int size = 2 * PacketSize; EIGEN_ALIGN_MAX Scalar data1[size]; EIGEN_ALIGN_MAX Scalar data2[size]; - EIGEN_ALIGN_MAX Scalar ref[size]; for (int i = 0; i < PacketSize; ++i) { data1[i] = internal::random(); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le); - CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt); + CHECK_CWISE2_MASK(internal::pcmp_le, internal::pcmp_le); + CHECK_CWISE2_MASK(internal::pcmp_lt, internal::pcmp_lt); // Test (-0) <=/< (0) for signed operations for (int i = 0; i < PacketSize; ++i) { data1[i] = Scalar(-0.0); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le); - CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt); + CHECK_CWISE2_MASK(internal::pcmp_le, internal::pcmp_le); + CHECK_CWISE2_MASK(internal::pcmp_lt, internal::pcmp_lt); // Test NaN for (int i = 0; i < PacketSize; ++i) { data1[i] = NumTraits::quiet_NaN(); data1[i + PacketSize] = internal::random() ? data1[i] : Scalar(0); } - CHECK_CWISE2_IF(true, internal::pcmp_le, internal::pcmp_le); - CHECK_CWISE2_IF(true, internal::pcmp_lt, internal::pcmp_lt); + CHECK_CWISE2_MASK(internal::pcmp_le, internal::pcmp_le); + CHECK_CWISE2_MASK(internal::pcmp_lt, internal::pcmp_lt); } }; @@ -700,7 +698,7 @@ void packetmath() { for (int i = 0; i < PacketSize; ++i) { data1[i] = internal::random(Scalar(0) - limit, limit); } - } else if (!NumTraits::IsInteger && !NumTraits::IsComplex) { + } else if (!NumTraits::IsInteger && !NumTraits::IsComplex && !std::is_same::value) { // Prevent very small product results by adjusting range. Otherwise, // we may end up with multiplying e.g. 32 Eigen::halfs with values < 1. for (int i = 0; i < PacketSize; ++i) { diff --git a/test/packetmath_test_shared.h b/test/packetmath_test_shared.h index 7d7a0dacf..64b13e336 100644 --- a/test/packetmath_test_shared.h +++ b/test/packetmath_test_shared.h @@ -115,6 +115,30 @@ bool areApprox(const Scalar* a, const Scalar* b, int size, const typename NumTra VERIFY(test::areApprox(ref, data2, PacketSize) && #POP); \ } +#define CHECK_CWISE1_MASK(REFOP, POP) \ + { \ + bool ref_mask[PacketSize] = {}; \ + bool data_mask[PacketSize] = {}; \ + internal::pstore(data2, POP(internal::pload(data1))); \ + for (int i = 0; i < PacketSize; ++i) { \ + ref_mask[i] = numext::is_exactly_zero(REFOP(data1[i])); \ + data_mask[i] = numext::is_exactly_zero(data2[i]); \ + } \ + VERIFY(test::areEqual(ref_mask, data_mask, PacketSize) && #POP); \ + } + +#define CHECK_CWISE2_MASK(REFOP, POP) \ + { \ + bool ref_mask[PacketSize] = {}; \ + bool data_mask[PacketSize] = {}; \ + internal::pstore(data2, POP(internal::pload(data1), internal::pload(data1 + PacketSize))); \ + for (int i = 0; i < PacketSize; ++i) { \ + ref_mask[i] = numext::is_exactly_zero(REFOP(data1[i], data1[i + PacketSize])); \ + data_mask[i] = numext::is_exactly_zero(data2[i]); \ + } \ + VERIFY(test::areEqual(ref_mask, data_mask, PacketSize) && #POP); \ + } + // Checks component-wise for input of size N. All of data1, data2, and ref // should have size at least ceil(N/PacketSize)*PacketSize to avoid memory // access errors.