remove denormal flushing in fp32tobf16 for avx & avx512

This commit is contained in:
Gauri Deshpande 2021-08-09 22:15:21 +00:00 committed by Rasmus Munk Larsen
parent 09d7122468
commit e6a5a594a7
4 changed files with 8 additions and 117 deletions

View File

@ -1274,12 +1274,7 @@ EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
Packet8bf r;
// Flush input denormals value to zero with hardware capability.
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
__m256 flush = _mm256_and_ps(a, a);
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
__m256i input = _mm256_castps_si256(flush);
__m256i input = _mm256_castps_si256(a);
#ifdef EIGEN_VECTORIZE_AVX2
// uint32_t lsb = (input >> 16);
@ -1293,7 +1288,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
// input = input >> 16;
t = _mm256_srli_epi32(t, 16);
// Check NaN before converting back to bf16
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
__m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
__m256i nan = _mm256_set1_epi32(0x7fc0);
t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
// output = numext::bit_cast<uint16_t>(input);
@ -1316,7 +1311,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
lo = _mm_srli_epi32(lo, 16);
hi = _mm_srli_epi32(hi, 16);
// Check NaN before converting back to bf16
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
__m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
__m128i nan = _mm_set1_epi32(0x7fc0);
lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));

View File

@ -1945,23 +1945,15 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
Packet16bf r;
// Flush input denormals value to zero with hardware capability.
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
#if defined(EIGEN_VECTORIZE_AVX512DQ)
__m512 flush = _mm512_and_ps(a, a);
#else
__m512 flush = _mm512_max_ps(a, a);
#endif // EIGEN_VECTORIZE_AVX512DQ
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
// Since GCC 10.1 supports avx512bf16 and C style explicit cast
// (C++ static_cast is not supported yet), do converion via intrinsic
// and register path for performance.
r = (__m256i)(_mm512_cvtneps_pbh(flush));
r = (__m256i)(_mm512_cvtneps_pbh(a));
#else
__m512i t;
__m512i input = _mm512_castps_si512(flush);
__m512i input = _mm512_castps_si512(a);
__m512i nan = _mm512_set1_epi32(0x7fc0);
// uint32_t lsb = (input >> 16) & 1;
@ -1974,9 +1966,9 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
t = _mm512_srli_epi32(t, 16);
// Check NaN before converting back to bf16
__mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q);
t = _mm512_mask_blend_epi32(mask, nan, t);
__mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
t = _mm512_mask_blend_epi32(mask, nan, t);
// output.value = static_cast<uint16_t>(input);
r = _mm512_cvtepi32_epi16(t);
#endif // EIGEN_VECTORIZE_AVX512BF16

View File

@ -250,10 +250,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
return output;
} else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
// Flush denormal to +/- 0.
output.value = std::signbit(v) ? 0x8000 : 0;
return output;
}
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
@ -288,9 +284,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<fals
// qNaN magic: All exponent bits set + most significant bit of fraction
// set.
output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
} else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
// Flush denormal to +/- 0.0
output.value = std::signbit(ff) ? 0x8000 : 0;
} else {
// Fast rounding algorithm that rounds a half value to nearest even. This
// reduces expected error when we convert a large number of floats. Here

View File

@ -32,18 +32,6 @@ float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
return dest;
}
void test_truncate(float input, float expected_truncation, float expected_rounding){
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(input);
if ((numext::isnan)(input)){
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
return;
}
VERIFY_IS_EQUAL(expected_truncation, static_cast<float>(truncated));
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
}
template<typename T>
void test_roundtrip() {
// Representable T round trip via bfloat16
@ -122,31 +110,6 @@ void test_conversion()
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.0f), 0x0000);
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-0.0f), 0x8000);
// Flush denormals to zero
for (float denorm = -std::numeric_limits<float>::denorm_min();
denorm < std::numeric_limits<float>::denorm_min();
denorm = nextafterf(denorm, 1.0f)) {
bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm);
VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f);
// Implicit conversion of denormls to bool is correct
VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(denorm)), false);
VERIFY_IS_EQUAL(bfloat16(denorm), false);
if (std::signbit(denorm)) {
VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x8000);
} else {
VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x0000);
}
bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(denorm);
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
if (std::signbit(denorm)) {
VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x8000);
} else {
VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x0000);
}
}
// Default is zero
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
@ -156,52 +119,6 @@ void test_conversion()
test_roundtrip<std::complex<float> >();
test_roundtrip<std::complex<double> >();
// Truncate test
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0xf5c3),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x49, 0x0000));
test_truncate(
BinaryToFloat(1, 0x80, 0x48, 0xf5c3),
BinaryToFloat(1, 0x80, 0x48, 0x0000),
BinaryToFloat(1, 0x80, 0x49, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x8000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0xff, 0x00, 0x0001),
BinaryToFloat(0, 0xff, 0x40, 0x0000),
BinaryToFloat(0, 0xff, 0x40, 0x0000));
test_truncate(
BinaryToFloat(0, 0xff, 0x7f, 0xffff),
BinaryToFloat(0, 0xff, 0x40, 0x0000),
BinaryToFloat(0, 0xff, 0x40, 0x0000));
test_truncate(
BinaryToFloat(1, 0x80, 0x48, 0xc000),
BinaryToFloat(1, 0x80, 0x48, 0x0000),
BinaryToFloat(1, 0x80, 0x49, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x4000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0x80, 0x48, 0x8000),
BinaryToFloat(0, 0x80, 0x48, 0x0000),
BinaryToFloat(0, 0x80, 0x48, 0x0000));
test_truncate(
BinaryToFloat(0, 0x00, 0x48, 0x8000),
BinaryToFloat(0, 0x00, 0x00, 0x0000),
BinaryToFloat(0, 0x00, 0x00, 0x0000));
test_truncate(
BinaryToFloat(0, 0x00, 0x7f, 0xc000),
BinaryToFloat(0, 0x00, 0x00, 0x0000),
BinaryToFloat(0, 0x00, 0x00, 0x0000));
// Conversion
Array<float,1,100> a;
for (int i = 0; i < 100; i++) a(i) = i + 1.25;
@ -250,12 +167,6 @@ void test_conversion()
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)), 0x7fc0);
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)), 0xffc0);
VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
BinaryToFloat(0x0, 0xff, 0x40, 0x0)),
0x7fc0);
VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
BinaryToFloat(0x1, 0xff, 0x40, 0x0)),
0xffc0);
}
void test_numtraits()