From e1aee4ab3942d1bae16f829ca2b840830d95ac14 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Thu, 15 Dec 2022 11:39:32 -0800 Subject: [PATCH] Update test of numext::signbit. --- test/numext.cpp | 73 ++++++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/test/numext.cpp b/test/numext.cpp index 5483e5c4a..e99eddc2f 100644 --- a/test/numext.cpp +++ b/test/numext.cpp @@ -239,19 +239,6 @@ void check_rsqrt() { check_rsqrt_impl::run(); } -template ::IsInteger> -struct ref_signbit_func_impl { - static bool run(const T& x) { return std::signbit(x); } -}; -template -struct ref_signbit_func_impl { - // MSVC (perhaps others) does not have a std::signbit overload for integers - static bool run(const T& x) { return x < T(0); } -}; -template -bool ref_signbit_func(const T& x) { - return ref_signbit_func_impl::run(x); -} template struct check_signbit_impl { @@ -261,28 +248,46 @@ struct check_signbit_impl { T false_mask; std::memset(static_cast(&false_mask), 0x00, sizeof(T)); - // has sign bit - const T neg_zero = static_cast(-0.0); - const T neg_one = static_cast(-1.0); - const T neg_inf = -std::numeric_limits::infinity(); - const T neg_nan = -std::numeric_limits::quiet_NaN(); - // does not have sign bit - const T pos_zero = static_cast(0.0); - const T pos_one = static_cast(1.0); - const T pos_inf = std::numeric_limits::infinity(); - const T pos_nan = std::numeric_limits::quiet_NaN(); + std::vector negative_values; + std::vector non_negative_values; - std::vector values = {neg_zero, neg_one, neg_inf, neg_nan, pos_zero, pos_one, pos_inf, pos_nan}; - - bool all_pass = true; - - for (T val : values) { - const T numext_val = numext::signbit(val); - const T ref_val = ref_signbit_func(val) ? true_mask : false_mask; - bool not_same = internal::predux_any(internal::bitwise_helper::bitwise_xor(ref_val, numext_val)); - all_pass = all_pass && !not_same; - if (not_same) std::cout << "signbit(" << val << ") != " << numext_val << "\n"; + if (NumTraits::IsInteger) { + negative_values = {static_cast(-1), + static_cast(NumTraits::lowest())}; + non_negative_values = {static_cast(0), static_cast(1), + static_cast(NumTraits::highest())}; + } else { + // has sign bit + const T neg_zero = static_cast(-0.0); + const T neg_one = static_cast(-1.0); + const T neg_inf = -std::numeric_limits::infinity(); + const T neg_nan = -std::numeric_limits::quiet_NaN(); + // does not have sign bit + const T pos_zero = static_cast(0.0); + const T pos_one = static_cast(1.0); + const T pos_inf = std::numeric_limits::infinity(); + const T pos_nan = std::numeric_limits::quiet_NaN(); + negative_values = {neg_zero, neg_one, neg_inf, neg_nan}; + non_negative_values = {pos_zero, pos_one, pos_inf, pos_nan}; } + + + auto check_all = [](auto values, auto expected) { + bool all_pass = true; + for (T val : values) { + const T numext_val = numext::signbit(val); + bool not_same = internal::predux_any( + internal::bitwise_helper::bitwise_xor(expected, numext_val)); + all_pass = all_pass && !not_same; + if (not_same) + std::cout << "signbit(" << val << ") = " << numext_val + << " != " << expected << std::endl; + } + return all_pass; + }; + + bool all_pass = check_all(non_negative_values, false_mask); + all_pass = all_pass && check_all(negative_values, (NumTraits::IsSigned ? true_mask : false_mask)); VERIFY(all_pass); } }; @@ -318,7 +323,7 @@ EIGEN_DECLARE_TEST(numext) { CALL_SUBTEST( check_sqrt() ); CALL_SUBTEST( check_sqrt >() ); CALL_SUBTEST( check_sqrt >() ); - + CALL_SUBTEST( check_rsqrt() ); CALL_SUBTEST( check_rsqrt() ); CALL_SUBTEST( check_rsqrt >() );