diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index dd3f243d2..7fc32fd71 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -1274,12 +1274,7 @@ EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) { EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { Packet8bf r; - // Flush input denormals value to zero with hardware capability. - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); - __m256 flush = _mm256_and_ps(a, a); - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); - - __m256i input = _mm256_castps_si256(flush); + __m256i input = _mm256_castps_si256(a); #ifdef EIGEN_VECTORIZE_AVX2 // uint32_t lsb = (input >> 16); @@ -1293,7 +1288,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { // input = input >> 16; t = _mm256_srli_epi32(t, 16); // Check NaN before converting back to bf16 - __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); + __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q); __m256i nan = _mm256_set1_epi32(0x7fc0); t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask)); // output = numext::bit_cast(input); @@ -1316,7 +1311,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) { lo = _mm_srli_epi32(lo, 16); hi = _mm_srli_epi32(hi, 16); // Check NaN before converting back to bf16 - __m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q); + __m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q); __m128i nan = _mm_set1_epi32(0x7fc0); lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask))); hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1))); diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 59bbef0d1..34d49ab66 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -1945,23 +1945,15 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) { EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { Packet16bf r; - // Flush input denormals value to zero with hardware capability. - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); -#if defined(EIGEN_VECTORIZE_AVX512DQ) - __m512 flush = _mm512_and_ps(a, a); -#else - __m512 flush = _mm512_max_ps(a, a); -#endif // EIGEN_VECTORIZE_AVX512DQ - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF); - #if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1) // Since GCC 10.1 supports avx512bf16 and C style explicit cast // (C++ static_cast is not supported yet), do converion via intrinsic // and register path for performance. - r = (__m256i)(_mm512_cvtneps_pbh(flush)); + r = (__m256i)(_mm512_cvtneps_pbh(a)); + #else __m512i t; - __m512i input = _mm512_castps_si512(flush); + __m512i input = _mm512_castps_si512(a); __m512i nan = _mm512_set1_epi32(0x7fc0); // uint32_t lsb = (input >> 16) & 1; @@ -1974,9 +1966,9 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) { t = _mm512_srli_epi32(t, 16); // Check NaN before converting back to bf16 - __mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q); - t = _mm512_mask_blend_epi32(mask, nan, t); + __mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + t = _mm512_mask_blend_epi32(mask, nan, t); // output.value = static_cast(input); r = _mm512_cvtepi32_epi16(t); #endif // EIGEN_VECTORIZE_AVX512BF16 diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index aac60f15c..1c28f4f95 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -250,10 +250,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) { output.value = std::signbit(v) ? 0xFFC0: 0x7FC0; return output; - } else if (std::fabs(v) < std::numeric_limits::min EIGEN_NOT_A_MACRO()) { - // Flush denormal to +/- 0. - output.value = std::signbit(v) ? 0x8000 : 0; - return output; } const uint16_t* p = reinterpret_cast(&v); #if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ @@ -288,9 +284,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne::min EIGEN_NOT_A_MACRO()) { - // Flush denormal to +/- 0.0 - output.value = std::signbit(ff) ? 0x8000 : 0; } else { // Fast rounding algorithm that rounds a half value to nearest even. This // reduces expected error when we convert a large number of floats. Here diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp index 1df22f73e..c3de0b19a 100644 --- a/test/bfloat16_float.cpp +++ b/test/bfloat16_float.cpp @@ -32,18 +32,6 @@ float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa, return dest; } -void test_truncate(float input, float expected_truncation, float expected_rounding){ - bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input); - bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne(input); - if ((numext::isnan)(input)){ - VERIFY((numext::isnan)(static_cast(truncated)) || (numext::isinf)(static_cast(truncated))); - VERIFY((numext::isnan)(static_cast(rounded)) || (numext::isinf)(static_cast(rounded))); - return; - } - VERIFY_IS_EQUAL(expected_truncation, static_cast(truncated)); - VERIFY_IS_EQUAL(expected_rounding, static_cast(rounded)); -} - template void test_roundtrip() { // Representable T round trip via bfloat16 @@ -122,31 +110,6 @@ void test_conversion() VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.0f), 0x0000); VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-0.0f), 0x8000); - // Flush denormals to zero - for (float denorm = -std::numeric_limits::denorm_min(); - denorm < std::numeric_limits::denorm_min(); - denorm = nextafterf(denorm, 1.0f)) { - bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm); - VERIFY_IS_EQUAL(static_cast(bf_trunc), 0.0f); - - // Implicit conversion of denormls to bool is correct - VERIFY_IS_EQUAL(static_cast(bfloat16(denorm)), false); - VERIFY_IS_EQUAL(bfloat16(denorm), false); - - if (std::signbit(denorm)) { - VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x8000); - } else { - VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x0000); - } - bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm); - VERIFY_IS_EQUAL(static_cast(bf_round), 0.0f); - if (std::signbit(denorm)) { - VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x8000); - } else { - VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x0000); - } - } - // Default is zero VERIFY_IS_EQUAL(static_cast(bfloat16()), 0.0f); @@ -156,52 +119,6 @@ void test_conversion() test_roundtrip >(); test_roundtrip >(); - // Truncate test - test_truncate( - BinaryToFloat(0, 0x80, 0x48, 0xf5c3), - BinaryToFloat(0, 0x80, 0x48, 0x0000), - BinaryToFloat(0, 0x80, 0x49, 0x0000)); - test_truncate( - BinaryToFloat(1, 0x80, 0x48, 0xf5c3), - BinaryToFloat(1, 0x80, 0x48, 0x0000), - BinaryToFloat(1, 0x80, 0x49, 0x0000)); - test_truncate( - BinaryToFloat(0, 0x80, 0x48, 0x8000), - BinaryToFloat(0, 0x80, 0x48, 0x0000), - BinaryToFloat(0, 0x80, 0x48, 0x0000)); - test_truncate( - BinaryToFloat(0, 0xff, 0x00, 0x0001), - BinaryToFloat(0, 0xff, 0x40, 0x0000), - BinaryToFloat(0, 0xff, 0x40, 0x0000)); - test_truncate( - BinaryToFloat(0, 0xff, 0x7f, 0xffff), - BinaryToFloat(0, 0xff, 0x40, 0x0000), - BinaryToFloat(0, 0xff, 0x40, 0x0000)); - test_truncate( - BinaryToFloat(1, 0x80, 0x48, 0xc000), - BinaryToFloat(1, 0x80, 0x48, 0x0000), - BinaryToFloat(1, 0x80, 0x49, 0x0000)); - test_truncate( - BinaryToFloat(0, 0x80, 0x48, 0x0000), - BinaryToFloat(0, 0x80, 0x48, 0x0000), - BinaryToFloat(0, 0x80, 0x48, 0x0000)); - test_truncate( - BinaryToFloat(0, 0x80, 0x48, 0x4000), - BinaryToFloat(0, 0x80, 0x48, 0x0000), - BinaryToFloat(0, 0x80, 0x48, 0x0000)); - test_truncate( - BinaryToFloat(0, 0x80, 0x48, 0x8000), - BinaryToFloat(0, 0x80, 0x48, 0x0000), - BinaryToFloat(0, 0x80, 0x48, 0x0000)); - test_truncate( - BinaryToFloat(0, 0x00, 0x48, 0x8000), - BinaryToFloat(0, 0x00, 0x00, 0x0000), - BinaryToFloat(0, 0x00, 0x00, 0x0000)); - test_truncate( - BinaryToFloat(0, 0x00, 0x7f, 0xc000), - BinaryToFloat(0, 0x00, 0x00, 0x0000), - BinaryToFloat(0, 0x00, 0x00, 0x0000)); - // Conversion Array a; for (int i = 0; i < 100; i++) a(i) = i + 1.25; @@ -250,12 +167,6 @@ void test_conversion() VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)), 0x7fc0); VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)), 0xffc0); - VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16( - BinaryToFloat(0x0, 0xff, 0x40, 0x0)), - 0x7fc0); - VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16( - BinaryToFloat(0x1, 0xff, 0x40, 0x0)), - 0xffc0); } void test_numtraits()