Enable default behavior for pmin<PropagateFast>, predux_min, etc

This commit is contained in:
Charles Schlosser 2025-06-02 17:23:37 +00:00 committed by Rasmus Munk Larsen
parent 4fdf87bbf5
commit 21e89b930c
4 changed files with 71 additions and 134 deletions

View File

@ -608,7 +608,7 @@ EIGEN_DEVICE_FUNC inline bool pselect<bool>(const bool& cond, const bool& a, con
/** \internal \returns the min or of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, the result is implementation defined. */
template <int NaNPropagation>
template <int NaNPropagation, bool IsInteger>
struct pminmax_impl {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
@ -619,7 +619,7 @@ struct pminmax_impl {
/** \internal \returns the min or max of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, NaN is returned. */
template <>
struct pminmax_impl<PropagateNaN> {
struct pminmax_impl<PropagateNaN, false> {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
Packet not_nan_mask_a = pcmp_eq(a, a);
@ -632,7 +632,7 @@ struct pminmax_impl<PropagateNaN> {
If both \a a and \a b are NaN, NaN is returned.
Equivalent to std::fmin(a, b). */
template <>
struct pminmax_impl<PropagateNumbers> {
struct pminmax_impl<PropagateNumbers, false> {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
Packet not_nan_mask_a = pcmp_eq(a, a);
@ -654,7 +654,8 @@ EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) {
NaNPropagation determines the NaN propagation semantics. */
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) {
return pminmax_impl<NaNPropagation>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmin<Packet>)));
constexpr bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger;
return pminmax_impl<NaNPropagation, IsInteger>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmin<Packet>)));
}
/** \internal \returns the max of \a a and \a b (coeff-wise)
@ -668,7 +669,8 @@ EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) {
NaNPropagation determines the NaN propagation semantics. */
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) {
return pminmax_impl<NaNPropagation>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmax<Packet>)));
constexpr bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger;
return pminmax_impl<NaNPropagation, IsInteger>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmax<Packet>)));
}
/** \internal \returns the absolute value of \a a */
@ -1244,26 +1246,46 @@ EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_mul(const
template <typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<PropagateFast, Scalar>)));
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<Scalar>)));
}
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<NaNPropagation, Scalar>)));
}
/** \internal \returns the min of the elements of \a a */
/** \internal \returns the max of the elements of \a a */
template <typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<PropagateFast, Scalar>)));
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<Scalar>)));
}
template <int NaNPropagation, typename Packet>
struct predux_min_max_helper_impl {
using Scalar = typename unpacket_traits<Packet>::type;
static constexpr bool UsePredux_ = NaNPropagation == PropagateFast || NumTraits<Scalar>::IsInteger;
template <bool UsePredux = UsePredux_, std::enable_if_t<!UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_min(const Packet& a) {
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<NaNPropagation, Scalar>)));
}
template <bool UsePredux = UsePredux_, std::enable_if_t<!UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_max(const Packet& a) {
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<NaNPropagation, Scalar>)));
}
template <bool UsePredux = UsePredux_, std::enable_if_t<UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_min(const Packet& a) {
return predux_min(a);
}
template <bool UsePredux = UsePredux_, std::enable_if_t<UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_max(const Packet& a) {
return predux_max(a);
}
};
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a) {
return predux_min_max_helper_impl<NaNPropagation, Packet>::run_min(a);
}
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<NaNPropagation, Scalar>)));
return predux_min_max_helper_impl<NaNPropagation, Packet>::run_max(a);
}
#undef EIGEN_BINARY_OP_NAN_PROPAGATION

View File

