From 1b84f21e321e9daa1efcd4422ae92c1782c5582c Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Wed, 22 Jul 2020 18:09:00 -0700 Subject: [PATCH] Revert change that made conversion from bfloat16 to {float, double} implicit. Add roundtrip tests for casting between bfloat16 and complex types. --- Eigen/src/Core/arch/Default/BFloat16.h | 10 +++++---- test/bfloat16_float.cpp | 31 +++++++++++++++++--------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index f9c6e76a9..30c998249 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -34,8 +34,9 @@ namespace Eigen { struct bfloat16; -// Since we allow implicit conversion of bfloat16 to float and double, we -// need to make the cast to complex a bit more explicit +// explicit conversion operators are no available before C++11 so we first cast +// bfloat16 to RealScalar rather than to std::complex directly +#if !EIGEN_HAS_CXX11 namespace internal { template struct cast_impl > { @@ -45,6 +46,7 @@ struct cast_impl > { } }; } // namespace internal +#endif // EIGEN_HAS_CXX11 namespace bfloat16_impl { @@ -129,10 +131,10 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base { EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const { return static_cast(bfloat16_to_float(*this)); } - EIGEN_DEVICE_FUNC operator float() const { + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const { return bfloat16_impl::bfloat16_to_float(*this); } - EIGEN_DEVICE_FUNC operator double() const { + EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const { return static_cast(bfloat16_impl::bfloat16_to_float(*this)); } template diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp index 96341929a..11fc31363 100644 --- a/test/bfloat16_float.cpp +++ b/test/bfloat16_float.cpp @@ -41,6 +41,19 @@ void test_truncate(float input, float expected_truncation, float expected_roundi VERIFY_IS_EQUAL(expected_rounding, static_cast(rounded)); } +template + void test_roundtrip() { + // Representable T round trip via bfloat16 + VERIFY_IS_EQUAL(static_cast(static_cast(-std::numeric_limits::infinity())), -std::numeric_limits::infinity()); + VERIFY_IS_EQUAL(static_cast(static_cast(std::numeric_limits::infinity())), std::numeric_limits::infinity()); + VERIFY_IS_EQUAL(static_cast(static_cast(T(-1.0))), T(-1.0)); + VERIFY_IS_EQUAL(static_cast(static_cast(T(-0.5))), T(-0.5)); + VERIFY_IS_EQUAL(static_cast(static_cast(T(-0.0))), T(-0.0)); + VERIFY_IS_EQUAL(static_cast(static_cast(T(1.0))), T(1.0)); + VERIFY_IS_EQUAL(static_cast(static_cast(T(0.5))), T(0.5)); + VERIFY_IS_EQUAL(static_cast(static_cast(T(0.0))), T(0.0)); +} + void test_conversion() { using Eigen::bfloat16_impl::__bfloat16_raw; @@ -53,9 +66,9 @@ void test_conversion() VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity. // Verify round-to-nearest-even behavior. - float val1 = bfloat16(__bfloat16_raw(0x3c00)); - float val2 = bfloat16(__bfloat16_raw(0x3c01)); - float val3 = bfloat16(__bfloat16_raw(0x3c02)); + float val1 = static_cast(bfloat16(__bfloat16_raw(0x3c00))); + float val2 = static_cast(bfloat16(__bfloat16_raw(0x3c01))); + float val3 = static_cast(bfloat16(__bfloat16_raw(0x3c02))); VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00); VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02); @@ -106,14 +119,10 @@ void test_conversion() VERIFY_IS_EQUAL(static_cast(bfloat16()), 0.0f); // Representable floats round trip via bfloat16 - VERIFY_IS_EQUAL(static_cast(static_cast(-std::numeric_limits::infinity())), -std::numeric_limits::infinity()); - VERIFY_IS_EQUAL(static_cast(static_cast(std::numeric_limits::infinity())), std::numeric_limits::infinity()); - VERIFY_IS_EQUAL(static_cast(static_cast(-1.0f)), -1.0f); - VERIFY_IS_EQUAL(static_cast(static_cast(-0.5f)), -0.5f); - VERIFY_IS_EQUAL(static_cast(static_cast(-0.0f)), -0.0f); - VERIFY_IS_EQUAL(static_cast(static_cast(1.0f)), 1.0f); - VERIFY_IS_EQUAL(static_cast(static_cast(0.5f)), 0.5f); - VERIFY_IS_EQUAL(static_cast(static_cast(0.0f)), 0.0f); + test_roundtrip(); + test_roundtrip(); + test_roundtrip >(); + test_roundtrip >(); // Truncate test test_truncate(