diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index b9f886fee..b8d3b4f69 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -763,6 +763,31 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast(from); + bool from_sign = from_bits >> 15; + // Whether we are adjusting toward the infinity with the same sign as from. + bool toward_inf = (to > from) == !from_sign; + if (toward_inf) { + ++from_bits; + } else if ((from_bits & 0x7fff) == 0) { + // Adjusting away from inf, but from is zero, so just toggle the sign. + from_bits ^= 0x8000; + } else { + --from_bits; + } + return numext::bit_cast(from_bits); +} + } // namespace numext } // namespace Eigen diff --git a/test/bfloat16_float.cpp b/test/bfloat16_float.cpp index 5be49d910..12d6d86e3 100644 --- a/test/bfloat16_float.cpp +++ b/test/bfloat16_float.cpp @@ -353,6 +353,40 @@ void test_product() { VERIFY_IS_APPROX(Ch.noalias() += Ah * Bh, (Cf.noalias() += Af * Bf).cast()); } +void test_nextafter() { + VERIFY((numext::isnan)(numext::nextafter(std::numeric_limits::quiet_NaN(), bfloat16(1.0f)))); + VERIFY((numext::isnan)(numext::nextafter(bfloat16(1.0f), std::numeric_limits::quiet_NaN()))); + VERIFY(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)) == bfloat16(0.0f)); + VERIFY(numext::nextafter(bfloat16(1.0f), bfloat16(1.0f)) == bfloat16(1.0f)); + VERIFY(numext::nextafter(bfloat16(-1.0f), bfloat16(-1.0f)) == bfloat16(-1.0f)); + VERIFY(numext::nextafter(std::numeric_limits::infinity(), std::numeric_limits::infinity()) == + std::numeric_limits::infinity()); + VERIFY(numext::nextafter(std::numeric_limits::infinity(), bfloat16(0.0f)) == + (std::numeric_limits::max)()); + VERIFY(numext::nextafter(-std::numeric_limits::infinity(), bfloat16(0.0f)) == + -(std::numeric_limits::max)()); + VERIFY(numext::nextafter(bfloat16(1.0f), std::numeric_limits::infinity()) == + bfloat16(1.0f) + std::numeric_limits::epsilon()); + VERIFY(numext::nextafter(bfloat16(1.0f), -std::numeric_limits::infinity()) == + bfloat16(1.0f) - std::numeric_limits::epsilon() / bfloat16(2.0f)); + VERIFY(numext::nextafter(bfloat16(-1.0f), -std::numeric_limits::infinity()) == + bfloat16(-1.0f) - std::numeric_limits::epsilon()); + VERIFY(numext::nextafter(bfloat16(-1.0f), std::numeric_limits::infinity()) == + bfloat16(-1.0f) + std::numeric_limits::epsilon() / bfloat16(2.0f)); + VERIFY(numext::nextafter((std::numeric_limits::max)(), std::numeric_limits::infinity()) == + std::numeric_limits::infinity()); + VERIFY(numext::nextafter(-(std::numeric_limits::max)(), -std::numeric_limits::infinity()) == + -std::numeric_limits::infinity()); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(1.0f)), 0x0001); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(1.0f)), 0x0000); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-1.0f)), 0x8000); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-1.0f)), 0x8001); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-0.0f)), 0x8000); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(0.0f)), 0x0000); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)), 0x0000); + VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-0.0f)), 0x8000); +} + EIGEN_DECLARE_TEST(bfloat16_float) { CALL_SUBTEST(test_numtraits()); for (int i = 0; i < g_repeat; i++) { @@ -363,5 +397,6 @@ EIGEN_DECLARE_TEST(bfloat16_float) { CALL_SUBTEST(test_trigonometric_functions()); CALL_SUBTEST(test_array()); CALL_SUBTEST(test_product()); + CALL_SUBTEST(test_nextafter()); } }