mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-01 14:10:52 +08:00
add nextafter for bfloat16
This commit is contained in:
parent
53b83cddf9
commit
b15ebb1c2d
@ -763,6 +763,31 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::bfloat1
|
||||
return Eigen::bfloat16_impl::raw_bfloat16_as_uint16(src);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, const bfloat16& to) {
|
||||
if (numext::isnan EIGEN_NOT_A_MACRO(from)) {
|
||||
return from;
|
||||
}
|
||||
if (numext::isnan EIGEN_NOT_A_MACRO(to)) {
|
||||
return to;
|
||||
}
|
||||
if (from == to) {
|
||||
return to;
|
||||
}
|
||||
uint16_t from_bits = numext::bit_cast<uint16_t>(from);
|
||||
bool from_sign = from_bits >> 15;
|
||||
// Whether we are adjusting toward the infinity with the same sign as from.
|
||||
bool toward_inf = (to > from) == !from_sign;
|
||||
if (toward_inf) {
|
||||
++from_bits;
|
||||
} else if ((from_bits & 0x7fff) == 0) {
|
||||
// Adjusting away from inf, but from is zero, so just toggle the sign.
|
||||
from_bits ^= 0x8000;
|
||||
} else {
|
||||
--from_bits;
|
||||
}
|
||||
return numext::bit_cast<bfloat16>(from_bits);
|
||||
}
|
||||
|
||||
} // namespace numext
|
||||
} // namespace Eigen
|
||||
|
||||
|
@ -353,6 +353,40 @@ void test_product() {
|
||||
VERIFY_IS_APPROX(Ch.noalias() += Ah * Bh, (Cf.noalias() += Af * Bf).cast<bfloat16>());
|
||||
}
|
||||
|
||||
void test_nextafter() {
|
||||
VERIFY((numext::isnan)(numext::nextafter(std::numeric_limits<bfloat16>::quiet_NaN(), bfloat16(1.0f))));
|
||||
VERIFY((numext::isnan)(numext::nextafter(bfloat16(1.0f), std::numeric_limits<bfloat16>::quiet_NaN())));
|
||||
VERIFY(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)) == bfloat16(0.0f));
|
||||
VERIFY(numext::nextafter(bfloat16(1.0f), bfloat16(1.0f)) == bfloat16(1.0f));
|
||||
VERIFY(numext::nextafter(bfloat16(-1.0f), bfloat16(-1.0f)) == bfloat16(-1.0f));
|
||||
VERIFY(numext::nextafter(std::numeric_limits<bfloat16>::infinity(), std::numeric_limits<bfloat16>::infinity()) ==
|
||||
std::numeric_limits<bfloat16>::infinity());
|
||||
VERIFY(numext::nextafter(std::numeric_limits<bfloat16>::infinity(), bfloat16(0.0f)) ==
|
||||
(std::numeric_limits<bfloat16>::max)());
|
||||
VERIFY(numext::nextafter(-std::numeric_limits<bfloat16>::infinity(), bfloat16(0.0f)) ==
|
||||
-(std::numeric_limits<bfloat16>::max)());
|
||||
VERIFY(numext::nextafter(bfloat16(1.0f), std::numeric_limits<bfloat16>::infinity()) ==
|
||||
bfloat16(1.0f) + std::numeric_limits<bfloat16>::epsilon());
|
||||
VERIFY(numext::nextafter(bfloat16(1.0f), -std::numeric_limits<bfloat16>::infinity()) ==
|
||||
bfloat16(1.0f) - std::numeric_limits<bfloat16>::epsilon() / bfloat16(2.0f));
|
||||
VERIFY(numext::nextafter(bfloat16(-1.0f), -std::numeric_limits<bfloat16>::infinity()) ==
|
||||
bfloat16(-1.0f) - std::numeric_limits<bfloat16>::epsilon());
|
||||
VERIFY(numext::nextafter(bfloat16(-1.0f), std::numeric_limits<bfloat16>::infinity()) ==
|
||||
bfloat16(-1.0f) + std::numeric_limits<bfloat16>::epsilon() / bfloat16(2.0f));
|
||||
VERIFY(numext::nextafter((std::numeric_limits<bfloat16>::max)(), std::numeric_limits<bfloat16>::infinity()) ==
|
||||
std::numeric_limits<bfloat16>::infinity());
|
||||
VERIFY(numext::nextafter(-(std::numeric_limits<bfloat16>::max)(), -std::numeric_limits<bfloat16>::infinity()) ==
|
||||
-std::numeric_limits<bfloat16>::infinity());
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(1.0f)), 0x0001);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(1.0f)), 0x0000);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-1.0f)), 0x8000);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-1.0f)), 0x8001);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(-0.0f)), 0x8000);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(0.0f)), 0x0000);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(0.0f), bfloat16(0.0f)), 0x0000);
|
||||
VERIFY_BFLOAT16_BITS_EQUAL(numext::nextafter(bfloat16(-0.0f), bfloat16(-0.0f)), 0x8000);
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(bfloat16_float) {
|
||||
CALL_SUBTEST(test_numtraits());
|
||||
for (int i = 0; i < g_repeat; i++) {
|
||||
@ -363,5 +397,6 @@ EIGEN_DECLARE_TEST(bfloat16_float) {
|
||||
CALL_SUBTEST(test_trigonometric_functions());
|
||||
CALL_SUBTEST(test_array());
|
||||
CALL_SUBTEST(test_product());
|
||||
CALL_SUBTEST(test_nextafter());
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user