diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 1b1d326b3..eb5da53d0 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -2249,24 +2249,64 @@ EIGEN_STRONG_INLINE Packet8h ptrunc(const Packet8h& a) { return float2half(ptrunc(half2float(a))); } +template <> +EIGEN_STRONG_INLINE Packet8h pisinf(const Packet8h& a) { + constexpr uint16_t kInf = ((1 << 5) - 1) << 10; + constexpr uint16_t kAbsMask = (1 << 15) - 1; + return _mm_cmpeq_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pisnan(const Packet8h& a) { + constexpr uint16_t kInf = ((1 << 5) - 1) << 10; + constexpr uint16_t kAbsMask = (1 << 15) - 1; + return _mm_cmpgt_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf)); +} + +// convert the sign-magnitude representation to two's complement +EIGEN_STRONG_INLINE __m128i pmaptosigned(const __m128i& a) { + constexpr uint16_t kAbsMask = (1 << 15) - 1; + // if 'a' has the sign bit set, clear the sign bit and negate the result as if it were an integer + return _mm_sign_epi16(_mm_and_si128(a, _mm_set1_epi16(kAbsMask)), a); +} + +// return true if both `a` and `b` are not NaN +EIGEN_STRONG_INLINE Packet8h pisordered(const Packet8h& a, const Packet8h& b) { + constexpr uint16_t kInf = ((1 << 5) - 1) << 10; + constexpr uint16_t kAbsMask = (1 << 15) - 1; + __m128i abs_a = _mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)); + __m128i abs_b = _mm_and_si128(b.m_val, _mm_set1_epi16(kAbsMask)); + // check if both `abs_a <= kInf` and `abs_b <= kInf` by checking if max(abs_a, abs_b) <= kInf + // SSE has no `lesser or equal` instruction for integers, but comparing against kInf + 1 accomplishes the same goal + return _mm_cmplt_epi16(_mm_max_epu16(abs_a, abs_b), _mm_set1_epi16(kInf + 1)); +} + template <> EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) { - return Pack16To8(pcmp_eq(half2float(a), half2float(b))); + __m128i isOrdered = pisordered(a, b); + __m128i isEqual = _mm_cmpeq_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val)); + return _mm_and_si128(isOrdered, isEqual); } template <> EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) { - return Pack16To8(pcmp_le(half2float(a), half2float(b))); + __m128i isOrdered = pisordered(a, b); + __m128i isGreater = _mm_cmpgt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val)); + return _mm_andnot_si128(isGreater, isOrdered); } template <> EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) { - return Pack16To8(pcmp_lt(half2float(a), half2float(b))); + __m128i isOrdered = pisordered(a, b); + __m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val)); + return _mm_and_si128(isOrdered, isLess); } template <> EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) { - return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b))); + __m128i isUnordered = por(pisnan(a), pisnan(b)); + __m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val)); + return _mm_or_si128(isUnordered, isLess); } template <> diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index ba70d5fab..c073fe8ec 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -497,16 +497,56 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) { a = half(float(a) / float(b)); return a; } + +// Non-negative floating point numbers have a monotonic mapping to non-negative integers. +// This property allows floating point numbers to be reinterpreted as integers for comparisons, which is useful if there +// is no native floating point comparison operator. Floating point signedness is handled by the sign-magnitude +// representation, whereas integers typically use two's complement. Converting the bit pattern from sign-magnitude to +// two's complement allows the transformed bit patterns be compared as signed integers. All edge cases (+/-0 and +/- +// infinity) are handled automatically, except NaN. +// +// fp16 uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The bit pattern conveys NaN when all the exponent +// bits (5) are set, and at least one mantissa bit is set. The sign bit is irrelevant for determining NaN. To check for +// NaN, clear the sign bit and check if the integral representation is greater than 01111100000000. To test +// for non-NaN, clear the sign bit and check if the integeral representation is less than or equal to 01111100000000. + +// convert sign-magnitude representation to two's complement +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int16_t mapToSigned(uint16_t a) { + constexpr uint16_t kAbsMask = (1 << 15) - 1; + // If the sign bit is set, clear the sign bit and return the (integer) negation. Otherwise, return the input. + return (a >> 15) ? -(a & kAbsMask) : a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool isOrdered(const half& a, const half& b) { + constexpr uint16_t kInf = ((1 << 5) - 1) << 10; + constexpr uint16_t kAbsMask = (1 << 15) - 1; + return numext::maxi(a.x & kAbsMask, b.x & kAbsMask) <= kInf; +} EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) { - return numext::equal_strict(float(a), float(b)); + bool result = mapToSigned(a.x) == mapToSigned(b.x); + result &= isOrdered(a, b); + return result; } -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { - return numext::not_equal_strict(float(a), float(b)); +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return !(a == b); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { + bool result = mapToSigned(a.x) < mapToSigned(b.x); + result &= isOrdered(a, b); + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { + bool result = mapToSigned(a.x) <= mapToSigned(b.x); + result &= isOrdered(a, b); + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { + bool result = mapToSigned(a.x) > mapToSigned(b.x); + result &= isOrdered(a, b); + return result; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { + bool result = mapToSigned(a.x) >= mapToSigned(b.x); + result &= isOrdered(a, b); + return result; } -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return float(a) < float(b); } -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return float(a) <= float(b); } -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return float(a) > float(b); } -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return float(a) >= float(b); } #if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC) #pragma pop_macro("EIGEN_DEVICE_FUNC") @@ -706,7 +746,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) { #endif } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) { - return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a)); +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16) + return (numext::bit_cast(a.x) & 0x7fff) < 0x7c00; +#else + return (a.x & 0x7fff) < 0x7c00; +#endif } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) { diff --git a/test/half_float.cpp b/test/half_float.cpp index 90ac825d6..aabe896a9 100644 --- a/test/half_float.cpp +++ b/test/half_float.cpp @@ -72,17 +72,16 @@ void test_conversion() { // NaNs and infinities. VERIFY(!(numext::isinf)(float(half(65504.0f)))); // Largest finite number. VERIFY(!(numext::isnan)(float(half(0.0f)))); + VERIFY((numext::isfinite)(float(half(65504.0f)))); + VERIFY((numext::isfinite)(float(half(0.0f)))); VERIFY((numext::isinf)(float(half(__half_raw(0xfc00))))); VERIFY((numext::isnan)(float(half(__half_raw(0xfc01))))); VERIFY((numext::isinf)(float(half(__half_raw(0x7c00))))); VERIFY((numext::isnan)(float(half(__half_raw(0x7c01))))); -#if !EIGEN_COMP_MSVC - // Visual Studio errors out on divisions by 0 - VERIFY((numext::isnan)(float(half(0.0 / 0.0)))); - VERIFY((numext::isinf)(float(half(1.0 / 0.0)))); - VERIFY((numext::isinf)(float(half(-1.0 / 0.0)))); -#endif + VERIFY((numext::isnan)(float(NumTraits::quiet_NaN()))); + VERIFY((numext::isinf)(float(NumTraits::infinity()))); + VERIFY((numext::isinf)(float(-NumTraits::infinity()))); // Exactly same checks as above, just directly on the half representation. VERIFY(!(numext::isinf)(half(__half_raw(0x7bff)))); @@ -92,12 +91,9 @@ void test_conversion() { VERIFY((numext::isinf)(half(__half_raw(0x7c00)))); VERIFY((numext::isnan)(half(__half_raw(0x7c01)))); -#if !EIGEN_COMP_MSVC - // Visual Studio errors out on divisions by 0 - VERIFY((numext::isnan)(half(0.0 / 0.0))); - VERIFY((numext::isinf)(half(1.0 / 0.0))); - VERIFY((numext::isinf)(half(-1.0 / 0.0))); -#endif + VERIFY((numext::isnan)(NumTraits::quiet_NaN())); + VERIFY((numext::isinf)(NumTraits::infinity())); + VERIFY((numext::isinf)(-NumTraits::infinity())); // Conversion to bool VERIFY(!static_cast(half(0.0))); @@ -204,19 +200,25 @@ void test_comparison() { VERIFY(half(1.0f) != half(2.0f)); // Comparisons with NaNs and infinities. -#if !EIGEN_COMP_MSVC - // Visual Studio errors out on divisions by 0 - VERIFY(!(half(0.0 / 0.0) == half(0.0 / 0.0))); - VERIFY(half(0.0 / 0.0) != half(0.0 / 0.0)); + VERIFY(!(NumTraits::quiet_NaN() == NumTraits::quiet_NaN())); + VERIFY(NumTraits::quiet_NaN() != NumTraits::quiet_NaN()); - VERIFY(!(half(1.0) == half(0.0 / 0.0))); - VERIFY(!(half(1.0) < half(0.0 / 0.0))); - VERIFY(!(half(1.0) > half(0.0 / 0.0))); - VERIFY(half(1.0) != half(0.0 / 0.0)); + VERIFY(!(internal::random() == NumTraits::quiet_NaN())); + VERIFY(!(internal::random() < NumTraits::quiet_NaN())); + VERIFY(!(internal::random() > NumTraits::quiet_NaN())); + VERIFY(!(internal::random() <= NumTraits::quiet_NaN())); + VERIFY(!(internal::random() >= NumTraits::quiet_NaN())); + VERIFY(internal::random() != NumTraits::quiet_NaN()); - VERIFY(half(1.0) < half(1.0 / 0.0)); - VERIFY(half(1.0) > half(-1.0 / 0.0)); -#endif + VERIFY(!(NumTraits::quiet_NaN() == internal::random())); + VERIFY(!(NumTraits::quiet_NaN() < internal::random())); + VERIFY(!(NumTraits::quiet_NaN() > internal::random())); + VERIFY(!(NumTraits::quiet_NaN() <= internal::random())); + VERIFY(!(NumTraits::quiet_NaN() >= internal::random())); + VERIFY(NumTraits::quiet_NaN() != internal::random()); + + VERIFY(internal::random() < NumTraits::infinity()); + VERIFY(internal::random() > -NumTraits::infinity()); } void test_basic_functions() {