Use numext::signbit instead of std::signbit, which is not defined for bfloat16.

This commit is contained in:
Rasmus Munk Larsen 2022-12-15 18:41:46 +00:00
parent 37de432907
commit 3717854a21

View File

@ -287,27 +287,9 @@ struct functor_traits<test_signbit_op<Scalar>> {
} // namespace internal
} // namespace Eigen
template <typename T, bool IsInteger = NumTraits<T>::IsInteger>
struct ref_signbit_func_impl {
static bool run(const T& x) { return std::signbit(x); }
};
template <typename T>
struct ref_signbit_func_impl<T, true> {
// MSVC (perhaps others) does not have a std::signbit overload for integers
static bool run(const T& x) { return x < T(0); }
};
template <typename T>
bool ref_signbit_func(const T& x) {
return ref_signbit_func_impl<T>::run(x);
}
template <typename Scalar>
void signbit_test() {
Scalar true_mask;
std::memset(static_cast<void*>(&true_mask), 0xff, sizeof(Scalar));
Scalar false_mask;
std::memset(static_cast<void*>(&false_mask), 0x00, sizeof(Scalar));
const size_t size = 100 * internal::packet_traits<Scalar>::size;
ArrayX<Scalar> x(size), y(size);
x.setRandom();
@ -320,7 +302,7 @@ void signbit_test() {
bool all_pass = true;
for (size_t i = 0; i < size; i++) {
const Scalar ref_val = ref_signbit_func(x(i)) ? true_mask : false_mask;
const Scalar ref_val = numext::signbit(x(i));
bool not_same = internal::predux_any(internal::bitwise_helper<Scalar>::bitwise_xor(ref_val, y(i)));
if (not_same) std::cout << "signbit(" << x(i) << ") != " << y(i) << "\n";
all_pass = all_pass && !not_same;