diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 26a4634d3..d885d9fbd 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -580,7 +580,7 @@ EIGEN_DEVICE_FUNC inline Packet pandnot(const Packet& a, const Packet& b) { } // In the general case, use bitwise select. -template +template ::value> struct pselect_impl { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& mask, const Packet& a, const Packet& b) { return por(pand(a, mask), pandnot(b, mask)); @@ -589,9 +589,9 @@ struct pselect_impl { // For scalars, use ternary select. template -struct pselect_impl::value>> { +struct pselect_impl { static EIGEN_DEVICE_FUNC inline Packet run(const Packet& mask, const Packet& a, const Packet& b) { - return numext::equal_strict(mask, Packet(0)) ? b : a; + return numext::select(mask, a, b); } }; diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 528aed2c2..9659c2bfc 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -182,6 +182,10 @@ struct imag_ref_retval { typedef typename NumTraits::Real& type; }; +// implementation in MathFunctionsImpl.h +template ::value> +struct scalar_select_mask; + } // namespace internal namespace numext { @@ -207,6 +211,11 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar& return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x); } +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar select(const Mask& mask, const Scalar& a, const Scalar& b) { + return internal::scalar_select_mask::run(mask) ? b : a; +} + } // namespace numext namespace internal { diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index cf8dcc3b8..cbac1c2a4 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -256,6 +256,48 @@ EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) { return ComplexT(numext::log(a), b); } +// For generic scalars, use ternary select. +template +struct scalar_select_mask { + static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) { return numext::is_exactly_zero(mask); } +}; + +// For built-in float mask, bitcast the mask to its integer counterpart and use ternary select. +template +struct scalar_select_mask { + using IntegerType = typename numext::get_integer_by_size::unsigned_type; + static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) { + return numext::is_exactly_zero(numext::bit_cast(std::abs(mask))); + } +}; + +template +struct ldbl_select_mask { + static constexpr int MantissaDigits = std::numeric_limits::digits; + static constexpr int NumBytes = (MantissaDigits == 64 ? 80 : 128) / CHAR_BIT; + static EIGEN_DEVICE_FUNC inline bool run(const long double& mask) { + const uint8_t* mask_bytes = reinterpret_cast(&mask); + for (Index i = 0; i < NumBytes; i++) { + if (mask_bytes[i] != 0) return false; + } + return true; + } +}; + +template <> +struct ldbl_select_mask : scalar_select_mask {}; + +template <> +struct scalar_select_mask : ldbl_select_mask<> {}; + +template +struct scalar_select_mask, false> { + using impl = scalar_select_mask; + static EIGEN_DEVICE_FUNC inline bool run(const std::complex& mask) { + return impl::run(numext::real(mask)) && impl::run(numext::imag(mask)); + } +}; + } // end namespace internal } // end namespace Eigen