mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-30 07:44:10 +08:00
Faster conversion from integer types to bfloat16
Specialized `bfloat16_impl::float_to_bfloat16_rtne(float)` for normal floating point numbers, infinity and zero, in order to improve the performance of `bfloat16::bfloat16(const T&)` for integer argument types. A reduction of more than 20% of the runtime duration of conversion from int to bfloat16 was observed, using Visual C++ 2019 on Windows 10.
This commit is contained in:
parent
acab22c205
commit
0e1a33a461
@ -58,7 +58,14 @@ struct __bfloat16_raw {
|
|||||||
};
|
};
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsigned short value);
|
||||||
|
template <bool AssumeArgumentIsNormalOrInfinityOrZero>
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff);
|
||||||
|
// Forward declarations of template specializations, to avoid Visual C++ 2019 errors, saying:
|
||||||
|
// > error C2908: explicit specialization; 'float_to_bfloat16_rtne' has already been instantiated
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff);
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff);
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float bfloat16_to_float(__bfloat16_raw h);
|
||||||
|
|
||||||
struct bfloat16_base : public __bfloat16_raw {
|
struct bfloat16_base : public __bfloat16_raw {
|
||||||
@ -81,14 +88,14 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base {
|
|||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
|
||||||
template<class T>
|
template<class T>
|
||||||
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
|
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
|
||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val))) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
|
||||||
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
|
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
|
||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(f)) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
|
||||||
// Following the convention of numpy, converting between complex and
|
// Following the convention of numpy, converting between complex and
|
||||||
// float will lead to loss of imag value.
|
// float will lead to loss of imag value.
|
||||||
template<typename RealScalar>
|
template<typename RealScalar>
|
||||||
explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val)
|
explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val)
|
||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne(static_cast<float>(val.real()))) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
||||||
// +0.0 and -0.0 become false, everything else becomes true.
|
// +0.0 and -0.0 become false, everything else becomes true.
|
||||||
@ -326,11 +333,13 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw raw_uint16_to_bfloat16(unsi
|
|||||||
return h;
|
return h;
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(float ff) {
|
// float_to_bfloat16_rtne template specialization that does not make any
|
||||||
|
// assumption about the value of its function argument (ff).
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<false>(float ff) {
|
||||||
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
|
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
|
||||||
// Nothing to do here
|
// Nothing to do here
|
||||||
#else
|
#else
|
||||||
unsigned int input = numext::as_uint(ff);
|
|
||||||
__bfloat16_raw output;
|
__bfloat16_raw output;
|
||||||
|
|
||||||
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
|
if (Eigen::numext::isnan EIGEN_NOT_A_MACRO(ff)) {
|
||||||
@ -491,14 +500,31 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne(floa
|
|||||||
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
// Sign | Exp (8 bit) | Frac (first 7 bit)
|
||||||
// S E E E E E E E E F F F F F F L
|
// S E E E E E E E E F F F F F F L
|
||||||
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
// 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0
|
||||||
//
|
|
||||||
//
|
// At this point, ff must be either a normal float, or +/-infinity.
|
||||||
|
output = float_to_bfloat16_rtne<true>(ff);
|
||||||
|
}
|
||||||
|
return output;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// float_to_bfloat16_rtne template specialization that assumes that its function
|
||||||
|
// argument (ff) is either a normal floating point number, or +/-infinity, or
|
||||||
|
// zero. Used to improve the runtime performance of conversion from an integer
|
||||||
|
// type to bfloat16.
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __bfloat16_raw float_to_bfloat16_rtne<true>(float ff) {
|
||||||
|
#if (defined(EIGEN_HAS_CUDA_BF16) && defined(EIGEN_HAS_HIP_BF16))
|
||||||
|
// Nothing to do here
|
||||||
|
#else
|
||||||
|
unsigned int input = numext::as_uint(ff);
|
||||||
|
__bfloat16_raw output;
|
||||||
|
|
||||||
// Least significant bit of resulting bfloat.
|
// Least significant bit of resulting bfloat.
|
||||||
unsigned int lsb = (input >> 16) & 1;
|
unsigned int lsb = (input >> 16) & 1;
|
||||||
unsigned int rounding_bias = 0x7fff + lsb;
|
unsigned int rounding_bias = 0x7fff + lsb;
|
||||||
input += rounding_bias;
|
input += rounding_bias;
|
||||||
output.value = static_cast<unsigned short>(input >> 16);
|
output.value = static_cast<unsigned short>(input >> 16);
|
||||||
}
|
|
||||||
return output;
|
return output;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ float BinaryToFloat(uint32_t sign, uint32_t exponent, uint32_t high_mantissa,
|
|||||||
|
|
||||||
void test_truncate(float input, float expected_truncation, float expected_rounding){
|
void test_truncate(float input, float expected_truncation, float expected_rounding){
|
||||||
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
|
bfloat16 truncated = Eigen::bfloat16_impl::truncate_to_bfloat16(input);
|
||||||
bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne(input);
|
bfloat16 rounded = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(input);
|
||||||
if ((numext::isnan)(input)){
|
if ((numext::isnan)(input)){
|
||||||
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
|
VERIFY((numext::isnan)(static_cast<float>(truncated)) || (numext::isinf)(static_cast<float>(truncated)));
|
||||||
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
|
VERIFY((numext::isnan)(static_cast<float>(rounded)) || (numext::isinf)(static_cast<float>(rounded)));
|
||||||
@ -93,7 +93,7 @@ void test_conversion()
|
|||||||
} else {
|
} else {
|
||||||
VERIFY_IS_EQUAL(bf_trunc.value, 0x0000);
|
VERIFY_IS_EQUAL(bf_trunc.value, 0x0000);
|
||||||
}
|
}
|
||||||
bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne(denorm);
|
bfloat16 bf_round = Eigen::bfloat16_impl::float_to_bfloat16_rtne<false>(denorm);
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
|
VERIFY_IS_EQUAL(static_cast<float>(bf_round), 0.0f);
|
||||||
if (std::signbit(denorm)) {
|
if (std::signbit(denorm)) {
|
||||||
VERIFY_IS_EQUAL(bf_round.value, 0x8000);
|
VERIFY_IS_EQUAL(bf_round.value, 0x8000);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user