mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-19 08:09:36 +08:00
Revert change that made conversion from bfloat16 to {float, double} implicit.
Add roundtrip tests for casting between bfloat16 and complex types.
This commit is contained in:
parent
38b91f256b
commit
1b84f21e32
@ -34,8 +34,9 @@ namespace Eigen {
|
|||||||
|
|
||||||
struct bfloat16;
|
struct bfloat16;
|
||||||
|
|
||||||
// Since we allow implicit conversion of bfloat16 to float and double, we
|
// explicit conversion operators are no available before C++11 so we first cast
|
||||||
// need to make the cast to complex a bit more explicit
|
// bfloat16 to RealScalar rather than to std::complex<RealScalar> directly
|
||||||
|
#if !EIGEN_HAS_CXX11
|
||||||
namespace internal {
|
namespace internal {
|
||||||
template <typename RealScalar>
|
template <typename RealScalar>
|
||||||
struct cast_impl<bfloat16, std::complex<RealScalar> > {
|
struct cast_impl<bfloat16, std::complex<RealScalar> > {
|
||||||
@ -45,6 +46,7 @@ struct cast_impl<bfloat16, std::complex<RealScalar> > {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
#endif // EIGEN_HAS_CXX11
|
||||||
|
|
||||||
namespace bfloat16_impl {
|
namespace bfloat16_impl {
|
||||||
|
|
||||||
@ -129,10 +131,10 @@ struct bfloat16 : public bfloat16_impl::bfloat16_base {
|
|||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const {
|
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned long long) const {
|
||||||
return static_cast<unsigned long long>(bfloat16_to_float(*this));
|
return static_cast<unsigned long long>(bfloat16_to_float(*this));
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC operator float() const {
|
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
|
||||||
return bfloat16_impl::bfloat16_to_float(*this);
|
return bfloat16_impl::bfloat16_to_float(*this);
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC operator double() const {
|
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(double) const {
|
||||||
return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this));
|
return static_cast<double>(bfloat16_impl::bfloat16_to_float(*this));
|
||||||
}
|
}
|
||||||
template<typename RealScalar>
|
template<typename RealScalar>
|
||||||
|
@ -41,6 +41,19 @@ void test_truncate(float input, float expected_truncation, float expected_roundi
|
|||||||
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
|
VERIFY_IS_EQUAL(expected_rounding, static_cast<float>(rounded));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void test_roundtrip() {
|
||||||
|
// Representable T round trip via bfloat16
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(-std::numeric_limits<T>::infinity())), -std::numeric_limits<T>::infinity());
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(std::numeric_limits<T>::infinity())), std::numeric_limits<T>::infinity());
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-1.0))), T(-1.0));
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.5))), T(-0.5));
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(-0.0))), T(-0.0));
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(1.0))), T(1.0));
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.5))), T(0.5));
|
||||||
|
VERIFY_IS_EQUAL(static_cast<T>(static_cast<bfloat16>(T(0.0))), T(0.0));
|
||||||
|
}
|
||||||
|
|
||||||
void test_conversion()
|
void test_conversion()
|
||||||
{
|
{
|
||||||
using Eigen::bfloat16_impl::__bfloat16_raw;
|
using Eigen::bfloat16_impl::__bfloat16_raw;
|
||||||
@ -53,9 +66,9 @@ void test_conversion()
|
|||||||
VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity.
|
VERIFY_IS_EQUAL(bfloat16(3.40e38f).value, 0x7f80); // Becomes infinity.
|
||||||
|
|
||||||
// Verify round-to-nearest-even behavior.
|
// Verify round-to-nearest-even behavior.
|
||||||
float val1 = bfloat16(__bfloat16_raw(0x3c00));
|
float val1 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c00)));
|
||||||
float val2 = bfloat16(__bfloat16_raw(0x3c01));
|
float val2 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c01)));
|
||||||
float val3 = bfloat16(__bfloat16_raw(0x3c02));
|
float val3 = static_cast<float>(bfloat16(__bfloat16_raw(0x3c02)));
|
||||||
VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00);
|
VERIFY_IS_EQUAL(bfloat16(0.5f * (val1 + val2)).value, 0x3c00);
|
||||||
VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02);
|
VERIFY_IS_EQUAL(bfloat16(0.5f * (val2 + val3)).value, 0x3c02);
|
||||||
|
|
||||||
@ -106,14 +119,10 @@ void test_conversion()
|
|||||||
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
|
VERIFY_IS_EQUAL(static_cast<float>(bfloat16()), 0.0f);
|
||||||
|
|
||||||
// Representable floats round trip via bfloat16
|
// Representable floats round trip via bfloat16
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-std::numeric_limits<float>::infinity())), -std::numeric_limits<float>::infinity());
|
test_roundtrip<float>();
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(std::numeric_limits<float>::infinity())), std::numeric_limits<float>::infinity());
|
test_roundtrip<double>();
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-1.0f)), -1.0f);
|
test_roundtrip<std::complex<float> >();
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.5f)), -0.5f);
|
test_roundtrip<std::complex<double> >();
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(-0.0f)), -0.0f);
|
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(1.0f)), 1.0f);
|
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.5f)), 0.5f);
|
|
||||||
VERIFY_IS_EQUAL(static_cast<float>(static_cast<bfloat16>(0.0f)), 0.0f);
|
|
||||||
|
|
||||||
// Truncate test
|
// Truncate test
|
||||||
test_truncate(
|
test_truncate(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user