mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Fix bfloat16 casts
If we have explicit conversion operators available (C++11) we define explicit casts from bfloat16 to other types. If not (C++03), we don't define conversion operators but rely on implicit conversion chains from bfloat16 over float to other types.
This commit is contained in:
parent
2ce2f51989
commit
c1ffe452fc
@ -13,16 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
|
||||||
#ifndef EIGEN_BFLOAT16_H
|
#ifndef EIGEN_BFLOAT16_H
|
||||||
#define EIGEN_BFLOAT16_H
|
#define EIGEN_BFLOAT16_H
|
||||||
|
|
||||||
#if __cplusplus > 199711L
|
|
||||||
#define EIGEN_EXPLICIT_CAST(tgt_type) explicit operator tgt_type()
|
|
||||||
#else
|
|
||||||
#define EIGEN_EXPLICIT_CAST(tgt_type) operator tgt_type()
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
|
#define BF16_PACKET_FUNCTION(PACKET_F, PACKET_BF16, METHOD) \
|
||||||
template <> \
|
template <> \
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED \
|
||||||
@ -34,20 +27,6 @@ namespace Eigen {
|
|||||||
|
|
||||||
struct bfloat16;
|
struct bfloat16;
|
||||||
|
|
||||||
// explicit conversion operators are no available before C++11 so we first cast
|
|
||||||
// bfloat16 to RealScalar rather than to std::complex<RealScalar> directly
|
|
||||||
#if !EIGEN_HAS_CXX11
|
|
||||||
namespace internal {
|
|
||||||
template <typename RealScalar>
|
|
||||||
struct cast_impl<bfloat16, std::complex<RealScalar> > {
|
|
||||||
EIGEN_DEVICE_FUNC static inline std::complex<RealScalar> run(const bfloat16 &x)
|
|
||||||
{
|
|
||||||
return static_cast<std::complex<RealScalar> >(static_cast<RealScalar>(x));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // namespace internal
|
|
||||||
#endif // EIGEN_HAS_CXX11
|
|
||||||
|
|
||||||
namespace bfloat16_impl {
|
namespace bfloat16_impl {
|
||||||
|
|
||||||
// Make our own __bfloat16_raw definition.
|
// Make our own __bfloat16_raw definition.
|
||||||
@ -86,66 +65,32 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base {
|
|||||||
|
|
||||||
explicit EIGEN_DEVICE_FUNC bfloat16(bool b)
|
explicit EIGEN_DEVICE_FUNC bfloat16(bool b)
|
||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::raw_uint16_to_bfloat16(b ? 0x3f80 : 0)) {}
|
||||||
|
|
||||||
template<class T>
|
template<class T>
|
||||||
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
|
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
|
||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<internal::is_integral<T>::value>(static_cast<float>(val))) {}
|
||||||
|
|
||||||
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
|
explicit EIGEN_DEVICE_FUNC bfloat16(float f)
|
||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(f)) {}
|
||||||
|
|
||||||
// Following the convention of numpy, converting between complex and
|
// Following the convention of numpy, converting between complex and
|
||||||
// float will lead to loss of imag value.
|
// float will lead to loss of imag value.
|
||||||
template<typename RealScalar>
|
template<typename RealScalar>
|
||||||
explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val)
|
explicit EIGEN_DEVICE_FUNC bfloat16(const std::complex<RealScalar>& val)
|
||||||
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
|
: bfloat16_impl::bfloat16_base(bfloat16_impl::float_to_bfloat16_rtne<false>(static_cast<float>(val.real()))) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC operator float() const {
|
||||||
|
return bfloat16_impl::bfloat16_to_float(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if EIGEN_HAS_CXX11
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
||||||
// +0.0 and -0.0 become false, everything else becomes true.
|
// +0.0 and -0.0 become false, everything else becomes true.
|
||||||
return (value & 0x7fff) != 0;
|
return (value & 0x7fff) != 0;
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const {
|
#endif
|
||||||
return static_cast<signed char>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned char) const {
|
|
||||||
return static_cast<unsigned char>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const {
|
|
||||||
return static_cast<short>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const {
|
|
||||||
return static_cast<unsigned short>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const {
|
|
||||||
return static_cast<int>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned int) const {
|
|
||||||
return static_cast<unsigned int>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long) const {
|
|
||||||
return static_cast<long>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long) const {
|
|
||||||
return static_cast<unsigned long>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(long long) const {
|
|
||||||
return static_cast<long long>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const {
|
|
||||||
return static_cast<unsigned long long>(bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
|
|
||||||
return bfloat16_impl::bfloat16_to_float(*this);
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
|
|
||||||
return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
template<typename RealScalar>
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(std::complex<RealScalar>) const {
|
|
||||||
return std::complex<RealScalar>(static_cast<RealScalar>(bfloat16_impl::bfloat16_to_float(*this)), RealScalar(0));
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(Eigen::half) const {
|
|
||||||
return static_cast<Eigen::half>(bfloat16_impl::bfloat16_to_float(*this));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
|
};
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
|
@ -84,10 +84,20 @@ void test_conversion()
|
|||||||
VERIFY_IS_EQUAL(bfloat16(false).value, 0x0000);
|
VERIFY_IS_EQUAL(bfloat16(false).value, 0x0000);
|
||||||
VERIFY_IS_EQUAL(bfloat16(true).value, 0x3f80);
|
VERIFY_IS_EQUAL(bfloat16(true).value, 0x3f80);
|
||||||
|
|
||||||
// Conversion to float.
|
// Conversion to bool
|
||||||
|
VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(3)), true);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.33333f)), true);
|
||||||
|
VERIFY_IS_EQUAL(bfloat16(-0.0), false);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(0.0)), false);
|
||||||
|
|
||||||
|
// Explicit conversion to float.
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x0000))), 0.0f);
|
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x0000))), 0.0f);
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x3f80))), 1.0f);
|
VERIFY_IS_EQUAL(static_cast<float>(bfloat16(__bfloat16_raw(0x3f80))), 1.0f);
|
||||||
|
|
||||||
|
// Implicit conversion to float
|
||||||
|
VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x0000)), 0.0f);
|
||||||
|
VERIFY_IS_EQUAL(bfloat16(__bfloat16_raw(0x3f80)), 1.0f);
|
||||||
|
|
||||||
// Zero representations
|
// Zero representations
|
||||||
VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f));
|
VERIFY_IS_EQUAL(bfloat16(0.0f), bfloat16(0.0f));
|
||||||
VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f));
|
VERIFY_IS_EQUAL(bfloat16(-0.0f), bfloat16(0.0f));
|
||||||
@ -101,6 +111,11 @@ void test_conversion()
|
|||||||
denorm = nextafterf(denorm, 1.0f)) {
|
denorm = nextafterf(denorm, 1.0f)) {
|
||||||
bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm);
|
bfloat16 bf_trunc = Eigen::bfloat16_impl::truncate_to_bfloat16(denorm);
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f);
|
VERIFY_IS_EQUAL(static_cast<float>(bf_trunc), 0.0f);
|
||||||
|
|
||||||
|
// Implicit conversion of denormls to bool is correct
|
||||||
|
VERIFY_IS_EQUAL(static_cast<bool>(bfloat16(denorm)), false);
|
||||||
|
VERIFY_IS_EQUAL(bfloat16(denorm), false);
|
||||||
|
|
||||||
if (std::signbit(denorm)) {
|
if (std::signbit(denorm)) {
|
||||||
VERIFY_IS_EQUAL(bf_trunc.value, 0x8000);
|
VERIFY_IS_EQUAL(bf_trunc.value, 0x8000);
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user