// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2017 Gael Guennebaud // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #include "main.h" template bool check_if_equal_or_nans(const T& actual, const U& expected) { return (numext::equal_strict(actual, expected) || ((numext::isnan)(actual) && (numext::isnan)(expected))); } template bool check_if_equal_or_nans(const std::complex& actual, const std::complex& expected) { return check_if_equal_or_nans(numext::real(actual), numext::real(expected)) && check_if_equal_or_nans(numext::imag(actual), numext::imag(expected)); } template bool test_is_equal_or_nans(const T& actual, const U& expected) { if (check_if_equal_or_nans(actual, expected)) { return true; } // false: std::cerr << "\n actual = " << actual << "\n expected = " << expected << "\n\n"; return false; } #define VERIFY_IS_EQUAL_OR_NANS(a, b) VERIFY(test_is_equal_or_nans(a, b)) template void check_negate() { Index size = 1000; for (Index i = 0; i < size; i++) { T val = i == 0 ? T(0) : internal::random(T(0), NumTraits::highest()); T neg_val = numext::negate(val); VERIFY_IS_EQUAL(T(val + neg_val), T(0)); VERIFY_IS_EQUAL(numext::negate(neg_val), val); } } template void check_abs() { typedef typename NumTraits::Real Real; Real zero(0); if (NumTraits::IsSigned) VERIFY_IS_EQUAL(numext::abs(numext::negate(T(1))), T(1)); VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); for (int k = 0; k < 100; ++k) { T x = internal::random(); x = x / Real(2); if (NumTraits::IsSigned) { VERIFY_IS_EQUAL(numext::abs(x), numext::abs(numext::negate(x))); VERIFY(numext::abs(numext::negate(x)) >= zero); } VERIFY(numext::abs(x) >= zero); VERIFY_IS_APPROX(numext::abs2(x), numext::abs2(numext::abs(x))); } } template <> void check_abs() { for (bool x : {true, false}) { VERIFY_IS_EQUAL(numext::abs(x), x); VERIFY(numext::abs(x) >= false); VERIFY_IS_EQUAL(numext::abs2(x), numext::abs2(numext::abs(x))); } } template void check_arg() { typedef typename NumTraits::Real Real; VERIFY_IS_EQUAL(numext::abs(T(0)), T(0)); VERIFY_IS_EQUAL(numext::abs(T(1)), T(1)); for (int k = 0; k < 100; ++k) { T x = internal::random(); Real y = numext::arg(x); VERIFY_IS_APPROX(y, std::arg(x)); } } template struct check_sqrt_impl { static void run() { for (int i = 0; i < 1000; ++i) { const T x = numext::abs(internal::random()); const T sqrtx = numext::sqrt(x); VERIFY_IS_APPROX(sqrtx * sqrtx, x); } // Corner cases. const T zero = T(0); const T one = T(1); const T inf = std::numeric_limits::infinity(); const T nan = std::numeric_limits::quiet_NaN(); VERIFY_IS_EQUAL(numext::sqrt(zero), zero); VERIFY_IS_EQUAL(numext::sqrt(inf), inf); VERIFY((numext::isnan)(numext::sqrt(nan))); VERIFY((numext::isnan)(numext::sqrt(-one))); } }; template struct check_sqrt_impl > { static void run() { typedef typename std::complex ComplexT; for (int i = 0; i < 1000; ++i) { const ComplexT x = internal::random(); const ComplexT sqrtx = numext::sqrt(x); VERIFY_IS_APPROX(sqrtx * sqrtx, x); } // Corner cases. const T zero = T(0); const T one = T(1); const T inf = std::numeric_limits::infinity(); const T nan = std::numeric_limits::quiet_NaN(); // Set of corner cases from https://en.cppreference.com/w/cpp/numeric/complex/sqrt const int kNumCorners = 20; const ComplexT corners[kNumCorners][2] = { {ComplexT(zero, zero), ComplexT(zero, zero)}, {ComplexT(-zero, zero), ComplexT(zero, zero)}, {ComplexT(zero, -zero), ComplexT(zero, zero)}, {ComplexT(-zero, -zero), ComplexT(zero, zero)}, {ComplexT(one, inf), ComplexT(inf, inf)}, {ComplexT(nan, inf), ComplexT(inf, inf)}, {ComplexT(one, -inf), ComplexT(inf, -inf)}, {ComplexT(nan, -inf), ComplexT(inf, -inf)}, {ComplexT(-inf, one), ComplexT(zero, inf)}, {ComplexT(inf, one), ComplexT(inf, zero)}, {ComplexT(-inf, -one), ComplexT(zero, -inf)}, {ComplexT(inf, -one), ComplexT(inf, -zero)}, {ComplexT(-inf, nan), ComplexT(nan, inf)}, {ComplexT(inf, nan), ComplexT(inf, nan)}, {ComplexT(zero, nan), ComplexT(nan, nan)}, {ComplexT(one, nan), ComplexT(nan, nan)}, {ComplexT(nan, zero), ComplexT(nan, nan)}, {ComplexT(nan, one), ComplexT(nan, nan)}, {ComplexT(nan, -one), ComplexT(nan, nan)}, {ComplexT(nan, nan), ComplexT(nan, nan)}, }; for (int i = 0; i < kNumCorners; ++i) { const ComplexT& x = corners[i][0]; const ComplexT sqrtx = corners[i][1]; VERIFY_IS_EQUAL_OR_NANS(numext::sqrt(x), sqrtx); } } }; template void check_sqrt() { check_sqrt_impl::run(); } template struct check_rsqrt_impl { static void run() { const T zero = T(0); const T one = T(1); const T inf = std::numeric_limits::infinity(); const T nan = std::numeric_limits::quiet_NaN(); for (int i = 0; i < 1000; ++i) { const T x = numext::abs(internal::random()); const T rsqrtx = numext::rsqrt(x); const T invx = one / x; VERIFY_IS_APPROX(rsqrtx * rsqrtx, invx); } // Corner cases. VERIFY_IS_EQUAL(numext::rsqrt(zero), inf); VERIFY_IS_EQUAL(numext::rsqrt(inf), zero); VERIFY((numext::isnan)(numext::rsqrt(nan))); VERIFY((numext::isnan)(numext::rsqrt(-one))); } }; template struct check_rsqrt_impl > { static void run() { typedef typename std::complex ComplexT; const T zero = T(0); const T one = T(1); const T inf = std::numeric_limits::infinity(); const T nan = std::numeric_limits::quiet_NaN(); for (int i = 0; i < 1000; ++i) { const ComplexT x = internal::random(); const ComplexT invx = ComplexT(one, zero) / x; const ComplexT rsqrtx = numext::rsqrt(x); VERIFY_IS_APPROX(rsqrtx * rsqrtx, invx); } // GCC and MSVC differ in their treatment of 1/(0 + 0i) // GCC/clang = (inf, nan) // MSVC = (nan, nan) // and 1 / (x + inf i) // GCC/clang = (0, 0) // MSVC = (nan, nan) #if (EIGEN_COMP_GNUC) { const int kNumCorners = 20; const ComplexT corners[kNumCorners][2] = { // Only consistent across GCC, clang {ComplexT(zero, zero), ComplexT(zero, zero)}, {ComplexT(-zero, zero), ComplexT(zero, zero)}, {ComplexT(zero, -zero), ComplexT(zero, zero)}, {ComplexT(-zero, -zero), ComplexT(zero, zero)}, {ComplexT(one, inf), ComplexT(inf, inf)}, {ComplexT(nan, inf), ComplexT(inf, inf)}, {ComplexT(one, -inf), ComplexT(inf, -inf)}, {ComplexT(nan, -inf), ComplexT(inf, -inf)}, // Consistent across GCC, clang, MSVC {ComplexT(-inf, one), ComplexT(zero, inf)}, {ComplexT(inf, one), ComplexT(inf, zero)}, {ComplexT(-inf, -one), ComplexT(zero, -inf)}, {ComplexT(inf, -one), ComplexT(inf, -zero)}, {ComplexT(-inf, nan), ComplexT(nan, inf)}, {ComplexT(inf, nan), ComplexT(inf, nan)}, {ComplexT(zero, nan), ComplexT(nan, nan)}, {ComplexT(one, nan), ComplexT(nan, nan)}, {ComplexT(nan, zero), ComplexT(nan, nan)}, {ComplexT(nan, one), ComplexT(nan, nan)}, {ComplexT(nan, -one), ComplexT(nan, nan)}, {ComplexT(nan, nan), ComplexT(nan, nan)}, }; for (int i = 0; i < kNumCorners; ++i) { const ComplexT& x = corners[i][0]; const ComplexT rsqrtx = ComplexT(one, zero) / corners[i][1]; VERIFY_IS_EQUAL_OR_NANS(numext::rsqrt(x), rsqrtx); } } #endif } }; template void check_rsqrt() { check_rsqrt_impl::run(); } template struct check_signbit_impl { static void run() { T true_mask; std::memset(static_cast(&true_mask), 0xff, sizeof(T)); T false_mask; std::memset(static_cast(&false_mask), 0x00, sizeof(T)); std::vector negative_values; std::vector non_negative_values; if (NumTraits::IsInteger) { negative_values = {static_cast(-1), static_cast(NumTraits::lowest())}; non_negative_values = {static_cast(0), static_cast(1), static_cast(NumTraits::highest())}; } else { // does not have sign bit const T pos_zero = static_cast(0.0); const T pos_one = static_cast(1.0); const T pos_inf = std::numeric_limits::infinity(); const T pos_nan = std::numeric_limits::quiet_NaN(); // has sign bit const T neg_zero = numext::negate(pos_zero); const T neg_one = numext::negate(pos_one); const T neg_inf = numext::negate(pos_inf); const T neg_nan = numext::negate(pos_nan); negative_values = {neg_zero, neg_one, neg_inf, neg_nan}; non_negative_values = {pos_zero, pos_one, pos_inf, pos_nan}; } auto check_all = [](auto values, auto expected) { bool all_pass = true; for (T val : values) { const T numext_val = numext::signbit(val); bool not_same = internal::predux_any(internal::bitwise_helper::bitwise_xor(expected, numext_val)); all_pass = all_pass && !not_same; if (not_same) std::cout << "signbit(" << val << ") = " << numext_val << " != " << expected << std::endl; } return all_pass; }; bool check_all_pass = check_all(non_negative_values, false_mask); check_all_pass = check_all_pass && check_all(negative_values, (NumTraits::IsSigned ? true_mask : false_mask)); VERIFY(check_all_pass); } }; template void check_signbit() { check_signbit_impl::run(); } template void check_shift() { using SignedT = typename numext::get_integer_by_size::signed_type; using UnsignedT = typename numext::get_integer_by_size::unsigned_type; constexpr int kNumBits = CHAR_BIT * sizeof(T); for (int i = 0; i < 1000; ++i) { const T a = internal::random(); for (int s = 1; s < kNumBits; s++) { T a_bsll = numext::logical_shift_left(a, s); T a_bsll_ref = a << s; VERIFY_IS_EQUAL(a_bsll, a_bsll_ref); T a_bsrl = numext::logical_shift_right(a, s); T a_bsrl_ref = numext::bit_cast(numext::bit_cast(a) >> s); VERIFY_IS_EQUAL(a_bsrl, a_bsrl_ref); T a_bsra = numext::arithmetic_shift_right(a, s); T a_bsra_ref = numext::bit_cast(numext::bit_cast(a) >> s); VERIFY_IS_EQUAL(a_bsra, a_bsra_ref); } } } EIGEN_DECLARE_TEST(numext) { for (int k = 0; k < g_repeat; ++k) { CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate()); CALL_SUBTEST(check_negate >()); CALL_SUBTEST(check_negate >()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs()); CALL_SUBTEST(check_abs >()); CALL_SUBTEST(check_abs >()); CALL_SUBTEST(check_arg >()); CALL_SUBTEST(check_arg >()); CALL_SUBTEST(check_sqrt()); CALL_SUBTEST(check_sqrt()); CALL_SUBTEST(check_sqrt >()); CALL_SUBTEST(check_sqrt >()); CALL_SUBTEST(check_rsqrt()); CALL_SUBTEST(check_rsqrt()); CALL_SUBTEST(check_rsqrt >()); CALL_SUBTEST(check_rsqrt >()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_signbit()); CALL_SUBTEST(check_shift()); CALL_SUBTEST(check_shift()); CALL_SUBTEST(check_shift()); CALL_SUBTEST(check_shift()); CALL_SUBTEST(check_shift()); CALL_SUBTEST(check_shift()); CALL_SUBTEST(check_shift()); CALL_SUBTEST(check_shift()); } }