mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-08 22:21:49 +08:00
remove denormal flushing in fp32tobf16 for avx & avx512
(cherry picked from commit e6a5a594a7f3cbe2f9843d4ef57a10d478cbb818)
This commit is contained in:
parent
4e0357c6dd
commit
93bff85a42
@ -1274,12 +1274,7 @@ EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
|
|||||||
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
||||||
Packet8bf r;
|
Packet8bf r;
|
||||||
|
|
||||||
// Flush input denormals value to zero with hardware capability.
|
__m256i input = _mm256_castps_si256(a);
|
||||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
|
||||||
__m256 flush = _mm256_and_ps(a, a);
|
|
||||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
|
||||||
|
|
||||||
__m256i input = _mm256_castps_si256(flush);
|
|
||||||
|
|
||||||
#ifdef EIGEN_VECTORIZE_AVX2
|
#ifdef EIGEN_VECTORIZE_AVX2
|
||||||
// uint32_t lsb = (input >> 16);
|
// uint32_t lsb = (input >> 16);
|
||||||
@ -1293,7 +1288,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
|||||||
// input = input >> 16;
|
// input = input >> 16;
|
||||||
t = _mm256_srli_epi32(t, 16);
|
t = _mm256_srli_epi32(t, 16);
|
||||||
// Check NaN before converting back to bf16
|
// Check NaN before converting back to bf16
|
||||||
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
|
__m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
|
||||||
__m256i nan = _mm256_set1_epi32(0x7fc0);
|
__m256i nan = _mm256_set1_epi32(0x7fc0);
|
||||||
t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
|
t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
|
||||||
// output = numext::bit_cast<uint16_t>(input);
|
// output = numext::bit_cast<uint16_t>(input);
|
||||||
@ -1316,7 +1311,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
|||||||
lo = _mm_srli_epi32(lo, 16);
|
lo = _mm_srli_epi32(lo, 16);
|
||||||
hi = _mm_srli_epi32(hi, 16);
|
hi = _mm_srli_epi32(hi, 16);
|
||||||
// Check NaN before converting back to bf16
|
// Check NaN before converting back to bf16
|
||||||
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
|
__m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
|
||||||
__m128i nan = _mm_set1_epi32(0x7fc0);
|
__m128i nan = _mm_set1_epi32(0x7fc0);
|
||||||
lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
|
lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
|
||||||
hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
|
hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
|
||||||
|
@ -1945,23 +1945,15 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
|
|||||||
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||||
Packet16bf r;
|
Packet16bf r;
|
||||||
|
|
||||||
// Flush input denormals value to zero with hardware capability.
|
|
||||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
|
||||||
#if defined(EIGEN_VECTORIZE_AVX512DQ)
|
|
||||||
__m512 flush = _mm512_and_ps(a, a);
|
|
||||||
#else
|
|
||||||
__m512 flush = _mm512_max_ps(a, a);
|
|
||||||
#endif // EIGEN_VECTORIZE_AVX512DQ
|
|
||||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
|
||||||
|
|
||||||
#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
|
#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
|
||||||
// Since GCC 10.1 supports avx512bf16 and C style explicit cast
|
// Since GCC 10.1 supports avx512bf16 and C style explicit cast
|
||||||
// (C++ static_cast is not supported yet), do converion via intrinsic
|
// (C++ static_cast is not supported yet), do converion via intrinsic
|
||||||
// and register path for performance.
|
// and register path for performance.
|
||||||
r = (__m256i)(_mm512_cvtneps_pbh(flush));
|
r = (__m256i)(_mm512_cvtneps_pbh(a));
|
||||||
|
|
||||||
#else
|
#else
|
||||||
__m512i t;
|
__m512i t;
|
||||||
__m512i input = _mm512_castps_si512(flush);
|
__m512i input = _mm512_castps_si512(a);
|
||||||
__m512i nan = _mm512_set1_epi32(0x7fc0);
|
__m512i nan = _mm512_set1_epi32(0x7fc0);
|
||||||
|
|
||||||
// uint32_t lsb = (input >> 16) & 1;
|
// uint32_t lsb = (input >> 16) & 1;
|
||||||
@ -1974,9 +1966,9 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
|||||||
t = _mm512_srli_epi32(t, 16);
|
t = _mm512_srli_epi32(t, 16);
|
||||||
|
|
||||||
// Check NaN before converting back to bf16
|
// Check NaN before converting back to bf16
|
||||||
__mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q);
|
__mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
|
||||||
t = _mm512_mask_blend_epi32(mask, nan, t);
|
|
||||||
|
|
||||||
|
t = _mm512_mask_blend_epi32(mask, nan, t);
|
||||||
// output.value = static_cast<uint16_t>(input);
|
// output.value = static_cast<uint16_t>(input);
|
||||||
r = _mm512_cvtepi32_epi16(t);
|
r = _mm512_cvtepi32_epi16(t);
|
||||||
#endif // EIGEN_VECTORIZE_AVX512BF16
|
#endif // EIGEN_VECTORIZE_AVX512BF16
|
||||||
|
@ -250,10 +250,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const
|
|||||||
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
|
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
|
||||||
output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
|
output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
|
||||||
return output;
|
return output;
|
||||||
} else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
|
|
||||||
// Flush denormal to +/- 0.
|
|
||||||
output.value = std::signbit(v) ? 0x8000 : 0;
|
|
||||||
return output;
|
|
||||||
}
|
}
|
||||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
||||||
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||||
@ -288,9 +284,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<fals
|
|||||||
// qNaN magic: All exponent bits set + most significant bit of fraction
|
// qNaN magic: All exponent bits set + most significant bit of fraction
|
||||||
// set.
|
// set.
|
||||||
output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
|
output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
|
||||||
} else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
|
|
||||||
// Flush denormal to +/- 0.0
|
|
||||||
output.value = std::signbit(ff) ? 0x8000 : 0;
|
|
||||||
} else {
|
} else {
|
||||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
// Fast rounding algorithm that rounds a half value to nearest even. This
|
||||||
// reduces expected error when we convert a large number of floats. Here
|
// reduces expected error when we convert a large number of floats. Here
|
||||||
|
@ -32,18 +32,6 @@ float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
|
|||||||
return dest;
|
return dest;
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_truncate(float input, float expected_truncation, float expected_rounding){
|
|
||||||
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
|
|
||||||
bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(input);
|
|
||||||
if ((numext::isnan)(input)){
|
|
||||||
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
|
|
||||||
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
VERIFY_IS_EQUAL(expected_truncation, static_cast<float>(truncated));
|
|
||||||
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void test_roundtrip() {
|
void test_roundtrip() {
|
||||||
// Representable T round trip via bfloat16
|
// Representable T round trip via bfloat16
|
||||||
@ -122,31 +110,6 @@ void test_conversion()
|
|||||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.0f), 0x0000);
|
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.0f), 0x0000);
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-0.0f), 0x8000);
|
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-0.0f), 0x8000);
|
||||||
|
|
||||||
// Flush denormals to zero
|
|
||||||
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
|
||||||
denorm < std::numeric_limits<float>::denorm_min();
|
|
||||||
denorm = nextafterf(denorm, 1.0f)) {
|
|
||||||
bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm);
|
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f);
|
|
||||||
|
|
||||||
// Implicit conversion of denormls to bool is correct
|
|
||||||
VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(denorm)), false);
|
|
||||||
VERIFY_IS_EQUAL(bfloat16(denorm), false);
|
|
||||||
|
|
||||||
if (std::signbit(denorm)) {
|
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x8000);
|
|
||||||
} else {
|
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x0000);
|
|
||||||
}
|
|
||||||
bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(denorm);
|
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
|
|
||||||
if (std::signbit(denorm)) {
|
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x8000);
|
|
||||||
} else {
|
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x0000);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default is zero
|
// Default is zero
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
|
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
|
||||||
|
|
||||||
@ -156,52 +119,6 @@ void test_conversion()
|
|||||||
test_roundtrip<std::complex<float> >();
|
test_roundtrip<std::complex<float> >();
|
||||||
test_roundtrip<std::complex<double> >();
|
test_roundtrip<std::complex<double> >();
|
||||||
|
|
||||||
// Truncate test
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0xf5c3),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x49, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(1, 0x80, 0x48, 0xf5c3),
|
|
||||||
BinaryToFloat(1, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(1, 0x80, 0x49, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x8000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0xff, 0x00, 0x0001),
|
|
||||||
BinaryToFloat(0, 0xff, 0x40, 0x0000),
|
|
||||||
BinaryToFloat(0, 0xff, 0x40, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0xff, 0x7f, 0xffff),
|
|
||||||
BinaryToFloat(0, 0xff, 0x40, 0x0000),
|
|
||||||
BinaryToFloat(0, 0xff, 0x40, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(1, 0x80, 0x48, 0xc000),
|
|
||||||
BinaryToFloat(1, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(1, 0x80, 0x49, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x4000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x8000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0x00, 0x48, 0x8000),
|
|
||||||
BinaryToFloat(0, 0x00, 0x00, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x00, 0x00, 0x0000));
|
|
||||||
test_truncate(
|
|
||||||
BinaryToFloat(0, 0x00, 0x7f, 0xc000),
|
|
||||||
BinaryToFloat(0, 0x00, 0x00, 0x0000),
|
|
||||||
BinaryToFloat(0, 0x00, 0x00, 0x0000));
|
|
||||||
|
|
||||||
// Conversion
|
// Conversion
|
||||||
Array<float,1,100> a;
|
Array<float,1,100> a;
|
||||||
for (int i = 0; i < 100; i++) a(i) = i + 1.25;
|
for (int i = 0; i < 100; i++) a(i) = i + 1.25;
|
||||||
@ -250,12 +167,6 @@ void test_conversion()
|
|||||||
|
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)), 0x7fc0);
|
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)), 0x7fc0);
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)), 0xffc0);
|
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)), 0xffc0);
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
|
|
||||||
BinaryToFloat(0x0, 0xff, 0x40, 0x0)),
|
|
||||||
0x7fc0);
|
|
||||||
VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
|
|
||||||
BinaryToFloat(0x1, 0xff, 0x40, 0x0)),
|
|
||||||
0xffc0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void test_numtraits()
|
void test_numtraits()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user