This commit is contained in:
Charles Schlosser 2025-03-22 17:19:44 +00:00
parent ac2165c11f
commit 754bd24f5e
3 changed files with 54 additions and 3 deletions

View File

@ -580,7 +580,7 @@ EIGEN_DEVICE_FUNC inline Packet pandnot(const Packet& a, const Packet& b) {
}
// In the general case, use bitwise select.
template <typename Packet, typename EnableIf = void>
template <typename Packet, bool is_scalar = is_scalar<Packet>::value>
struct pselect_impl {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& mask, const Packet& a, const Packet& b) {
return por(pand(a, mask), pandnot(b, mask));
@ -589,9 +589,9 @@ struct pselect_impl {
// For scalars, use ternary select.
template <typename Packet>
struct pselect_impl<Packet, std::enable_if_t<is_scalar<Packet>::value>> {
struct pselect_impl<Packet, true> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& mask, const Packet& a, const Packet& b) {
return numext::equal_strict(mask, Packet(0)) ? b : a;
return numext::select(mask, a, b);
}
};

View File

@ -182,6 +182,10 @@ struct imag_ref_retval {
typedef typename NumTraits<Scalar>::Real& type;
};
// implementation in MathFunctionsImpl.h
template <typename Mask, bool is_built_in_float = std::is_floating_point<Mask>::value>
struct scalar_select_mask;
} // namespace internal
namespace numext {
@ -207,6 +211,11 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar&
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
}
template <typename Scalar, typename Mask>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar select(const Mask& mask, const Scalar& a, const Scalar& b) {
return internal::scalar_select_mask<Mask>::run(mask) ? b : a;
}
} // namespace numext
namespace internal {

View File

@ -256,6 +256,48 @@ EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) {
return ComplexT(numext::log(a), b);
}
// For generic scalars, use ternary select.
template <typename Mask>
struct scalar_select_mask<Mask, /*is_built_in_float*/ false> {
static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) { return numext::is_exactly_zero(mask); }
};
// For built-in float mask, bitcast the mask to its integer counterpart and use ternary select.
template <typename Mask>
struct scalar_select_mask<Mask, /*is_built_in_float*/ true> {
using IntegerType = typename numext::get_integer_by_size<sizeof(Mask)>::unsigned_type;
static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) {
return numext::is_exactly_zero(numext::bit_cast<IntegerType>(std::abs(mask)));
}
};
template <int Size = sizeof(long double)>
struct ldbl_select_mask {
static constexpr int MantissaDigits = std::numeric_limits<long double>::digits;
static constexpr int NumBytes = (MantissaDigits == 64 ? 80 : 128) / CHAR_BIT;
static EIGEN_DEVICE_FUNC inline bool run(const long double& mask) {
const uint8_t* mask_bytes = reinterpret_cast<const uint8_t*>(&mask);
for (Index i = 0; i < NumBytes; i++) {
if (mask_bytes[i] != 0) return false;
}
return true;
}
};
template <>
struct ldbl_select_mask<sizeof(double)> : scalar_select_mask<double> {};
template <>
struct scalar_select_mask<long double, true> : ldbl_select_mask<> {};
template <typename RealMask>
struct scalar_select_mask<std::complex<RealMask>, false> {
using impl = scalar_select_mask<RealMask>;
static EIGEN_DEVICE_FUNC inline bool run(const std::complex<RealMask>& mask) {
return impl::run(numext::real(mask)) && impl::run(numext::imag(mask));
}
};
} // end namespace internal
} // end namespace Eigen