diff --git a/Eigen/src/Core/arch/GPU/PacketMath.h b/Eigen/src/Core/arch/GPU/PacketMath.h index e147fd1b1..e2bcf483a 100644 --- a/Eigen/src/Core/arch/GPU/PacketMath.h +++ b/Eigen/src/Core/arch/GPU/PacketMath.h @@ -176,11 +176,22 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a, const float& b) { return __int_as_float(a < b ? 0xffffffffu : 0u); } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a, const double& b) { return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull); } +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float le_mask(const float& a, + const float& b) { + return __int_as_float(a <= b ? 0xffffffffu : 0u); +} + +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double le_mask(const double& a, + const double& b) { + return __longlong_as_double(a <= b ? 0xffffffffffffffffull : 0ull); +} + template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pand(const float4& a, const float4& b) { @@ -242,6 +253,12 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt(const float4& a, lt_mask(a.w, b.w)); } template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_le(const float4& a, + const float4& b) { + return make_float4(le_mask(a.x, b.x), le_mask(a.y, b.y), le_mask(a.z, b.z), + le_mask(a.w, b.w)); +} +template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pcmp_eq(const double2& a, const double2& b) { return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y)); @@ -251,6 +268,11 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 pcmp_lt(const double2& a, const double2& b) { return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y)); } +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2 +pcmp_le(const double2& a, const double2& b) { + return make_double2(le_mask(a.x, b.x), le_mask(a.y, b.y)); +} #endif // defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset(const float& a) { @@ -676,6 +698,19 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a, return __halves2half2(eq1, eq2); } +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_le(const half2& a, + const half2& b) { + half true_half = half_impl::raw_uint16_to_half(0xffffu); + half false_half = half_impl::raw_uint16_to_half(0x0000u); + half a1 = __low2half(a); + half a2 = __high2half(a); + half b1 = __low2half(b); + half b2 = __high2half(b); + half eq1 = __half2float(a1) <= __half2float(b1) ? true_half : false_half; + half eq2 = __half2float(a2) <= __half2float(b2) ? true_half : false_half; + return __halves2half2(eq1, eq2); +} + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a, const half2& b) { half a1 = __low2half(a); @@ -1257,6 +1292,34 @@ pcmp_eq(const Packet4h2& a, const Packet4h2& b) { return r; } +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pcmp_lt(const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pcmp_lt(a_alias[0], b_alias[0]); + r_alias[1] = pcmp_lt(a_alias[1], b_alias[1]); + r_alias[2] = pcmp_lt(a_alias[2], b_alias[2]); + r_alias[3] = pcmp_lt(a_alias[3], b_alias[3]); + return r; +} + +template <> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 +pcmp_le(const Packet4h2& a, const Packet4h2& b) { + Packet4h2 r; + half2* r_alias = reinterpret_cast(&r); + const half2* a_alias = reinterpret_cast(&a); + const half2* b_alias = reinterpret_cast(&b); + r_alias[0] = pcmp_le(a_alias[0], b_alias[0]); + r_alias[1] = pcmp_le(a_alias[1], b_alias[1]); + r_alias[2] = pcmp_le(a_alias[2], b_alias[2]); + r_alias[3] = pcmp_le(a_alias[3], b_alias[3]); + return r; +} + template <> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pand( const Packet4h2& a, const Packet4h2& b) {