mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-16 14:49:39 +08:00
remove denormal flushing in fp32tobf16 for avx & avx512
This commit is contained in:
parent
09d7122468
commit
e6a5a594a7
@ -1274,12 +1274,7 @@ EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
|
||||
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
||||
Packet8bf r;
|
||||
|
||||
// Flush input denormals value to zero with hardware capability.
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
||||
__m256 flush = _mm256_and_ps(a, a);
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
||||
|
||||
__m256i input = _mm256_castps_si256(flush);
|
||||
__m256i input = _mm256_castps_si256(a);
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
// uint32_t lsb = (input >> 16);
|
||||
@ -1293,7 +1288,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
||||
// input = input >> 16;
|
||||
t = _mm256_srli_epi32(t, 16);
|
||||
// Check NaN before converting back to bf16
|
||||
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
|
||||
__m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
|
||||
__m256i nan = _mm256_set1_epi32(0x7fc0);
|
||||
t = _mm256_blendv_epi8(nan, t, _mm256_castps_si256(mask));
|
||||
// output = numext::bit_cast<uint16_t>(input);
|
||||
@ -1316,7 +1311,7 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(const Packet8f& a) {
|
||||
lo = _mm_srli_epi32(lo, 16);
|
||||
hi = _mm_srli_epi32(hi, 16);
|
||||
// Check NaN before converting back to bf16
|
||||
__m256 mask = _mm256_cmp_ps(flush, flush, _CMP_ORD_Q);
|
||||
__m256 mask = _mm256_cmp_ps(a, a, _CMP_ORD_Q);
|
||||
__m128i nan = _mm_set1_epi32(0x7fc0);
|
||||
lo = _mm_blendv_epi8(nan, lo, _mm_castps_si128(_mm256_castps256_ps128(mask)));
|
||||
hi = _mm_blendv_epi8(nan, hi, _mm_castps_si128(_mm256_extractf128_ps(mask, 1)));
|
||||
|
@ -1945,23 +1945,15 @@ EIGEN_STRONG_INLINE Packet16f Bf16ToF32(const Packet16bf& a) {
|
||||
EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||
Packet16bf r;
|
||||
|
||||
// Flush input denormals value to zero with hardware capability.
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
||||
#if defined(EIGEN_VECTORIZE_AVX512DQ)
|
||||
__m512 flush = _mm512_and_ps(a, a);
|
||||
#else
|
||||
__m512 flush = _mm512_max_ps(a, a);
|
||||
#endif // EIGEN_VECTORIZE_AVX512DQ
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_OFF);
|
||||
|
||||
#if defined(EIGEN_VECTORIZE_AVX512BF16) && EIGEN_GNUC_AT_LEAST(10, 1)
|
||||
// Since GCC 10.1 supports avx512bf16 and C style explicit cast
|
||||
// (C++ static_cast is not supported yet), do converion via intrinsic
|
||||
// and register path for performance.
|
||||
r = (__m256i)(_mm512_cvtneps_pbh(flush));
|
||||
r = (__m256i)(_mm512_cvtneps_pbh(a));
|
||||
|
||||
#else
|
||||
__m512i t;
|
||||
__m512i input = _mm512_castps_si512(flush);
|
||||
__m512i input = _mm512_castps_si512(a);
|
||||
__m512i nan = _mm512_set1_epi32(0x7fc0);
|
||||
|
||||
// uint32_t lsb = (input >> 16) & 1;
|
||||
@ -1974,9 +1966,9 @@ EIGEN_STRONG_INLINE Packet16bf F32ToBf16(const Packet16f& a) {
|
||||
t = _mm512_srli_epi32(t, 16);
|
||||
|
||||
// Check NaN before converting back to bf16
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(flush, flush, _CMP_ORD_Q);
|
||||
t = _mm512_mask_blend_epi32(mask, nan, t);
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q);
|
||||
|
||||
t = _mm512_mask_blend_epi32(mask, nan, t);
|
||||
// output.value = static_cast<uint16_t>(input);
|
||||
r = _mm512_cvtepi32_epi16(t);
|
||||
#endif // EIGEN_VECTORIZE_AVX512BF16
|
||||
|
@ -250,10 +250,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw truncate_to_bfloat16(const
|
||||
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(v)) {
|
||||
output.value = std::signbit(v) ? 0xFFC0: 0x7FC0;
|
||||
return output;
|
||||
} else if (std::fabs(v) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
|
||||
// Flush denormal to +/- 0.
|
||||
output.value = std::signbit(v) ? 0x8000 : 0;
|
||||
return output;
|
||||
}
|
||||
const uint16_t* p = reinterpret_cast<const uint16_t*>(&v);
|
||||
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
|
||||
@ -288,9 +284,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<fals
|
||||
// qNaN magic: All exponent bits set + most significant bit of fraction
|
||||
// set.
|
||||
output.value = std::signbit(ff) ? 0xFFC0: 0x7FC0;
|
||||
} else if (std::fabs(ff) < std::numeric_limits<float>::min EIGEN_NOT_A_MACRO()) {
|
||||
// Flush denormal to +/- 0.0
|
||||
output.value = std::signbit(ff) ? 0x8000 : 0;
|
||||
} else {
|
||||
// Fast rounding algorithm that rounds a half value to nearest even. This
|
||||
// reduces expected error when we convert a large number of floats. Here
|
||||
|
@ -32,18 +32,6 @@ float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
|
||||
return dest;
|
||||
}
|
||||
|
||||
void test_truncate(float input, float expected_truncation, float expected_rounding){
|
||||
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
|
||||
bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(input);
|
||||
if ((numext::isnan)(input)){
|
||||
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
|
||||
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
|
||||
return;
|
||||
}
|
||||
VERIFY_IS_EQUAL(expected_truncation, static_cast<float>(truncated));
|
||||
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void test_roundtrip() {
|
||||
// Representable T round trip via bfloat16
|
||||
@ -122,31 +110,6 @@ void test_conversion()
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(0.0f), 0x0000);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(-0.0f), 0x8000);
|
||||
|
||||
// Flush denormals to zero
|
||||
for (float denorm = -std::numeric_limits<float>::denorm_min();
|
||||
denorm < std::numeric_limits<float>::denorm_min();
|
||||
denorm = nextafterf(denorm, 1.0f)) {
|
||||
bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f);
|
||||
|
||||
// Implicit conversion of denormls to bool is correct
|
||||
VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(denorm)), false);
|
||||
VERIFY_IS_EQUAL(bfloat16(denorm), false);
|
||||
|
||||
if (std::signbit(denorm)) {
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x8000);
|
||||
} else {
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_trunc, 0x0000);
|
||||
}
|
||||
bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(denorm);
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
|
||||
if (std::signbit(denorm)) {
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x8000);
|
||||
} else {
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bf_round, 0x0000);
|
||||
}
|
||||
}
|
||||
|
||||
// Default is zero
|
||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
|
||||
|
||||
@ -156,52 +119,6 @@ void test_conversion()
|
||||
test_roundtrip<std::complex<float> >();
|
||||
test_roundtrip<std::complex<double> >();
|
||||
|
||||
// Truncate test
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0xf5c3),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x49, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(1, 0x80, 0x48, 0xf5c3),
|
||||
BinaryToFloat(1, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(1, 0x80, 0x49, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x8000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0xff, 0x00, 0x0001),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0xff, 0x7f, 0xffff),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000),
|
||||
BinaryToFloat(0, 0xff, 0x40, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(1, 0x80, 0x48, 0xc000),
|
||||
BinaryToFloat(1, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(1, 0x80, 0x49, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x4000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x8000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000),
|
||||
BinaryToFloat(0, 0x80, 0x48, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x00, 0x48, 0x8000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000));
|
||||
test_truncate(
|
||||
BinaryToFloat(0, 0x00, 0x7f, 0xc000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000),
|
||||
BinaryToFloat(0, 0x00, 0x00, 0x0000));
|
||||
|
||||
// Conversion
|
||||
Array<float,1,100> a;
|
||||
for (int i = 0; i < 100; i++) a(i) = i + 1.25;
|
||||
@ -250,12 +167,6 @@ void test_conversion()
|
||||
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x0, 0xff, 0x40, 0x0)), 0x7fc0);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(bfloat16(BinaryToFloat(0x1, 0xff, 0x40, 0x0)), 0xffc0);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
|
||||
BinaryToFloat(0x0, 0xff, 0x40, 0x0)),
|
||||
0x7fc0);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(Eigen::bfloat16_impl::truncate_to_bfloat16(
|
||||
BinaryToFloat(0x1, 0xff, 0x40, 0x0)),
|
||||
0xffc0);
|
||||
}
|
||||
|
||||
void test_numtraits()
|
||||
|
Loading…
x
Reference in New Issue
Block a user