mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-19 19:34:29 +08:00
Faster emulated half comparisons
This commit is contained in:
parent
ac6955ebc6
commit
bcce88c99e
@ -2249,24 +2249,64 @@ EIGEN_STRONG_INLINE Packet8h ptrunc<Packet8h>(const Packet8h& a) {
|
||||
return float2half(ptrunc<Packet8f>(half2float(a)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pisinf<Packet8h>(const Packet8h& a) {
|
||||
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
|
||||
constexpr uint16_t kAbsMask = (1 << 15) - 1;
|
||||
return _mm_cmpeq_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pisnan<Packet8h>(const Packet8h& a) {
|
||||
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
|
||||
constexpr uint16_t kAbsMask = (1 << 15) - 1;
|
||||
return _mm_cmpgt_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf));
|
||||
}
|
||||
|
||||
// convert the sign-magnitude representation to two's complement
|
||||
EIGEN_STRONG_INLINE __m128i pmaptosigned(const __m128i& a) {
|
||||
constexpr uint16_t kAbsMask = (1 << 15) - 1;
|
||||
// if 'a' has the sign bit set, clear the sign bit and negate the result as if it were an integer
|
||||
return _mm_sign_epi16(_mm_and_si128(a, _mm_set1_epi16(kAbsMask)), a);
|
||||
}
|
||||
|
||||
// return true if both `a` and `b` are not NaN
|
||||
EIGEN_STRONG_INLINE Packet8h pisordered(const Packet8h& a, const Packet8h& b) {
|
||||
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
|
||||
constexpr uint16_t kAbsMask = (1 << 15) - 1;
|
||||
__m128i abs_a = _mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask));
|
||||
__m128i abs_b = _mm_and_si128(b.m_val, _mm_set1_epi16(kAbsMask));
|
||||
// check if both `abs_a <= kInf` and `abs_b <= kInf` by checking if max(abs_a, abs_b) <= kInf
|
||||
// SSE has no `lesser or equal` instruction for integers, but comparing against kInf + 1 accomplishes the same goal
|
||||
return _mm_cmplt_epi16(_mm_max_epu16(abs_a, abs_b), _mm_set1_epi16(kInf + 1));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) {
|
||||
return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
|
||||
__m128i isOrdered = pisordered(a, b);
|
||||
__m128i isEqual = _mm_cmpeq_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
|
||||
return _mm_and_si128(isOrdered, isEqual);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) {
|
||||
return Pack16To8(pcmp_le(half2float(a), half2float(b)));
|
||||
__m128i isOrdered = pisordered(a, b);
|
||||
__m128i isGreater = _mm_cmpgt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
|
||||
return _mm_andnot_si128(isGreater, isOrdered);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) {
|
||||
return Pack16To8(pcmp_lt(half2float(a), half2float(b)));
|
||||
__m128i isOrdered = pisordered(a, b);
|
||||
__m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
|
||||
return _mm_and_si128(isOrdered, isLess);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) {
|
||||
return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b)));
|
||||
__m128i isUnordered = por(pisnan(a), pisnan(b));
|
||||
__m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
|
||||
return _mm_or_si128(isUnordered, isLess);
|
||||
}
|
||||
|
||||
template <>
|
||||
|
@ -497,16 +497,56 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
|
||||
a = half(float(a) / float(b));
|
||||
return a;
|
||||
}
|
||||
|
||||
// Non-negative floating point numbers have a monotonic mapping to non-negative integers.
|
||||
// This property allows floating point numbers to be reinterpreted as integers for comparisons, which is useful if there
|
||||
// is no native floating point comparison operator. Floating point signedness is handled by the sign-magnitude
|
||||
// representation, whereas integers typically use two's complement. Converting the bit pattern from sign-magnitude to
|
||||
// two's complement allows the transformed bit patterns be compared as signed integers. All edge cases (+/-0 and +/-
|
||||
// infinity) are handled automatically, except NaN.
|
||||
//
|
||||
// fp16 uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The bit pattern conveys NaN when all the exponent
|
||||
// bits (5) are set, and at least one mantissa bit is set. The sign bit is irrelevant for determining NaN. To check for
|
||||
// NaN, clear the sign bit and check if the integral representation is greater than 01111100000000. To test
|
||||
// for non-NaN, clear the sign bit and check if the integeral representation is less than or equal to 01111100000000.
|
||||
|
||||
// convert sign-magnitude representation to two's complement
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int16_t mapToSigned(uint16_t a) {
|
||||
constexpr uint16_t kAbsMask = (1 << 15) - 1;
|
||||
// If the sign bit is set, clear the sign bit and return the (integer) negation. Otherwise, return the input.
|
||||
return (a >> 15) ? -(a & kAbsMask) : a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool isOrdered(const half& a, const half& b) {
|
||||
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
|
||||
constexpr uint16_t kAbsMask = (1 << 15) - 1;
|
||||
return numext::maxi(a.x & kAbsMask, b.x & kAbsMask) <= kInf;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) {
|
||||
return numext::equal_strict(float(a), float(b));
|
||||
bool result = mapToSigned(a.x) == mapToSigned(b.x);
|
||||
result &= isOrdered(a, b);
|
||||
return result;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) {
|
||||
return numext::not_equal_strict(float(a), float(b));
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return !(a == b); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) {
|
||||
bool result = mapToSigned(a.x) < mapToSigned(b.x);
|
||||
result &= isOrdered(a, b);
|
||||
return result;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) {
|
||||
bool result = mapToSigned(a.x) <= mapToSigned(b.x);
|
||||
result &= isOrdered(a, b);
|
||||
return result;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) {
|
||||
bool result = mapToSigned(a.x) > mapToSigned(b.x);
|
||||
result &= isOrdered(a, b);
|
||||
return result;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) {
|
||||
bool result = mapToSigned(a.x) >= mapToSigned(b.x);
|
||||
result &= isOrdered(a, b);
|
||||
return result;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return float(a) < float(b); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return float(a) <= float(b); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return float(a) > float(b); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return float(a) >= float(b); }
|
||||
|
||||
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
|
||||
#pragma pop_macro("EIGEN_DEVICE_FUNC")
|
||||
@ -706,7 +746,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) {
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) {
|
||||
return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) < 0x7c00;
|
||||
#else
|
||||
return (a.x & 0x7fff) < 0x7c00;
|
||||
#endif
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
|
||||
|
@ -72,17 +72,16 @@ void test_conversion() {
|
||||
// NaNs and infinities.
|
||||
VERIFY(!(numext::isinf)(float(half(65504.0f)))); // Largest finite number.
|
||||
VERIFY(!(numext::isnan)(float(half(0.0f))));
|
||||
VERIFY((numext::isfinite)(float(half(65504.0f))));
|
||||
VERIFY((numext::isfinite)(float(half(0.0f))));
|
||||
VERIFY((numext::isinf)(float(half(__half_raw(0xfc00)))));
|
||||
VERIFY((numext::isnan)(float(half(__half_raw(0xfc01)))));
|
||||
VERIFY((numext::isinf)(float(half(__half_raw(0x7c00)))));
|
||||
VERIFY((numext::isnan)(float(half(__half_raw(0x7c01)))));
|
||||
|
||||
#if !EIGEN_COMP_MSVC
|
||||
// Visual Studio errors out on divisions by 0
|
||||
VERIFY((numext::isnan)(float(half(0.0 / 0.0))));
|
||||
VERIFY((numext::isinf)(float(half(1.0 / 0.0))));
|
||||
VERIFY((numext::isinf)(float(half(-1.0 / 0.0))));
|
||||
#endif
|
||||
VERIFY((numext::isnan)(float(NumTraits<half>::quiet_NaN())));
|
||||
VERIFY((numext::isinf)(float(NumTraits<half>::infinity())));
|
||||
VERIFY((numext::isinf)(float(-NumTraits<half>::infinity())));
|
||||
|
||||
// Exactly same checks as above, just directly on the half representation.
|
||||
VERIFY(!(numext::isinf)(half(__half_raw(0x7bff))));
|
||||
@ -92,12 +91,9 @@ void test_conversion() {
|
||||
VERIFY((numext::isinf)(half(__half_raw(0x7c00))));
|
||||
VERIFY((numext::isnan)(half(__half_raw(0x7c01))));
|
||||
|
||||
#if !EIGEN_COMP_MSVC
|
||||
// Visual Studio errors out on divisions by 0
|
||||
VERIFY((numext::isnan)(half(0.0 / 0.0)));
|
||||
VERIFY((numext::isinf)(half(1.0 / 0.0)));
|
||||
VERIFY((numext::isinf)(half(-1.0 / 0.0)));
|
||||
#endif
|
||||
VERIFY((numext::isnan)(NumTraits<half>::quiet_NaN()));
|
||||
VERIFY((numext::isinf)(NumTraits<half>::infinity()));
|
||||
VERIFY((numext::isinf)(-NumTraits<half>::infinity()));
|
||||
|
||||
// Conversion to bool
|
||||
VERIFY(!static_cast<bool>(half(0.0)));
|
||||
@ -204,19 +200,25 @@ void test_comparison() {
|
||||
VERIFY(half(1.0f) != half(2.0f));
|
||||
|
||||
// Comparisons with NaNs and infinities.
|
||||
#if !EIGEN_COMP_MSVC
|
||||
// Visual Studio errors out on divisions by 0
|
||||
VERIFY(!(half(0.0 / 0.0) == half(0.0 / 0.0)));
|
||||
VERIFY(half(0.0 / 0.0) != half(0.0 / 0.0));
|
||||
VERIFY(!(NumTraits<half>::quiet_NaN() == NumTraits<half>::quiet_NaN()));
|
||||
VERIFY(NumTraits<half>::quiet_NaN() != NumTraits<half>::quiet_NaN());
|
||||
|
||||
VERIFY(!(half(1.0) == half(0.0 / 0.0)));
|
||||
VERIFY(!(half(1.0) < half(0.0 / 0.0)));
|
||||
VERIFY(!(half(1.0) > half(0.0 / 0.0)));
|
||||
VERIFY(half(1.0) != half(0.0 / 0.0));
|
||||
VERIFY(!(internal::random<half>() == NumTraits<half>::quiet_NaN()));
|
||||
VERIFY(!(internal::random<half>() < NumTraits<half>::quiet_NaN()));
|
||||
VERIFY(!(internal::random<half>() > NumTraits<half>::quiet_NaN()));
|
||||
VERIFY(!(internal::random<half>() <= NumTraits<half>::quiet_NaN()));
|
||||
VERIFY(!(internal::random<half>() >= NumTraits<half>::quiet_NaN()));
|
||||
VERIFY(internal::random<half>() != NumTraits<half>::quiet_NaN());
|
||||
|
||||
VERIFY(half(1.0) < half(1.0 / 0.0));
|
||||
VERIFY(half(1.0) > half(-1.0 / 0.0));
|
||||
#endif
|
||||
VERIFY(!(NumTraits<half>::quiet_NaN() == internal::random<half>()));
|
||||
VERIFY(!(NumTraits<half>::quiet_NaN() < internal::random<half>()));
|
||||
VERIFY(!(NumTraits<half>::quiet_NaN() > internal::random<half>()));
|
||||
VERIFY(!(NumTraits<half>::quiet_NaN() <= internal::random<half>()));
|
||||
VERIFY(!(NumTraits<half>::quiet_NaN() >= internal::random<half>()));
|
||||
VERIFY(NumTraits<half>::quiet_NaN() != internal::random<half>());
|
||||
|
||||
VERIFY(internal::random<half>() < NumTraits<half>::infinity());
|
||||
VERIFY(internal::random<half>() > -NumTraits<half>::infinity());
|
||||
}
|
||||
|
||||
void test_basic_functions() {
|
||||
|
Loading…
x
Reference in New Issue
Block a user