From 7a87ed1b6a49bd0067856dcba9ad9a3a46186220 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Mon, 8 Aug 2022 18:48:36 +0000 Subject: [PATCH] Fix code and unit test for a few corner cases in vectorized pow() --- .../arch/Default/GenericPacketMathFunctions.h | 72 +++++++++---------- test/array_cwise.cpp | 6 +- 2 files changed, 36 insertions(+), 42 deletions(-) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 822113b9b..5949d2ce2 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1424,38 +1424,40 @@ EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) { } // Generic implementation of pow(x,y). -template -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS -Packet generic_pow(const Packet& x, const Packet& y) { +template +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Packet& x, const Packet& y) { typedef typename unpacket_traits::type Scalar; const Packet cst_pos_inf = pset1(NumTraits::infinity()); + const Packet cst_neg_inf = pset1(-NumTraits::infinity()); const Packet cst_zero = pset1(Scalar(0)); const Packet cst_one = pset1(Scalar(1)); const Packet cst_nan = pset1(NumTraits::quiet_NaN()); const Packet abs_x = pabs(x); // Predicates for sign and magnitude of x. - const Packet x_is_zero = pcmp_eq(x, cst_zero); - const Packet x_is_neg = pcmp_lt(x, cst_zero); + const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_zero); + const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf); + const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero); + const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero); const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); - const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one); + const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one); const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x); const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one); - const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); - const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); + const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg); + const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg); const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x)); // Predicates for sign and magnitude of y. + const Packet abs_y = pabs(y); const Packet y_is_one = pcmp_eq(y, cst_one); - const Packet y_is_zero = pcmp_eq(y, cst_zero); + const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero); const Packet y_is_neg = pcmp_lt(y, cst_zero); - const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg)); + const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg)); const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y)); - const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf); + const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf); EIGEN_CONSTEXPR Scalar huge_exponent = - (NumTraits::max_exponent() * Scalar(EIGEN_LN2)) / - NumTraits::epsilon(); + (NumTraits::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits::epsilon(); const Packet abs_y_is_huge = pcmp_le(pset1(huge_exponent), pabs(y)); // Predicates for whether y is integer and/or even. @@ -1464,39 +1466,31 @@ Packet generic_pow(const Packet& x, const Packet& y) { const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2); // Predicates encoding special cases for the value of pow(x,y) - const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), - y_is_int), - abs_y_is_inf); - const Packet pow_is_one = por(por(x_is_one, y_is_zero), - pand(x_is_neg_one, - por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x)))); + const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf); const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan)); - const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), - pand(abs_x_is_inf, y_is_neg)), - pand(pand(abs_x_is_lt_one, abs_y_is_huge), - y_is_pos)), - pand(pand(abs_x_is_gt_one, abs_y_is_huge), - y_is_neg)); - const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), - pand(abs_x_is_inf, y_is_pos)), - pand(pand(abs_x_is_lt_one, abs_y_is_huge), - y_is_neg)), - pand(pand(abs_x_is_gt_one, abs_y_is_huge), - y_is_pos)); + const Packet pow_is_one = + por(por(x_is_one, abs_y_is_zero), pand(x_is_neg_one, por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x)))); + const Packet pow_is_zero = por(por(por(pand(abs_x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)), + pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_pos)), + pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_neg)); + const Packet pow_is_inf = por(por(por(pand(abs_x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)), + pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)), + pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos)); + const Packet inf_val = + pselect(pandnot(pand(por(pand(abs_x_is_inf, x_is_neg), pand(x_is_neg_zero, y_is_neg)), y_is_int), y_is_even), + cst_neg_inf, cst_pos_inf); // General computation of pow(x,y) for positive x or negative x and integer y. const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even); const Packet pow_abs = generic_pow_impl(abs_x, y); - return pselect(y_is_one, x, - pselect(pow_is_one, cst_one, - pselect(pow_is_nan, cst_nan, - pselect(pow_is_inf, cst_pos_inf, - pselect(pow_is_zero, cst_zero, - pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))); + return pselect( + y_is_one, x, + pselect(pow_is_one, cst_one, + pselect(pow_is_nan, cst_nan, + pselect(pow_is_inf, inf_val, + pselect(pow_is_zero, cst_zero, pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))); } - - /* polevl (modified for Eigen) * * Evaluate polynomial diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 298351eef..48290a171 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -72,9 +72,9 @@ void pow_test() { for (int j = 0; j < num_cases; ++j) { Scalar e = static_cast(std::pow(x(i,j), y(i,j))); Scalar a = actual(i, j); - bool fail = !(a==e) && !internal::isApprox(a, e, tol) && !((numext::isnan)(a) && (numext::isnan)(e)); - all_pass &= !fail; - if (fail) { + bool success = (a==e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e)); + all_pass &= success; + if (!success) { std::cout << "pow(" << x(i,j) << "," << y(i,j) << ") = " << a << " != " << e << std::endl; } }