Add missing comparison operators for GPU packets.

This commit is contained in:
Rasmus Munk Larsen 2022-09-07 14:10:02 -07:00
parent 242325eca7
commit f9dfda28ab

View File

@ -176,11 +176,22 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a,
const float& b) {
return __int_as_float(a < b ? 0xffffffffu : 0u);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a,
const double& b) {
return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float le_mask(const float& a,
const float& b) {
return __int_as_float(a <= b ? 0xffffffffu : 0u);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double le_mask(const double& a,
const double& b) {
return __longlong_as_double(a <= b ? 0xffffffffffffffffull : 0ull);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pand<float4>(const float4& a,
const float4& b) {
@ -242,6 +253,12 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt<float4>(const float4& a,
lt_mask(a.w, b.w));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_le<float4>(const float4& a,
const float4& b) {
return make_float4(le_mask(a.x, b.x), le_mask(a.y, b.y), le_mask(a.z, b.z),
le_mask(a.w, b.w));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pcmp_eq<double2>(const double2& a, const double2& b) {
return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y));
@ -251,6 +268,11 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pcmp_lt<double2>(const double2& a, const double2& b) {
return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pcmp_le<double2>(const double2& a, const double2& b) {
return make_double2(le_mask(a.x, b.x), le_mask(a.y, b.y));
}
#endif // defined(EIGEN_CUDA_ARCH) || defined(EIGEN_HIPCC) || (defined(EIGEN_CUDACC) && EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC)
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
@ -676,6 +698,19 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt(const half2& a,
return __halves2half2(eq1, eq2);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_le(const half2& a,
const half2& b) {
half true_half = half_impl::raw_uint16_to_half(0xffffu);
half false_half = half_impl::raw_uint16_to_half(0x0000u);
half a1 = __low2half(a);
half a2 = __high2half(a);
half b1 = __low2half(b);
half b2 = __high2half(b);
half eq1 = __half2float(a1) <= __half2float(b1) ? true_half : false_half;
half eq2 = __half2float(a2) <= __half2float(b2) ? true_half : false_half;
return __halves2half2(eq1, eq2);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand(const half2& a,
const half2& b) {
half a1 = __low2half(a);
@ -1257,6 +1292,34 @@ pcmp_eq<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
return r;
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
pcmp_lt<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
Packet4h2 r;
half2* r_alias = reinterpret_cast<half2*>(&r);
const half2* a_alias = reinterpret_cast<const half2*>(&a);
const half2* b_alias = reinterpret_cast<const half2*>(&b);
r_alias[0] = pcmp_lt(a_alias[0], b_alias[0]);
r_alias[1] = pcmp_lt(a_alias[1], b_alias[1]);
r_alias[2] = pcmp_lt(a_alias[2], b_alias[2]);
r_alias[3] = pcmp_lt(a_alias[3], b_alias[3]);
return r;
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2
pcmp_le<Packet4h2>(const Packet4h2& a, const Packet4h2& b) {
Packet4h2 r;
half2* r_alias = reinterpret_cast<half2*>(&r);
const half2* a_alias = reinterpret_cast<const half2*>(&a);
const half2* b_alias = reinterpret_cast<const half2*>(&b);
r_alias[0] = pcmp_le(a_alias[0], b_alias[0]);
r_alias[1] = pcmp_le(a_alias[1], b_alias[1]);
r_alias[2] = pcmp_le(a_alias[2], b_alias[2]);
r_alias[3] = pcmp_le(a_alias[3], b_alias[3]);
return r;
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4h2 pand<Packet4h2>(
const Packet4h2& a, const Packet4h2& b) {