Faster emulated half comparisons

This commit is contained in:
Charles Schlosser 2025-06-17 17:05:58 +00:00 committed by Antonio Sánchez
parent ac6955ebc6
commit bcce88c99e
3 changed files with 121 additions and 35 deletions

View File

@ -2249,24 +2249,64 @@ EIGEN_STRONG_INLINE Packet8h ptrunc<Packet8h>(const Packet8h& a) {
return float2half(ptrunc<Packet8f>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pisinf<Packet8h>(const Packet8h& a) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
return _mm_cmpeq_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf));
}
template <>
EIGEN_STRONG_INLINE Packet8h pisnan<Packet8h>(const Packet8h& a) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
return _mm_cmpgt_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf));
}
// convert the sign-magnitude representation to two's complement
EIGEN_STRONG_INLINE __m128i pmaptosigned(const __m128i& a) {
constexpr uint16_t kAbsMask = (1 << 15) - 1;
// if 'a' has the sign bit set, clear the sign bit and negate the result as if it were an integer
return _mm_sign_epi16(_mm_and_si128(a, _mm_set1_epi16(kAbsMask)), a);
}
// return true if both `a` and `b` are not NaN
EIGEN_STRONG_INLINE Packet8h pisordered(const Packet8h& a, const Packet8h& b) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
__m128i abs_a = _mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask));
__m128i abs_b = _mm_and_si128(b.m_val, _mm_set1_epi16(kAbsMask));
// check if both `abs_a <= kInf` and `abs_b <= kInf` by checking if max(abs_a, abs_b) <= kInf
// SSE has no `lesser or equal` instruction for integers, but comparing against kInf + 1 accomplishes the same goal
return _mm_cmplt_epi16(_mm_max_epu16(abs_a, abs_b), _mm_set1_epi16(kInf + 1));
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
__m128i isOrdered = pisordered(a, b);
__m128i isEqual = _mm_cmpeq_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_and_si128(isOrdered, isEqual);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_le(half2float(a), half2float(b)));
__m128i isOrdered = pisordered(a, b);
__m128i isGreater = _mm_cmpgt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_andnot_si128(isGreater, isOrdered);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_lt(half2float(a), half2float(b)));
__m128i isOrdered = pisordered(a, b);
__m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_and_si128(isOrdered, isLess);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b)));
__m128i isUnordered = por(pisnan(a), pisnan(b));
__m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_or_si128(isUnordered, isLess);
}
template <>

View File

