diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index d45cb4bf4..ab9c0e142 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -608,7 +608,7 @@ EIGEN_DEVICE_FUNC inline bool pselect(const bool& cond, const bool& a, con /** \internal \returns the min or of \a a and \a b (coeff-wise) If either \a a or \a b are NaN, the result is implementation defined. */ -template +template struct pminmax_impl { template static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) { @@ -619,7 +619,7 @@ struct pminmax_impl { /** \internal \returns the min or max of \a a and \a b (coeff-wise) If either \a a or \a b are NaN, NaN is returned. */ template <> -struct pminmax_impl { +struct pminmax_impl { template static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) { Packet not_nan_mask_a = pcmp_eq(a, a); @@ -632,7 +632,7 @@ struct pminmax_impl { If both \a a and \a b are NaN, NaN is returned. Equivalent to std::fmin(a, b). */ template <> -struct pminmax_impl { +struct pminmax_impl { template static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) { Packet not_nan_mask_a = pcmp_eq(a, a); @@ -654,7 +654,8 @@ EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) { NaNPropagation determines the NaN propagation semantics. */ template EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) { - return pminmax_impl::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmin))); + constexpr bool IsInteger = NumTraits::type>::IsInteger; + return pminmax_impl::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmin))); } /** \internal \returns the max of \a a and \a b (coeff-wise) @@ -668,7 +669,8 @@ EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) { NaNPropagation determines the NaN propagation semantics. */ template EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) { - return pminmax_impl::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmax))); + constexpr bool IsInteger = NumTraits::type>::IsInteger; + return pminmax_impl::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmax))); } /** \internal \returns the absolute value of \a a */ @@ -1244,26 +1246,46 @@ EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_mul(const template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_min(const Packet& a) { typedef typename unpacket_traits::type Scalar; - return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin))); + return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin))); } -template -EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_min(const Packet& a) { - typedef typename unpacket_traits::type Scalar; - return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin))); -} - -/** \internal \returns the min of the elements of \a a */ +/** \internal \returns the max of the elements of \a a */ template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_max(const Packet& a) { typedef typename unpacket_traits::type Scalar; - return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax))); + return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax))); +} + +template +struct predux_min_max_helper_impl { + using Scalar = typename unpacket_traits::type; + static constexpr bool UsePredux_ = NaNPropagation == PropagateFast || NumTraits::IsInteger; + template = true> + static EIGEN_DEVICE_FUNC inline Scalar run_min(const Packet& a) { + return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin))); + } + template = true> + static EIGEN_DEVICE_FUNC inline Scalar run_max(const Packet& a) { + return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax))); + } + template = true> + static EIGEN_DEVICE_FUNC inline Scalar run_min(const Packet& a) { + return predux_min(a); + } + template = true> + static EIGEN_DEVICE_FUNC inline Scalar run_max(const Packet& a) { + return predux_max(a); + } +}; + +template +EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_min(const Packet& a) { + return predux_min_max_helper_impl::run_min(a); } template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_max(const Packet& a) { - typedef typename unpacket_traits::type Scalar; - return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax))); + return predux_min_max_helper_impl::run_max(a); } #undef EIGEN_BINARY_OP_NAN_PROPAGATION diff --git a/Eigen/src/Core/arch/AVX/Reductions.h b/Eigen/src/Core/arch/AVX/Reductions.h index 8eed4be31..237617c5a 100644 --- a/Eigen/src/Core/arch/AVX/Reductions.h +++ b/Eigen/src/Core/arch/AVX/Reductions.h @@ -148,11 +148,6 @@ EIGEN_STRONG_INLINE float predux_min(const Packet8f& a) { return predux_min(pmin(lo, hi)); } -template <> -EIGEN_STRONG_INLINE float predux_min(const Packet8f& a) { - return predux_min(a); -} - template <> EIGEN_STRONG_INLINE float predux_min(const Packet8f& a) { Packet4f lo = _mm256_castps256_ps128(a); @@ -174,11 +169,6 @@ EIGEN_STRONG_INLINE float predux_max(const Packet8f& a) { return predux_max(pmax(lo, hi)); } -template <> -EIGEN_STRONG_INLINE float predux_max(const Packet8f& a) { - return predux_max(a); -} - template <> EIGEN_STRONG_INLINE float predux_max(const Packet8f& a) { Packet4f lo = _mm256_castps256_ps128(a); @@ -221,11 +211,6 @@ EIGEN_STRONG_INLINE double predux_min(const Packet4d& a) { return predux_min(pmin(lo, hi)); } -template <> -EIGEN_STRONG_INLINE double predux_min(const Packet4d& a) { - return predux_min(a); -} - template <> EIGEN_STRONG_INLINE double predux_min(const Packet4d& a) { Packet2d lo = _mm256_castpd256_pd128(a); @@ -247,11 +232,6 @@ EIGEN_STRONG_INLINE double predux_max(const Packet4d& a) { return predux_max(pmax(lo, hi)); } -template <> -EIGEN_STRONG_INLINE double predux_max(const Packet4d& a) { - return predux_max(a); -} - template <> EIGEN_STRONG_INLINE double predux_max(const Packet4d& a) { Packet2d lo = _mm256_castpd256_pd128(a); @@ -289,11 +269,6 @@ EIGEN_STRONG_INLINE half predux_min(const Packet8h& a) { return static_cast(predux_min(half2float(a))); } -template <> -EIGEN_STRONG_INLINE half predux_min(const Packet8h& a) { - return static_cast(predux_min(half2float(a))); -} - template <> EIGEN_STRONG_INLINE half predux_min(const Packet8h& a) { return static_cast(predux_min(half2float(a))); @@ -309,11 +284,6 @@ EIGEN_STRONG_INLINE half predux_max(const Packet8h& a) { return static_cast(predux_max(half2float(a))); } -template <> -EIGEN_STRONG_INLINE half predux_max(const Packet8h& a) { - return static_cast(predux_max(half2float(a))); -} - template <> EIGEN_STRONG_INLINE half predux_max(const Packet8h& a) { return static_cast(predux_max(half2float(a))); @@ -347,11 +317,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) { return static_cast(predux_min(Bf16ToF32(a))); } -template <> -EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) { - return predux_min(a); -} - template <> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) { return static_cast(predux_min(Bf16ToF32(a))); @@ -367,11 +332,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) { return static_cast(predux_max(Bf16ToF32(a))); } -template <> -EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) { - return predux_max(a); -} - template <> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) { return static_cast(predux_max(Bf16ToF32(a))); diff --git a/Eigen/src/Core/arch/AVX512/Reductions.h b/Eigen/src/Core/arch/AVX512/Reductions.h index e6d5bae26..f7b4c25a1 100644 --- a/Eigen/src/Core/arch/AVX512/Reductions.h +++ b/Eigen/src/Core/arch/AVX512/Reductions.h @@ -102,11 +102,6 @@ EIGEN_STRONG_INLINE float predux_min(const Packet16f& a) { return _mm512_reduce_min_ps(a); } -template <> -EIGEN_STRONG_INLINE float predux_min(const Packet16f& a) { - return _mm512_reduce_min_ps(a); -} - template <> EIGEN_STRONG_INLINE float predux_min(const Packet16f& a) { Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); @@ -126,11 +121,6 @@ EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) { return _mm512_reduce_max_ps(a); } -template <> -EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) { - return _mm512_reduce_max_ps(a); -} - template <> EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) { Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); @@ -146,8 +136,8 @@ EIGEN_STRONG_INLINE float predux_max(const Packet16f& a) { } template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet16f& x) { - return _mm512_reduce_or_epi32(_mm512_castps_si512(x)) != 0; +EIGEN_STRONG_INLINE bool predux_any(const Packet16f& a) { + return _mm512_reduce_or_epi32(_mm512_castps_si512(a)) != 0; } /* -- -- -- -- -- -- -- -- -- -- -- -- Packet8d -- -- -- -- -- -- -- -- -- -- -- -- */ @@ -167,11 +157,6 @@ EIGEN_STRONG_INLINE double predux_min(const Packet8d& a) { return _mm512_reduce_min_pd(a); } -template <> -EIGEN_STRONG_INLINE double predux_min(const Packet8d& a) { - return _mm512_reduce_min_pd(a); -} - template <> EIGEN_STRONG_INLINE double predux_min(const Packet8d& a) { Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); @@ -191,11 +176,6 @@ EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) { return _mm512_reduce_max_pd(a); } -template <> -EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) { - return _mm512_reduce_max_pd(a); -} - template <> EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) { Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); @@ -211,8 +191,8 @@ EIGEN_STRONG_INLINE double predux_max(const Packet8d& a) { } template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet8d& x) { - return _mm512_reduce_or_epi64(_mm512_castpd_si512(x)) != 0; +EIGEN_STRONG_INLINE bool predux_any(const Packet8d& a) { + return _mm512_reduce_or_epi64(_mm512_castpd_si512(a)) != 0; } #ifndef EIGEN_VECTORIZE_AVX512FP16 @@ -233,11 +213,6 @@ EIGEN_STRONG_INLINE half predux_min(const Packet16h& from) { return half(predux_min(half2float(from))); } -template <> -EIGEN_STRONG_INLINE half predux_min(const Packet16h& from) { - return half(predux_min(half2float(from))); -} - template <> EIGEN_STRONG_INLINE half predux_min(const Packet16h& from) { return half(predux_min(half2float(from))); @@ -253,11 +228,6 @@ EIGEN_STRONG_INLINE half predux_max(const Packet16h& from) { return half(predux_max(half2float(from))); } -template <> -EIGEN_STRONG_INLINE half predux_max(const Packet16h& from) { - return half(predux_max(half2float(from))); -} - template <> EIGEN_STRONG_INLINE half predux_max(const Packet16h& from) { return half(predux_max(half2float(from))); @@ -269,8 +239,8 @@ EIGEN_STRONG_INLINE half predux_max(const Packet16h& from) { } template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet16h& x) { - return predux_any(x.m_val); +EIGEN_STRONG_INLINE bool predux_any(const Packet16h& a) { + return predux_any(a.m_val); } #endif @@ -291,11 +261,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet16bf& from) { return static_cast(predux_min(Bf16ToF32(from))); } -template <> -EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet16bf& from) { - return static_cast(predux_min(Bf16ToF32(from))); -} - template <> EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet16bf& from) { return static_cast(predux_min(Bf16ToF32(from))); @@ -311,11 +276,6 @@ EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) { return static_cast(predux_max(Bf16ToF32(from))); } -template <> -EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) { - return static_cast(predux_max(Bf16ToF32(from))); -} - template <> EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) { return static_cast(predux_max(Bf16ToF32(from))); @@ -327,8 +287,8 @@ EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet16bf& from) { } template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet16bf& x) { - return predux_any(x.m_val); +EIGEN_STRONG_INLINE bool predux_any(const Packet16bf& a) { + return predux_any(a.m_val); } } // end namespace internal diff --git a/Eigen/src/Core/arch/SSE/Reductions.h b/Eigen/src/Core/arch/SSE/Reductions.h index dd33da211..f38df4e4a 100644 --- a/Eigen/src/Core/arch/SSE/Reductions.h +++ b/Eigen/src/Core/arch/SSE/Reductions.h @@ -86,6 +86,21 @@ EIGEN_STRONG_INLINE bool predux_mul(const Packet16b& a) { return ((pfirst(tmp) == 0x01010101) && (pfirst(_mm_shuffle_epi32(tmp, 1)) == 0x01010101)); } +template <> +EIGEN_STRONG_INLINE bool predux_min(const Packet16b& a) { + return predux_mul(a); +} + +template <> +EIGEN_STRONG_INLINE bool predux_max(const Packet16b& a) { + return predux(a); +} + +template <> +EIGEN_STRONG_INLINE bool predux_any(const Packet16b& a) { + return predux(a); +} + /* -- -- -- -- -- -- -- -- -- -- -- -- Packet4i -- -- -- -- -- -- -- -- -- -- -- -- */ template @@ -121,8 +136,8 @@ EIGEN_STRONG_INLINE int predux_max(const Packet4i& a) { #endif template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet4i& x) { - return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0; +EIGEN_STRONG_INLINE bool predux_any(const Packet4i& a) { + return _mm_movemask_ps(_mm_castsi128_ps(a)) != 0x0; } /* -- -- -- -- -- -- -- -- -- -- -- -- Packet4ui -- -- -- -- -- -- -- -- -- -- -- -- */ @@ -160,8 +175,8 @@ EIGEN_STRONG_INLINE uint32_t predux_max(const Packet4ui& a) { #endif template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet4ui& x) { - return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0; +EIGEN_STRONG_INLINE bool predux_any(const Packet4ui& a) { + return _mm_movemask_ps(_mm_castsi128_ps(a)) != 0x0; } /* -- -- -- -- -- -- -- -- -- -- -- -- Packet2l -- -- -- -- -- -- -- -- -- -- -- -- */ @@ -181,8 +196,8 @@ EIGEN_STRONG_INLINE int64_t predux(const Packet2l& a) { } template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet2l& x) { - return _mm_movemask_pd(_mm_castsi128_pd(x)) != 0x0; +EIGEN_STRONG_INLINE bool predux_any(const Packet2l& a) { + return _mm_movemask_pd(_mm_castsi128_pd(a)) != 0x0; } /* -- -- -- -- -- -- -- -- -- -- -- -- Packet4f -- -- -- -- -- -- -- -- -- -- -- -- */ @@ -216,11 +231,6 @@ EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) { return sse_predux_min_impl::run(a); } -template <> -EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) { - return sse_predux_min_prop_impl::run(a); -} - template <> EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) { return sse_predux_min_prop_impl::run(a); @@ -236,11 +246,6 @@ EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) { return sse_predux_max_impl::run(a); } -template <> -EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) { - return sse_predux_max_prop_impl::run(a); -} - template <> EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) { return sse_predux_max_prop_impl::run(a); @@ -252,8 +257,8 @@ EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) { } template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x) { - return _mm_movemask_ps(x) != 0x0; +EIGEN_STRONG_INLINE bool predux_any(const Packet4f& a) { + return _mm_movemask_ps(a) != 0x0; } /* -- -- -- -- -- -- -- -- -- -- -- -- Packet2d -- -- -- -- -- -- -- -- -- -- -- -- */ @@ -282,11 +287,6 @@ EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) { return sse_predux_min_impl::run(a); } -template <> -EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) { - return sse_predux_min_prop_impl::run(a); -} - template <> EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) { return sse_predux_min_prop_impl::run(a); @@ -302,11 +302,6 @@ EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) { return sse_predux_max_impl::run(a); } -template <> -EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) { - return sse_predux_max_prop_impl::run(a); -} - template <> EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) { return sse_predux_max_prop_impl::run(a); @@ -318,8 +313,8 @@ EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) { } template <> -EIGEN_STRONG_INLINE bool predux_any(const Packet2d& x) { - return _mm_movemask_pd(x) != 0x0; +EIGEN_STRONG_INLINE bool predux_any(const Packet2d& a) { + return _mm_movemask_pd(a) != 0x0; } } // end namespace internal