add nextafter for bfloat16

This commit is contained in:
Peter Gavin 2024-10-21 21:23:41 +00:00
parent 53b83cddf9
commit b15ebb1c2d
2 changed files with 60 additions and 0 deletions

View File

@ -763,6 +763,31 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat1
return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, const bfloat16& to) {
if (numext::isnan EIGEN_NOT_A_MACRO(from)) {
return from;
}
if (numext::isnan EIGEN_NOT_A_MACRO(to)) {
return to;
}
if (from == to) {
return to;
}
uint16_t from_bits = numext::bit_cast<uint16_t>(from);
bool from_sign = from_bits >> 15;
// Whether we are adjusting toward the infinity with the same sign as from.
bool toward_inf = (to > from) == !from_sign;
if (toward_inf) {
++from_bits;
} else if ((from_bits & 0x7fff) == 0) {
// Adjusting away from inf, but from is zero, so just toggle the sign.
from_bits ^= 0x8000;
} else {
--from_bits;
}
return numext::bit_cast<bfloat16>(from_bits);
}
} // namespace numext
} // namespace Eigen

View File

@ -353,6 +353,40 @@ void test_product() {
VERIFY_IS_APPROX(Ch.noalias() += Ah * Bh, (Cf.noalias() += Af * Bf).cast<bfloat16>());
}
void test_nextafter() {
VERIFY((numext::isnan)(numext::nextafter(std::numeric_limits<bfloat16>::quiet_NaN(), bfloat16(1.0f))));
VERIFY((numext::isnan)(numext::nextafter(bfloat16(1.0f), std::numeric_limits<bfloat16>::quiet_NaN())));
VERIFY(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)) == bfloat16(0.0f));
VERIFY(numext::nextafter(bfloat16(1.0f), bfloat16(1.0f)) == bfloat16(1.0f));
VERIFY(numext::nextafter(bfloat16(-1.0f), bfloat16(-1.0f)) == bfloat16(-1.0f));
VERIFY(numext::nextafter(std::numeric_limits<bfloat16>::infinity(), std::numeric_limits<bfloat16>::infinity()) ==
std::numeric_limits<bfloat16>::infinity());
VERIFY(numext::nextafter(std::numeric_limits<bfloat16>::infinity(), bfloat16(0.0f)) ==
(std::numeric_limits<bfloat16>::max)());
VERIFY(numext::nextafter(-std::numeric_limits<bfloat16>::infinity(), bfloat16(0.0f)) ==
-(std::numeric_limits<bfloat16>::max)());
VERIFY(numext::nextafter(bfloat16(1.0f), std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(1.0f) + std::numeric_limits<bfloat16>::epsilon());
VERIFY(numext::nextafter(bfloat16(1.0f), -std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(1.0f) - std::numeric_limits<bfloat16>::epsilon() / bfloat16(2.0f));
VERIFY(numext::nextafter(bfloat16(-1.0f), -std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(-1.0f) - std::numeric_limits<bfloat16>::epsilon());
VERIFY(numext::nextafter(bfloat16(-1.0f), std::numeric_limits<bfloat16>::infinity()) ==
bfloat16(-1.0f) + std::numeric_limits<bfloat16>::epsilon() / bfloat16(2.0f));
VERIFY(numext::nextafter((std::numeric_limits<bfloat16>::max)(), std::numeric_limits<bfloat16>::infinity()) ==
std::numeric_limits<bfloat16>::infinity());
VERIFY(numext::nextafter(-(std::numeric_limits<bfloat16>::max)(), -std::numeric_limits<bfloat16>::infinity()) ==
-std::numeric_limits<bfloat16>::infinity());
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(1.0f)), 0x0001);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(1.0f)), 0x0000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-1.0f)), 0x8000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-1.0f)), 0x8001);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-0.0f)), 0x8000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(0.0f)), 0x0000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)), 0x0000);
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-0.0f)), 0x8000);
}
EIGEN_DECLARE_TEST(bfloat16_float) {
CALL_SUBTEST(test_numtraits());
for (int i = 0; i < g_repeat; i++) {
@ -363,5 +397,6 @@ EIGEN_DECLARE_TEST(bfloat16_float) {
CALL_SUBTEST(test_trigonometric_functions());
CALL_SUBTEST(test_array());
CALL_SUBTEST(test_product());
CALL_SUBTEST(test_nextafter());
}
}