@ -148,11 +148,6 @@ EIGEN_STRONG_INLINE float predux_min(const Packet8f& a) {
return predux_min(pmin(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateFast>(const Packet8f& a) {
return predux_min(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateNumbers>(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
@ -174,11 +169,6 @@ EIGEN_STRONG_INLINE float predux_max(const Packet8f& a) {
return predux_max(pmax(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateFast>(const Packet8f& a) {
return predux_max(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateNumbers>(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
@ -221,11 +211,6 @@ EIGEN_STRONG_INLINE double predux_min(const Packet4d& a) {
return predux_min(pmin(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateFast>(const Packet4d& a) {
return predux_min(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateNumbers>(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
@ -247,11 +232,6 @@ EIGEN_STRONG_INLINE double predux_max(const Packet4d& a) {
return predux_max(pmax(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateFast>(const Packet4d& a) {
return predux_max(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateNumbers>(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
@ -289,11 +269,6 @@ EIGEN_STRONG_INLINE half predux_min(const Packet8h& a) {
return static_cast<half>(predux_min(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_min<PropagateFast>(const Packet8h& a) {
return static_cast<half>(predux_min<PropagateFast>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_min<PropagateNumbers>(const Packet8h& a) {
return static_cast<half>(predux_min<PropagateNumbers>(half2float(a)));
@ -309,11 +284,6 @@ EIGEN_STRONG_INLINE half predux_max(const Packet8h& a) {
return static_cast<half>(predux_max(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_max<PropagateFast>(const Packet8h& a) {
return static_cast<half>(predux_max<PropagateFast>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_max<PropagateNumbers>(const Packet8h& a) {
return static_cast<half>(predux_max<PropagateNumbers>(half2float(a)));
@ -347,11 +317,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) {
return static_cast<bfloat16>(predux_min(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateFast>(const Packet8bf& a) {
return predux_min(a);
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNumbers>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_min<PropagateNumbers>(Bf16ToF32(a)));
@ -367,11 +332,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) {
return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateFast>(const Packet8bf& a) {
return predux_max(a);
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNumbers>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_max<PropagateNumbers>(Bf16ToF32(a)));

View File

@ -102,11 +102,6 @@ EIGEN_STRONG_INLINE float predux_min(const Packet16f& a) {
return _mm512_reduce_min_ps(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateFast>(const Packet16f& a) {
return _mm512_reduce_min_ps(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateNumbers>(const Packet16f& a) {
Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
@ -126,11 +121,6 @@ EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) {
return _mm512_reduce_max_ps(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateFast>(const Packet16f& a) {
return _mm512_reduce_max_ps(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateNumbers>(const Packet16f& a) {
Packet8f lane0 = _mm512_extractf32x8_ps(a, 0);
@ -146,8 +136,8 @@ EIGEN_STRONG_INLINE float predux_max<PropagateNaN>(const Packet16f& a) {
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet16f& x) {
return _mm512_reduce_or_epi32(_mm512_castps_si512(x)) != 0;
EIGEN_STRONG_INLINE bool predux_any(const Packet16f& a) {
return _mm512_reduce_or_epi32(_mm512_castps_si512(a)) != 0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8d -- -- -- -- -- -- -- -- -- -- -- -- */
@ -167,11 +157,6 @@ EIGEN_STRONG_INLINE double predux_min(const Packet8d& a) {
return _mm512_reduce_min_pd(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateFast>(const Packet8d& a) {
return _mm512_reduce_min_pd(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateNumbers>(const Packet8d& a) {
Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
@ -191,11 +176,6 @@ EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) {
return _mm512_reduce_max_pd(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateFast>(const Packet8d& a) {
return _mm512_reduce_max_pd(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateNumbers>(const Packet8d& a) {
Packet4d lane0 = _mm512_extractf64x4_pd(a, 0);
@ -211,8 +191,8 @@ EIGEN_STRONG_INLINE double predux_max<PropagateNaN>(const Packet8d& a) {
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8d& x) {
return _mm512_reduce_or_epi64(_mm512_castpd_si512(x)) != 0;
EIGEN_STRONG_INLINE bool predux_any(const Packet8d& a) {
return _mm512_reduce_or_epi64(_mm512_castpd_si512(a)) != 0;
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
@ -233,11 +213,6 @@ EIGEN_STRONG_INLINE half predux_min(const Packet16h& from) {
return half(predux_min(half2float(from)));
}
template <>
EIGEN_STRONG_INLINE half predux_min<PropagateFast>(const Packet16h& from) {
return half(predux_min<PropagateFast>(half2float(from)));
}
template <>
EIGEN_STRONG_INLINE half predux_min<PropagateNumbers>(const Packet16h& from) {
return half(predux_min<PropagateNumbers>(half2float(from)));
@ -253,11 +228,6 @@ EIGEN_STRONG_INLINE half predux_max(const Packet16h& from) {
return half(predux_max(half2float(from)));
}
template <>
EIGEN_STRONG_INLINE half predux_max<PropagateFast>(const Packet16h& from) {
return half(predux_max<PropagateFast>(half2float(from)));
}
template <>
EIGEN_STRONG_INLINE half predux_max<PropagateNumbers>(const Packet16h& from) {
return half(predux_max<PropagateNumbers>(half2float(from)));
@ -269,8 +239,8 @@ EIGEN_STRONG_INLINE half predux_max<PropagateNaN>(const Packet16h& from) {
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet16h& x) {
return predux_any<Packet8i>(x.m_val);
EIGEN_STRONG_INLINE bool predux_any(const Packet16h& a) {
return predux_any<Packet8i>(a.m_val);
}
#endif
@ -291,11 +261,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet16bf& from) {
return static_cast<bfloat16>(predux_min<Packet16f>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateFast>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_min<PropagateFast>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNumbers>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_min<PropagateNumbers>(Bf16ToF32(from)));
@ -311,11 +276,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) {
return static_cast<bfloat16>(predux_max(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateFast>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_max<PropagateFast>(Bf16ToF32(from)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNumbers>(const Packet16bf& from) {
return static_cast<bfloat16>(predux_max<PropagateNumbers>(Bf16ToF32(from)));
@ -327,8 +287,8 @@ EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNaN>(const Packet16bf& from) {
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet16bf& x) {
return predux_any<Packet8i>(x.m_val);
EIGEN_STRONG_INLINE bool predux_any(const Packet16bf& a) {
return predux_any<Packet8i>(a.m_val);
}
} // end namespace internal

View File

@ -86,6 +86,21 @@ EIGEN_STRONG_INLINE bool predux_mul(const Packet16b& a) {
return ((pfirst<Packet4i>(tmp) == 0x01010101) && (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) == 0x01010101));
}
template <>
EIGEN_STRONG_INLINE bool predux_min(const Packet16b& a) {
return predux_mul(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_max(const Packet16b& a) {
return predux(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet16b& a) {
return predux(a);
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4i -- -- -- -- -- -- -- -- -- -- -- -- */
template <typename Op>
@ -121,8 +136,8 @@ EIGEN_STRONG_INLINE int predux_max(const Packet4i& a) {
#endif
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4i& x) {
return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0;
EIGEN_STRONG_INLINE bool predux_any(const Packet4i& a) {
return _mm_movemask_ps(_mm_castsi128_ps(a)) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4ui -- -- -- -- -- -- -- -- -- -- -- -- */
@ -160,8 +175,8 @@ EIGEN_STRONG_INLINE uint32_t predux_max(const Packet4ui& a) {
#endif
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4ui& x) {
return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0;
EIGEN_STRONG_INLINE bool predux_any(const Packet4ui& a) {
return _mm_movemask_ps(_mm_castsi128_ps(a)) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet2l -- -- -- -- -- -- -- -- -- -- -- -- */
@ -181,8 +196,8 @@ EIGEN_STRONG_INLINE int64_t predux(const Packet2l& a) {
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet2l& x) {
return _mm_movemask_pd(_mm_castsi128_pd(x)) != 0x0;
EIGEN_STRONG_INLINE bool predux_any(const Packet2l& a) {
return _mm_movemask_pd(_mm_castsi128_pd(a)) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4f -- -- -- -- -- -- -- -- -- -- -- -- */
@ -216,11 +231,6 @@ EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) {
return sse_predux_min_impl<Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateFast>(const Packet4f& a) {
return sse_predux_min_prop_impl<PropagateFast, Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateNumbers>(const Packet4f& a) {
return sse_predux_min_prop_impl<PropagateNumbers, Packet4f>::run(a);
@ -236,11 +246,6 @@ EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) {
return sse_predux_max_impl<Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateFast>(const Packet4f& a) {
return sse_predux_max_prop_impl<PropagateFast, Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateNumbers>(const Packet4f& a) {
return sse_predux_max_prop_impl<PropagateNumbers, Packet4f>::run(a);
@ -252,8 +257,8 @@ EIGEN_STRONG_INLINE float predux_max<PropagateNaN>(const Packet4f& a) {
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x) {
return _mm_movemask_ps(x) != 0x0;
EIGEN_STRONG_INLINE bool predux_any(const Packet4f& a) {
return _mm_movemask_ps(a) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet2d -- -- -- -- -- -- -- -- -- -- -- -- */
@ -282,11 +287,6 @@ EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) {
return sse_predux_min_impl<Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateFast>(const Packet2d& a) {
return sse_predux_min_prop_impl<PropagateFast, Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateNumbers>(const Packet2d& a) {
return sse_predux_min_prop_impl<PropagateNumbers, Packet2d>::run(a);
@ -302,11 +302,6 @@ EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) {
return sse_predux_max_impl<Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateFast>(const Packet2d& a) {
return sse_predux_max_prop_impl<PropagateFast, Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateNumbers>(const Packet2d& a) {
return sse_predux_max_prop_impl<PropagateNumbers, Packet2d>::run(a);
@ -318,8 +313,8 @@ EIGEN_STRONG_INLINE double predux_max<PropagateNaN>(const Packet2d& a) {
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet2d& x) {
return _mm_movemask_pd(x) != 0x0;
EIGEN_STRONG_INLINE bool predux_any(const Packet2d& a) {
return _mm_movemask_pd(a) != 0x0;
}
} // end namespace internal