@ -497,16 +497,56 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
a = half(float(a) / float(b));
return a;
}
// Non-negative floating point numbers have a monotonic mapping to non-negative integers.
// This property allows floating point numbers to be reinterpreted as integers for comparisons, which is useful if there
// is no native floating point comparison operator. Floating point signedness is handled by the sign-magnitude
// representation, whereas integers typically use two's complement. Converting the bit pattern from sign-magnitude to
// two's complement allows the transformed bit patterns be compared as signed integers. All edge cases (+/-0 and +/-
// infinity) are handled automatically, except NaN.
//
// fp16 uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The bit pattern conveys NaN when all the exponent
// bits (5) are set, and at least one mantissa bit is set. The sign bit is irrelevant for determining NaN. To check for
// NaN, clear the sign bit and check if the integral representation is greater than 01111100000000. To test
// for non-NaN, clear the sign bit and check if the integeral representation is less than or equal to 01111100000000.
// convert sign-magnitude representation to two's complement
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int16_t mapToSigned(uint16_t a) {
constexpr uint16_t kAbsMask = (1 << 15) - 1;
// If the sign bit is set, clear the sign bit and return the (integer) negation. Otherwise, return the input.
return (a >> 15) ? -(a & kAbsMask) : a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool isOrdered(const half& a, const half& b) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
return numext::maxi(a.x & kAbsMask, b.x & kAbsMask) <= kInf;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) {
return numext::equal_strict(float(a), float(b));
bool result = mapToSigned(a.x) == mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) {
return numext::not_equal_strict(float(a), float(b));
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return !(a == b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) {
bool result = mapToSigned(a.x) < mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) {
bool result = mapToSigned(a.x) <= mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) {
bool result = mapToSigned(a.x) > mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) {
bool result = mapToSigned(a.x) >= mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return float(a) < float(b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return float(a) <= float(b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return float(a) > float(b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return float(a) >= float(b); }
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
@ -706,7 +746,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) {
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) {
return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) < 0x7c00;
#else
return (a.x & 0x7fff) < 0x7c00;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {

View File

@ -72,17 +72,16 @@ void test_conversion() {
// NaNs and infinities.
VERIFY(!(numext::isinf)(float(half(65504.0f)))); // Largest finite number.
VERIFY(!(numext::isnan)(float(half(0.0f))));
VERIFY((numext::isfinite)(float(half(65504.0f))));
VERIFY((numext::isfinite)(float(half(0.0f))));
VERIFY((numext::isinf)(float(half(__half_raw(0xfc00)))));
VERIFY((numext::isnan)(float(half(__half_raw(0xfc01)))));
VERIFY((numext::isinf)(float(half(__half_raw(0x7c00)))));
VERIFY((numext::isnan)(float(half(__half_raw(0x7c01)))));
#if !EIGEN_COMP_MSVC
// Visual Studio errors out on divisions by 0
VERIFY((numext::isnan)(float(half(0.0 / 0.0))));
VERIFY((numext::isinf)(float(half(1.0 / 0.0))));
VERIFY((numext::isinf)(float(half(-1.0 / 0.0))));
#endif
VERIFY((numext::isnan)(float(NumTraits<half>::quiet_NaN())));
VERIFY((numext::isinf)(float(NumTraits<half>::infinity())));
VERIFY((numext::isinf)(float(-NumTraits<half>::infinity())));
// Exactly same checks as above, just directly on the half representation.
VERIFY(!(numext::isinf)(half(__half_raw(0x7bff))));
@ -92,12 +91,9 @@ void test_conversion() {
VERIFY((numext::isinf)(half(__half_raw(0x7c00))));
VERIFY((numext::isnan)(half(__half_raw(0x7c01))));
#if !EIGEN_COMP_MSVC
// Visual Studio errors out on divisions by 0
VERIFY((numext::isnan)(half(0.0 / 0.0)));
VERIFY((numext::isinf)(half(1.0 / 0.0)));
VERIFY((numext::isinf)(half(-1.0 / 0.0)));
#endif
VERIFY((numext::isnan)(NumTraits<half>::quiet_NaN()));
VERIFY((numext::isinf)(NumTraits<half>::infinity()));
VERIFY((numext::isinf)(-NumTraits<half>::infinity()));
// Conversion to bool
VERIFY(!static_cast<bool>(half(0.0)));
@ -204,19 +200,25 @@ void test_comparison() {
VERIFY(half(1.0f) != half(2.0f));
// Comparisons with NaNs and infinities.
#if !EIGEN_COMP_MSVC
// Visual Studio errors out on divisions by 0
VERIFY(!(half(0.0 / 0.0) == half(0.0 / 0.0)));
VERIFY(half(0.0 / 0.0) != half(0.0 / 0.0));
VERIFY(!(NumTraits<half>::quiet_NaN() == NumTraits<half>::quiet_NaN()));
VERIFY(NumTraits<half>::quiet_NaN() != NumTraits<half>::quiet_NaN());
VERIFY(!(half(1.0) == half(0.0 / 0.0)));
VERIFY(!(half(1.0) < half(0.0 / 0.0)));
VERIFY(!(half(1.0) > half(0.0 / 0.0)));
VERIFY(half(1.0) != half(0.0 / 0.0));
VERIFY(!(internal::random<half>() == NumTraits<half>::quiet_NaN()));
VERIFY(!(internal::random<half>() < NumTraits<half>::quiet_NaN()));
VERIFY(!(internal::random<half>() > NumTraits<half>::quiet_NaN()));
VERIFY(!(internal::random<half>() <= NumTraits<half>::quiet_NaN()));
VERIFY(!(internal::random<half>() >= NumTraits<half>::quiet_NaN()));
VERIFY(internal::random<half>() != NumTraits<half>::quiet_NaN());
VERIFY(half(1.0) < half(1.0 / 0.0));
VERIFY(half(1.0) > half(-1.0 / 0.0));
#endif
VERIFY(!(NumTraits<half>::quiet_NaN() == internal::random<half>()));
VERIFY(!(NumTraits<half>::quiet_NaN() < internal::random<half>()));
VERIFY(!(NumTraits<half>::quiet_NaN() > internal::random<half>()));
VERIFY(!(NumTraits<half>::quiet_NaN() <= internal::random<half>()));
VERIFY(!(NumTraits<half>::quiet_NaN() >= internal::random<half>()));
VERIFY(NumTraits<half>::quiet_NaN() != internal::random<half>());
VERIFY(internal::random<half>() < NumTraits<half>::infinity());
VERIFY(internal::random<half>() > -NumTraits<half>::infinity());
}
void test_basic_functions() {