diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 70c83d64a..1bfee7d67 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1970,24 +1970,49 @@ struct pchebevl { }; namespace unary_pow { -template ::IsInteger> -struct is_odd { - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) { - ScalarExponent xdiv2 = x / ScalarExponent(2); - ScalarExponent floorxdiv2 = numext::floor(xdiv2); - return xdiv2 != floorxdiv2; + +template ::IsInteger> +struct exponent_helper { + using safe_abs_type = ScalarExponent; + static constexpr ScalarExponent one_half = ScalarExponent(0.5); + // these routines assume that exp is an integer stored as a floating point type + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent safe_abs(const ScalarExponent& exp) { + return numext::abs(exp); + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const ScalarExponent& exp) { + eigen_assert(((numext::isfinite)(exp) && exp == numext::floor(exp)) && "exp must be an integer"); + ScalarExponent exp_div_2 = exp * one_half; + ScalarExponent floor_exp_div_2 = numext::floor(exp_div_2); + return exp_div_2 != floor_exp_div_2; + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent floor_div_two(const ScalarExponent& exp) { + ScalarExponent exp_div_2 = exp * one_half; + return numext::floor(exp_div_2); } }; + template -struct is_odd { - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) { - return x % ScalarExponent(2) != 0; +struct exponent_helper { + // if `exp` is a signed integer type, cast it to its unsigned counterpart to safely store its absolute value + // consider the (rare) case where `exp` is an int32_t: abs(-2147483648) != 2147483648 + using safe_abs_type = typename numext::get_integer_by_size::unsigned_type; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type safe_abs(const ScalarExponent& exp) { + ScalarExponent mask = exp ^ numext::abs(exp); + safe_abs_type result = static_cast(exp); + return result ^ mask; + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const safe_abs_type& exp) { + return exp % safe_abs_type(2) != safe_abs_type(0); + } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type floor_div_two(const safe_abs_type& exp) { + return exp >> safe_abs_type(1); } }; template ::type>::IsInteger> -struct do_div { + bool ReciprocateIfExponentIsNegative = + !NumTraits::type>::IsInteger && NumTraits::IsSigned> +struct reciprocate { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { using Scalar = typename unpacket_traits::type; const Packet cst_pos_one = pset1(Scalar(1)); @@ -1996,41 +2021,43 @@ struct do_div { }; template -struct do_div { +struct reciprocate { // pdiv not defined, nor necessary for integer base types - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { - return x; - } + // if the exponent is unsigned, then the exponent cannot be negative + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { return x; } }; template -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) { +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) { using Scalar = typename unpacket_traits::type; + using ExponentHelper = exponent_helper; + using AbsExponentType = typename ExponentHelper::safe_abs_type; const Packet cst_pos_one = pset1(Scalar(1)); - if (exponent == 0) return cst_pos_one; - Packet result = x; + if (exponent == ScalarExponent(0)) return cst_pos_one; + + Packet result = reciprocate::run(x, exponent); Packet y = cst_pos_one; - ScalarExponent m = numext::abs(exponent); + AbsExponentType m = ExponentHelper::safe_abs(exponent); + while (m > 1) { - bool odd = is_odd::run(m); + bool odd = ExponentHelper::is_odd(m); if (odd) y = pmul(y, result); result = pmul(result, result); - m = numext::floor(m / ScalarExponent(2)); + m = ExponentHelper::floor_div_two(m); } - result = pmul(y, result); - result = do_div::run(result, exponent); - return result; + + return pmul(y, result); } template -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x, +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x, const typename unpacket_traits::type& exponent) { const Packet exponent_packet = pset1(exponent); return generic_pow_impl(x, exponent_packet); } template -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, const ScalarExponent& exponent) { using Scalar = typename unpacket_traits::type; @@ -2045,36 +2072,45 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors( const Packet cst_pos_one = pset1(pos_one); const Packet cst_pos_inf = pset1(pos_inf); - const bool exponent_is_nan = (numext::isnan)(exponent); - const bool exponent_is_fin = (numext::isfinite)(exponent); + const bool exponent_is_not_fin = !(numext::isfinite)(exponent); const bool exponent_is_neg = exponent < ScalarExponent(0); + const bool exponent_is_pos = exponent > ScalarExponent(0); - const Packet exp_is_nan = pset1(exponent_is_nan ? all_ones : pos_zero); - const Packet exp_is_fin = pset1(exponent_is_fin ? all_ones : pos_zero); + const Packet exp_is_not_fin = pset1(exponent_is_not_fin ? all_ones : pos_zero); const Packet exp_is_neg = pset1(exponent_is_neg ? all_ones : pos_zero); + const Packet exp_is_pos = pset1(exponent_is_pos ? all_ones : pos_zero); + const Packet exp_is_inf = pand(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); + const Packet exp_is_nan = pandnot(exp_is_not_fin, por(exp_is_neg, exp_is_pos)); - const Packet x_is_gt_one = pcmp_lt(cst_pos_one, x); - const Packet x_is_lt_one = pcmp_lt(x, cst_pos_one); - const Packet x_is_zero = pcmp_eq(x, cst_pos_zero); - const Packet x_is_not_one = por(x_is_gt_one, x_is_lt_one); + const Packet x_is_le_zero = pcmp_le(x, cst_pos_zero); + const Packet x_is_ge_zero = pcmp_le(cst_pos_zero, x); + const Packet x_is_zero = pand(x_is_le_zero, x_is_ge_zero); - const Packet inf_if_neg_exp = pand(cst_pos_inf, exp_is_neg); - const Packet inf_if_pos_exp = pandnot(cst_pos_inf, exp_is_neg); + const Packet abs_x = pabs(x); + const Packet abs_x_is_le_one = pcmp_le(abs_x, cst_pos_one); + const Packet abs_x_is_ge_one = pcmp_le(cst_pos_one, abs_x); + const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); + const Packet abs_x_is_one = pand(abs_x_is_le_one, abs_x_is_ge_one); + + Packet pow_is_inf_if_exp_is_neg = por(x_is_zero, pand(abs_x_is_le_one, exp_is_inf)); + Packet pow_is_inf_if_exp_is_pos = por(abs_x_is_inf, pand(abs_x_is_ge_one, exp_is_inf)); + Packet pow_is_one = pand(abs_x_is_one, por(exp_is_inf, x_is_ge_zero)); Packet result = powx; - result = pselect(x_is_zero, inf_if_neg_exp, result); - result = pselect(pandnot(x_is_gt_one, exp_is_fin), inf_if_pos_exp, result); - result = pselect(pandnot(x_is_lt_one, exp_is_fin), inf_if_neg_exp, result); + result = por(x_is_le_zero, result); + result = pselect(pow_is_inf_if_exp_is_neg, pand(cst_pos_inf, exp_is_neg), result); + result = pselect(pow_is_inf_if_exp_is_pos, pand(cst_pos_inf, exp_is_pos), result); result = por(exp_is_nan, result); - result = pselect(x_is_not_one, result, cst_pos_one); + result = pselect(pow_is_one, cst_pos_one, result); return result; } -template -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) { +template ::type>::IsSigned, bool> = true> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent& exponent) { using Scalar = typename unpacket_traits::type; - // integer base, integer exponent case + // singed integer base, signed integer exponent case // This routine handles negative exponents. // The return value is either 0, 1, or -1. @@ -2085,7 +2121,7 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& const Packet cst_pos_one = pset1(pos_one); - const bool exponent_is_odd = unary_pow::is_odd::run(exponent); + const bool exponent_is_odd = exponent % ScalarExponent(2) != ScalarExponent(0); const Packet exp_is_odd = pset1(exponent_is_odd ? all_ones : pos_zero); @@ -2097,15 +2133,36 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& return result; } +template ::type>::IsSigned, bool> = true> +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent&) { + using Scalar = typename unpacket_traits::type; + + // unsigned integer base, signed integer exponent case + + // This routine handles negative exponents. + // The return value is either 0 or 1 + + const Scalar pos_one = Scalar(1); + + const Packet cst_pos_one = pset1(pos_one); + + const Packet x_is_one = pcmp_eq(x, cst_pos_one); + + return pand(x_is_one, x); +} + + } // end namespace unary_pow template ::type>::IsInteger, - bool ExponentIsIntegerType = NumTraits::IsInteger> + bool ExponentIsIntegerType = NumTraits::IsInteger, + bool ExponentIsSigned = NumTraits::IsSigned> struct unary_pow_impl; -template -struct unary_pow_impl { +template +struct unary_pow_impl { typedef typename unpacket_traits::type Scalar; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent; @@ -2119,8 +2176,8 @@ struct unary_pow_impl { } }; -template -struct unary_pow_impl { +template +struct unary_pow_impl { typedef typename unpacket_traits::type Scalar; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { return unary_pow::int_pow(x, exponent); @@ -2128,17 +2185,25 @@ struct unary_pow_impl { }; template -struct unary_pow_impl { +struct unary_pow_impl { typedef typename unpacket_traits::type Scalar; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { if (exponent < ScalarExponent(0)) { - return unary_pow::handle_int_int(x, exponent); + return unary_pow::handle_negative_exponent(x, exponent); } else { return unary_pow::int_pow(x, exponent); } } }; +template +struct unary_pow_impl { + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + return unary_pow::int_pow(x, exponent); + } +}; + } // end namespace internal } // end namespace Eigen diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index 760112569..989359fe1 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -191,53 +191,97 @@ void unary_ops_test() { */ } +template ::IsInteger> +struct ref_pow { + static Base run(Base base, Exponent exponent) { + EIGEN_USING_STD(pow); + return static_cast(pow(base, static_cast(exponent))); + } +}; -template -void pow_scalar_exponent_test() { - using Int_t = typename internal::make_integer::type; - const Scalar tol = test_precision(); +template +struct ref_pow { + static Base run(Base base, Exponent exponent) { + EIGEN_USING_STD(pow); + return static_cast(pow(base, exponent)); + } +}; - std::vector abs_vals = special_values(); - const Index num_vals = (Index)abs_vals.size(); - Map> bases(abs_vals.data(), num_vals); +template ::IsInteger> +struct pow_helper { + static bool is_integer_impl(const Exponent& exp) { return (numext::isfinite)(exp) && exp == numext::floor(exp); } + static bool is_odd_impl(const Exponent& exp) { + Exponent exp_div_2 = exp / Exponent(2); + Exponent floor_exp_div_2 = numext::floor(exp_div_2); + return exp_div_2 != floor_exp_div_2; + } +}; +template +struct pow_helper { + static bool is_integer_impl(const Exponent&) { return true; } + static bool is_odd_impl(const Exponent& exp) { return exp % 2 != 0; } +}; +template +bool is_integer(const Exponent& exp) { + return pow_helper::is_integer_impl(exp); +} +template +bool is_odd(const Exponent& exp) { + return pow_helper::is_odd_impl(exp); +} +template +void float_pow_test_impl() { + const Base tol = test_precision(); + std::vector abs_base_vals = special_values(); + std::vector abs_exponent_vals = special_values(); + for (int i = 0; i < 100; i++) { + abs_base_vals.push_back(internal::random(Base(0), Base(10))); + abs_exponent_vals.push_back(internal::random(Exponent(0), Exponent(10))); + } + const Index num_repeats = internal::packet_traits::size + 1; + ArrayX bases(num_repeats), eigenPow(num_repeats); bool all_pass = true; - for (Scalar abs_exponent : abs_vals) { - for (Scalar exponent : {-abs_exponent, abs_exponent}) { - // test integer exponent code path - bool exponent_is_integer = (numext::isfinite)(exponent) && (numext::round(exponent) == exponent) && - (numext::abs(exponent) < static_cast(NumTraits::highest())); - if (exponent_is_integer) { - Int_t exponent_as_int = static_cast(exponent); - Array eigenPow = bases.pow(exponent_as_int); - for (Index j = 0; j < num_vals; j++) { - Scalar e = static_cast(std::pow(bases(j), exponent)); - Scalar a = eigenPow(j); - bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || - ((numext::isnan)(a) && (numext::isnan)(e)); - if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a); - all_pass &= success; - if (!success) { - std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl; - } - } - } else { - // test floating point exponent code path - Array eigenPow = bases.pow(exponent); - for (Index j = 0; j < num_vals; j++) { - Scalar e = static_cast(std::pow(bases(j), exponent)); - Scalar a = eigenPow(j); - bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || - ((numext::isnan)(a) && (numext::isnan)(e)); - if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a); - all_pass &= success; - if (!success) { - std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl; + for (Base abs_base : abs_base_vals) + for (Base base : {negative_or_zero(abs_base), abs_base}) { + bases.setConstant(base); + for (Exponent abs_exponent : abs_exponent_vals) { + for (Exponent exponent : {negative_or_zero(abs_exponent), abs_exponent}) { + eigenPow = bases.pow(exponent); + for (Index j = 0; j < num_repeats; j++) { + Base e = ref_pow::run(bases(j), exponent); + if (is_integer(exponent)) { + // std::pow may return an incorrect result for a very large integral exponent + // if base is negative and the exponent is odd, then the result must be negative + // if std::pow returns otherwise, flip the sign + bool exp_is_odd = is_odd(exponent); + bool base_is_neg = !(numext::isnan)(base) && (bool)numext::signbit(base); + bool result_is_neg = exp_is_odd && base_is_neg; + bool ref_is_neg = !(numext::isnan)(e) && (bool)numext::signbit(e); + bool flip_sign = result_is_neg != ref_is_neg; + if (flip_sign) e = -e; + } + + Base a = eigenPow(j); + #ifdef EIGEN_COMP_MSVC + // Work around MSVC return value on underflow. + // if std::pow returns 0 and Eigen returns a denormalized value, then skip the test + int fpclass = std::fpclassify(a); + if (e == Base(0) && fpclass == FP_SUBNORMAL) continue; + #endif + + bool both_nan = (numext::isnan)(a) && (numext::isnan)(e); + bool exact_or_approx = (a == e) || internal::isApprox(a, e, tol); + bool same_sign = (bool)numext::signbit(e) == (bool)numext::signbit(a); + bool success = both_nan || (exact_or_approx && same_sign); + all_pass &= success; + if (!success) { + std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl; + } } } } } - } VERIFY(all_pass); } @@ -259,24 +303,9 @@ Scalar calc_overflow_threshold(const ScalarExponent exponent) { } } -template ::IsInteger> -struct ref_pow { - static Base run(Base base, Exponent exponent) { - EIGEN_USING_STD(pow); - return static_cast(pow(base, static_cast(exponent))); - } -}; - -template -struct ref_pow { - static Base run(Base base, Exponent exponent) { - EIGEN_USING_STD(pow); - return static_cast(pow(base, exponent)); - } -}; - template void test_exponent(Exponent exponent) { + EIGEN_STATIC_ASSERT(NumTraits::IsInteger,THIS TEST IS ONLY INTENDED FOR BASE INTEGER TYPES) const Base max_abs_bases = static_cast(10000); // avoid integer overflow in Base type Base threshold = calc_overflow_threshold(numext::abs(exponent)); @@ -300,10 +329,10 @@ void test_exponent(Exponent exponent) { for (Base a : y) { Base e = ref_pow::run(base, exponent); bool pass = (a == e); - if (!NumTraits::IsInteger) { - pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) || - ((numext::isnan)(a) && (numext::isnan)(e))); - } + //if (!NumTraits::IsInteger) { + // pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) || + // ((numext::isnan)(a) && (numext::isnan)(e))); + //} all_pass &= pass; if (!pass) { std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl; @@ -314,7 +343,7 @@ void test_exponent(Exponent exponent) { } template -void unary_pow_test() { +void int_pow_test_impl() { Exponent max_exponent = static_cast(NumTraits::digits()); Exponent min_exponent = negative_or_zero(max_exponent); @@ -323,21 +352,26 @@ void unary_pow_test() { } } +void float_pow_test() { + float_pow_test_impl(); + float_pow_test_impl(); +} + void mixed_pow_test() { // The following cases will test promoting a smaller exponent type // to a wider base type. - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); + float_pow_test_impl(); + float_pow_test_impl(); + float_pow_test_impl(); + float_pow_test_impl(); + float_pow_test_impl(); + float_pow_test_impl(); // Although in the following cases the exponent cannot be represented exactly // in the base type, we do not perform a conversion, but implement // the operation using repeated squaring. - unary_pow_test(); - unary_pow_test(); + float_pow_test_impl(); + float_pow_test_impl(); // The following cases will test promoting a wider exponent type // to a narrower base type. This should compile but would generate a @@ -346,20 +380,20 @@ void mixed_pow_test() { } void int_pow_test() { - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); + int_pow_test_impl(); + int_pow_test_impl(); + int_pow_test_impl(); + int_pow_test_impl(); // Although in the following cases the exponent cannot be represented exactly // in the base type, we do not perform a conversion, but implement the // operation using repeated squaring. - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); - unary_pow_test(); + int_pow_test_impl(); + int_pow_test_impl(); + int_pow_test_impl(); + int_pow_test_impl(); + int_pow_test_impl(); + int_pow_test_impl(); } namespace Eigen { @@ -849,7 +883,6 @@ template void array_real(const ArrayType& m) // Test pow and atan2 on special IEEE values. unary_ops_test(); binary_ops_test(); - pow_scalar_exponent_test(); VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10))); VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2))); @@ -1223,6 +1256,7 @@ EIGEN_DECLARE_TEST(array_cwise) } for(int i = 0; i < g_repeat; i++) { + CALL_SUBTEST_5( float_pow_test() ); CALL_SUBTEST_6( int_pow_test() ); CALL_SUBTEST_7( mixed_pow_test() ); CALL_SUBTEST_8( signbit_tests() );