Fix unary pow error handling and test

This commit is contained in:
Charles Schlosser 2023-06-06 18:46:55 +00:00 committed by Rasmus Munk Larsen
parent 7ac8897431
commit b7151ffaab
2 changed files with 229 additions and 130 deletions

View File

@ -1970,24 +1970,49 @@ struct pchebevl {
}; };
namespace unary_pow { namespace unary_pow {
template <typename ScalarExponent, bool IsIntegerAtCompileTime = NumTraits<ScalarExponent>::IsInteger>
struct is_odd { template <typename ScalarExponent, bool IsInteger = NumTraits<ScalarExponent>::IsInteger>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) { struct exponent_helper {
ScalarExponent xdiv2 = x / ScalarExponent(2); using safe_abs_type = ScalarExponent;
ScalarExponent floorxdiv2 = numext::floor(xdiv2); static constexpr ScalarExponent one_half = ScalarExponent(0.5);
return xdiv2 != floorxdiv2; // these routines assume that exp is an integer stored as a floating point type
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent safe_abs(const ScalarExponent& exp) {
return numext::abs(exp);
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const ScalarExponent& exp) {
eigen_assert(((numext::isfinite)(exp) && exp == numext::floor(exp)) && "exp must be an integer");
ScalarExponent exp_div_2 = exp * one_half;
ScalarExponent floor_exp_div_2 = numext::floor(exp_div_2);
return exp_div_2 != floor_exp_div_2;
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent floor_div_two(const ScalarExponent& exp) {
ScalarExponent exp_div_2 = exp * one_half;
return numext::floor(exp_div_2);
} }
}; };
template <typename ScalarExponent> template <typename ScalarExponent>
struct is_odd<ScalarExponent, true> { struct exponent_helper<ScalarExponent, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) { // if `exp` is a signed integer type, cast it to its unsigned counterpart to safely store its absolute value
return x % ScalarExponent(2) != 0; // consider the (rare) case where `exp` is an int32_t: abs(-2147483648) != 2147483648
using safe_abs_type = typename numext::get_integer_by_size<sizeof(ScalarExponent)>::unsigned_type;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type safe_abs(const ScalarExponent& exp) {
ScalarExponent mask = exp ^ numext::abs(exp);
safe_abs_type result = static_cast<safe_abs_type>(exp);
return result ^ mask;
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool is_odd(const safe_abs_type& exp) {
return exp % safe_abs_type(2) != safe_abs_type(0);
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_abs_type floor_div_two(const safe_abs_type& exp) {
return exp >> safe_abs_type(1);
} }
}; };
template <typename Packet, typename ScalarExponent, template <typename Packet, typename ScalarExponent,
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger> bool ReciprocateIfExponentIsNegative =
struct do_div { !NumTraits<typename unpacket_traits<Packet>::type>::IsInteger && NumTraits<ScalarExponent>::IsSigned>
struct reciprocate {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
const Packet cst_pos_one = pset1<Packet>(Scalar(1)); const Packet cst_pos_one = pset1<Packet>(Scalar(1));
@ -1996,41 +2021,43 @@ struct do_div {
}; };
template <typename Packet, typename ScalarExponent> template <typename Packet, typename ScalarExponent>
struct do_div<Packet, ScalarExponent, true> { struct reciprocate<Packet, ScalarExponent, false> {
// pdiv not defined, nor necessary for integer base types // pdiv not defined, nor necessary for integer base types
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { // if the exponent is unsigned, then the exponent cannot be negative
return x; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { return x; }
}
}; };
template <typename Packet, typename ScalarExponent> template <typename Packet, typename ScalarExponent>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
using ExponentHelper = exponent_helper<ScalarExponent>;
using AbsExponentType = typename ExponentHelper::safe_abs_type;
const Packet cst_pos_one = pset1<Packet>(Scalar(1)); const Packet cst_pos_one = pset1<Packet>(Scalar(1));
if (exponent == 0) return cst_pos_one; if (exponent == ScalarExponent(0)) return cst_pos_one;
Packet result = x;
Packet result = reciprocate<Packet, ScalarExponent>::run(x, exponent);
Packet y = cst_pos_one; Packet y = cst_pos_one;
ScalarExponent m = numext::abs(exponent); AbsExponentType m = ExponentHelper::safe_abs(exponent);
while (m > 1) { while (m > 1) {
bool odd = is_odd<ScalarExponent>::run(m); bool odd = ExponentHelper::is_odd(m);
if (odd) y = pmul(y, result); if (odd) y = pmul(y, result);
result = pmul(result, result); result = pmul(result, result);
m = numext::floor(m / ScalarExponent(2)); m = ExponentHelper::floor_div_two(m);
} }
result = pmul(y, result);
result = do_div<Packet, ScalarExponent>::run(result, exponent); return pmul(y, result);
return result;
} }
template <typename Packet> template <typename Packet>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
const typename unpacket_traits<Packet>::type& exponent) { const typename unpacket_traits<Packet>::type& exponent) {
const Packet exponent_packet = pset1<Packet>(exponent); const Packet exponent_packet = pset1<Packet>(exponent);
return generic_pow_impl(x, exponent_packet); return generic_pow_impl(x, exponent_packet);
} }
template <typename Packet, typename ScalarExponent> template <typename Packet, typename ScalarExponent>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
const ScalarExponent& exponent) { const ScalarExponent& exponent) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
@ -2045,36 +2072,45 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(
const Packet cst_pos_one = pset1<Packet>(pos_one); const Packet cst_pos_one = pset1<Packet>(pos_one);
const Packet cst_pos_inf = pset1<Packet>(pos_inf); const Packet cst_pos_inf = pset1<Packet>(pos_inf);
const bool exponent_is_nan = (numext::isnan)(exponent); const bool exponent_is_not_fin = !(numext::isfinite)(exponent);
const bool exponent_is_fin = (numext::isfinite)(exponent);
const bool exponent_is_neg = exponent < ScalarExponent(0); const bool exponent_is_neg = exponent < ScalarExponent(0);
const bool exponent_is_pos = exponent > ScalarExponent(0);
const Packet exp_is_nan = pset1<Packet>(exponent_is_nan ? all_ones : pos_zero); const Packet exp_is_not_fin = pset1<Packet>(exponent_is_not_fin ? all_ones : pos_zero);
const Packet exp_is_fin = pset1<Packet>(exponent_is_fin ? all_ones : pos_zero);
const Packet exp_is_neg = pset1<Packet>(exponent_is_neg ? all_ones : pos_zero); const Packet exp_is_neg = pset1<Packet>(exponent_is_neg ? all_ones : pos_zero);
const Packet exp_is_pos = pset1<Packet>(exponent_is_pos ? all_ones : pos_zero);
const Packet exp_is_inf = pand(exp_is_not_fin, por(exp_is_neg, exp_is_pos));
const Packet exp_is_nan = pandnot(exp_is_not_fin, por(exp_is_neg, exp_is_pos));
const Packet x_is_gt_one = pcmp_lt(cst_pos_one, x); const Packet x_is_le_zero = pcmp_le(x, cst_pos_zero);
const Packet x_is_lt_one = pcmp_lt(x, cst_pos_one); const Packet x_is_ge_zero = pcmp_le(cst_pos_zero, x);
const Packet x_is_zero = pcmp_eq(x, cst_pos_zero); const Packet x_is_zero = pand(x_is_le_zero, x_is_ge_zero);
const Packet x_is_not_one = por(x_is_gt_one, x_is_lt_one);
const Packet inf_if_neg_exp = pand(cst_pos_inf, exp_is_neg); const Packet abs_x = pabs(x);
const Packet inf_if_pos_exp = pandnot(cst_pos_inf, exp_is_neg); const Packet abs_x_is_le_one = pcmp_le(abs_x, cst_pos_one);
const Packet abs_x_is_ge_one = pcmp_le(cst_pos_one, abs_x);
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
const Packet abs_x_is_one = pand(abs_x_is_le_one, abs_x_is_ge_one);
Packet pow_is_inf_if_exp_is_neg = por(x_is_zero, pand(abs_x_is_le_one, exp_is_inf));
Packet pow_is_inf_if_exp_is_pos = por(abs_x_is_inf, pand(abs_x_is_ge_one, exp_is_inf));
Packet pow_is_one = pand(abs_x_is_one, por(exp_is_inf, x_is_ge_zero));
Packet result = powx; Packet result = powx;
result = pselect(x_is_zero, inf_if_neg_exp, result); result = por(x_is_le_zero, result);
result = pselect(pandnot(x_is_gt_one, exp_is_fin), inf_if_pos_exp, result); result = pselect(pow_is_inf_if_exp_is_neg, pand(cst_pos_inf, exp_is_neg), result);
result = pselect(pandnot(x_is_lt_one, exp_is_fin), inf_if_neg_exp, result); result = pselect(pow_is_inf_if_exp_is_pos, pand(cst_pos_inf, exp_is_pos), result);
result = por(exp_is_nan, result); result = por(exp_is_nan, result);
result = pselect(x_is_not_one, result, cst_pos_one); result = pselect(pow_is_one, cst_pos_one, result);
return result; return result;
} }
template <typename Packet, typename ScalarExponent> template <typename Packet, typename ScalarExponent,
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) { std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsSigned, bool> = true>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent& exponent) {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
// integer base, integer exponent case // singed integer base, signed integer exponent case
// This routine handles negative exponents. // This routine handles negative exponents.
// The return value is either 0, 1, or -1. // The return value is either 0, 1, or -1.
@ -2085,7 +2121,7 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet&
const Packet cst_pos_one = pset1<Packet>(pos_one); const Packet cst_pos_one = pset1<Packet>(pos_one);
const bool exponent_is_odd = unary_pow::is_odd<ScalarExponent>::run(exponent); const bool exponent_is_odd = exponent % ScalarExponent(2) != ScalarExponent(0);
const Packet exp_is_odd = pset1<Packet>(exponent_is_odd ? all_ones : pos_zero); const Packet exp_is_odd = pset1<Packet>(exponent_is_odd ? all_ones : pos_zero);
@ -2097,15 +2133,36 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet&
return result; return result;
} }
template <typename Packet, typename ScalarExponent,
std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned, bool> = true>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Packet& x, const ScalarExponent&) {
using Scalar = typename unpacket_traits<Packet>::type;
// unsigned integer base, signed integer exponent case
// This routine handles negative exponents.
// The return value is either 0 or 1
const Scalar pos_one = Scalar(1);
const Packet cst_pos_one = pset1<Packet>(pos_one);
const Packet x_is_one = pcmp_eq(x, cst_pos_one);
return pand(x_is_one, x);
}
} // end namespace unary_pow } // end namespace unary_pow
template <typename Packet, typename ScalarExponent, template <typename Packet, typename ScalarExponent,
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger, bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger,
bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger> bool ExponentIsIntegerType = NumTraits<ScalarExponent>::IsInteger,
bool ExponentIsSigned = NumTraits<ScalarExponent>::IsSigned>
struct unary_pow_impl; struct unary_pow_impl;
template <typename Packet, typename ScalarExponent> template <typename Packet, typename ScalarExponent, bool ExponentIsSigned>
struct unary_pow_impl<Packet, ScalarExponent, false, false> { struct unary_pow_impl<Packet, ScalarExponent, false, false, ExponentIsSigned> {
typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::type Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent; const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
@ -2119,8 +2176,8 @@ struct unary_pow_impl<Packet, ScalarExponent, false, false> {
} }
}; };
template <typename Packet, typename ScalarExponent> template <typename Packet, typename ScalarExponent, bool ExponentIsSigned>
struct unary_pow_impl<Packet, ScalarExponent, false, true> { struct unary_pow_impl<Packet, ScalarExponent, false, true, ExponentIsSigned> {
typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::type Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
return unary_pow::int_pow(x, exponent); return unary_pow::int_pow(x, exponent);
@ -2128,17 +2185,25 @@ struct unary_pow_impl<Packet, ScalarExponent, false, true> {
}; };
template <typename Packet, typename ScalarExponent> template <typename Packet, typename ScalarExponent>
struct unary_pow_impl<Packet, ScalarExponent, true, true> { struct unary_pow_impl<Packet, ScalarExponent, true, true, true> {
typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::type Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
if (exponent < ScalarExponent(0)) { if (exponent < ScalarExponent(0)) {
return unary_pow::handle_int_int(x, exponent); return unary_pow::handle_negative_exponent(x, exponent);
} else { } else {
return unary_pow::int_pow(x, exponent); return unary_pow::int_pow(x, exponent);
} }
} }
}; };
template <typename Packet, typename ScalarExponent>
struct unary_pow_impl<Packet, ScalarExponent, true, true, false> {
typedef typename unpacket_traits<Packet>::type Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
return unary_pow::int_pow(x, exponent);
}
};
} // end namespace internal } // end namespace internal
} // end namespace Eigen } // end namespace Eigen

View File

@ -191,45 +191,89 @@ void unary_ops_test() {
*/ */
} }
template <typename Base, typename Exponent, bool ExpIsInteger = NumTraits<Exponent>::IsInteger>
struct ref_pow {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, static_cast<Base>(exponent)));
}
};
template <typename Scalar> template <typename Base, typename Exponent>
void pow_scalar_exponent_test() { struct ref_pow<Base, Exponent, true> {
using Int_t = typename internal::make_integer<Scalar>::type; static Base run(Base base, Exponent exponent) {
const Scalar tol = test_precision<Scalar>(); EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, exponent));
}
};
std::vector<Scalar> abs_vals = special_values<Scalar>(); template <typename Exponent, bool ExpIsInteger = NumTraits<Exponent>::IsInteger>
const Index num_vals = (Index)abs_vals.size(); struct pow_helper {
Map<Array<Scalar, Dynamic, 1>> bases(abs_vals.data(), num_vals); static bool is_integer_impl(const Exponent& exp) { return (numext::isfinite)(exp) && exp == numext::floor(exp); }
static bool is_odd_impl(const Exponent& exp) {
Exponent exp_div_2 = exp / Exponent(2);
Exponent floor_exp_div_2 = numext::floor(exp_div_2);
return exp_div_2 != floor_exp_div_2;
}
};
template <typename Exponent>
struct pow_helper<Exponent, true> {
static bool is_integer_impl(const Exponent&) { return true; }
static bool is_odd_impl(const Exponent& exp) { return exp % 2 != 0; }
};
template <typename Exponent>
bool is_integer(const Exponent& exp) {
return pow_helper<Exponent>::is_integer_impl(exp);
}
template <typename Exponent>
bool is_odd(const Exponent& exp) {
return pow_helper<Exponent>::is_odd_impl(exp);
}
template <typename Base, typename Exponent>
void float_pow_test_impl() {
const Base tol = test_precision<Base>();
std::vector<Base> abs_base_vals = special_values<Base>();
std::vector<Exponent> abs_exponent_vals = special_values<Exponent>();
for (int i = 0; i < 100; i++) {
abs_base_vals.push_back(internal::random<Base>(Base(0), Base(10)));
abs_exponent_vals.push_back(internal::random<Exponent>(Exponent(0), Exponent(10)));
}
const Index num_repeats = internal::packet_traits<Base>::size + 1;
ArrayX<Base> bases(num_repeats), eigenPow(num_repeats);
bool all_pass = true; bool all_pass = true;
for (Scalar abs_exponent : abs_vals) { for (Base abs_base : abs_base_vals)
for (Scalar exponent : {-abs_exponent, abs_exponent}) { for (Base base : {negative_or_zero(abs_base), abs_base}) {
// test integer exponent code path bases.setConstant(base);
bool exponent_is_integer = (numext::isfinite)(exponent) && (numext::round(exponent) == exponent) && for (Exponent abs_exponent : abs_exponent_vals) {
(numext::abs(exponent) < static_cast<Scalar>(NumTraits<Int_t>::highest())); for (Exponent exponent : {negative_or_zero(abs_exponent), abs_exponent}) {
if (exponent_is_integer) { eigenPow = bases.pow(exponent);
Int_t exponent_as_int = static_cast<Int_t>(exponent); for (Index j = 0; j < num_repeats; j++) {
Array<Scalar, Dynamic, 1> eigenPow = bases.pow(exponent_as_int); Base e = ref_pow<Base, Exponent>::run(bases(j), exponent);
for (Index j = 0; j < num_vals; j++) { if (is_integer(exponent)) {
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent)); // std::pow may return an incorrect result for a very large integral exponent
Scalar a = eigenPow(j); // if base is negative and the exponent is odd, then the result must be negative
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || // if std::pow returns otherwise, flip the sign
((numext::isnan)(a) && (numext::isnan)(e)); bool exp_is_odd = is_odd(exponent);
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a); bool base_is_neg = !(numext::isnan)(base) && (bool)numext::signbit(base);
all_pass &= success; bool result_is_neg = exp_is_odd && base_is_neg;
if (!success) { bool ref_is_neg = !(numext::isnan)(e) && (bool)numext::signbit(e);
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl; bool flip_sign = result_is_neg != ref_is_neg;
if (flip_sign) e = -e;
} }
}
} else { Base a = eigenPow(j);
// test floating point exponent code path #ifdef EIGEN_COMP_MSVC
Array<Scalar, Dynamic, 1> eigenPow = bases.pow(exponent); // Work around MSVC return value on underflow.
for (Index j = 0; j < num_vals; j++) { // if std::pow returns 0 and Eigen returns a denormalized value, then skip the test
Scalar e = static_cast<Scalar>(std::pow(bases(j), exponent)); int fpclass = std::fpclassify(a);
Scalar a = eigenPow(j); if (e == Base(0) && fpclass == FP_SUBNORMAL) continue;
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || #endif
((numext::isnan)(a) && (numext::isnan)(e));
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a); bool both_nan = (numext::isnan)(a) && (numext::isnan)(e);
bool exact_or_approx = (a == e) || internal::isApprox(a, e, tol);
bool same_sign = (bool)numext::signbit(e) == (bool)numext::signbit(a);
bool success = both_nan || (exact_or_approx && same_sign);
all_pass &= success; all_pass &= success;
if (!success) { if (!success) {
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl; std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;
@ -259,24 +303,9 @@ Scalar calc_overflow_threshold(const ScalarExponent exponent) {
} }
} }
template <typename Base, typename Exponent, bool ExpIsInteger = NumTraits<Exponent>::IsInteger>
struct ref_pow {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, static_cast<Base>(exponent)));
}
};
template <typename Base, typename Exponent>
struct ref_pow<Base, Exponent, true> {
static Base run(Base base, Exponent exponent) {
EIGEN_USING_STD(pow);
return static_cast<Base>(pow(base, exponent));
}
};
template <typename Base, typename Exponent> template <typename Base, typename Exponent>
void test_exponent(Exponent exponent) { void test_exponent(Exponent exponent) {
EIGEN_STATIC_ASSERT(NumTraits<Base>::IsInteger,THIS TEST IS ONLY INTENDED FOR BASE INTEGER TYPES)
const Base max_abs_bases = static_cast<Base>(10000); const Base max_abs_bases = static_cast<Base>(10000);
// avoid integer overflow in Base type // avoid integer overflow in Base type
Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent)); Base threshold = calc_overflow_threshold<Base, Exponent>(numext::abs(exponent));
@ -300,10 +329,10 @@ void test_exponent(Exponent exponent) {
for (Base a : y) { for (Base a : y) {
Base e = ref_pow<Base, Exponent>::run(base, exponent); Base e = ref_pow<Base, Exponent>::run(base, exponent);
bool pass = (a == e); bool pass = (a == e);
if (!NumTraits<Base>::IsInteger) { //if (!NumTraits<Base>::IsInteger) {
pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) || // pass = pass || (((numext::isfinite)(e) && internal::isApprox(a, e)) ||
((numext::isnan)(a) && (numext::isnan)(e))); // ((numext::isnan)(a) && (numext::isnan)(e)));
} //}
all_pass &= pass; all_pass &= pass;
if (!pass) { if (!pass) {
std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl; std::cout << "pow(" << base << "," << exponent << ") = " << a << " != " << e << std::endl;
@ -314,7 +343,7 @@ void test_exponent(Exponent exponent) {
} }
template <typename Base, typename Exponent> template <typename Base, typename Exponent>
void unary_pow_test() { void int_pow_test_impl() {
Exponent max_exponent = static_cast<Exponent>(NumTraits<Base>::digits()); Exponent max_exponent = static_cast<Exponent>(NumTraits<Base>::digits());
Exponent min_exponent = negative_or_zero(max_exponent); Exponent min_exponent = negative_or_zero(max_exponent);
@ -323,21 +352,26 @@ void unary_pow_test() {
} }
} }
void float_pow_test() {
float_pow_test_impl<float, float>();
float_pow_test_impl<double, double>();
}
void mixed_pow_test() { void mixed_pow_test() {
// The following cases will test promoting a smaller exponent type // The following cases will test promoting a smaller exponent type
// to a wider base type. // to a wider base type.
unary_pow_test<double, int>(); float_pow_test_impl<double, int>();
unary_pow_test<double, float>(); float_pow_test_impl<double, float>();
unary_pow_test<float, half>(); float_pow_test_impl<float, half>();
unary_pow_test<double, half>(); float_pow_test_impl<double, half>();
unary_pow_test<float, bfloat16>(); float_pow_test_impl<float, bfloat16>();
unary_pow_test<double, bfloat16>(); float_pow_test_impl<double, bfloat16>();
// Although in the following cases the exponent cannot be represented exactly // Although in the following cases the exponent cannot be represented exactly
// in the base type, we do not perform a conversion, but implement // in the base type, we do not perform a conversion, but implement
// the operation using repeated squaring. // the operation using repeated squaring.
unary_pow_test<float, int>(); float_pow_test_impl<float, int>();
unary_pow_test<double, long long>(); float_pow_test_impl<double, long long>();
// The following cases will test promoting a wider exponent type // The following cases will test promoting a wider exponent type
// to a narrower base type. This should compile but would generate a // to a narrower base type. This should compile but would generate a
@ -346,20 +380,20 @@ void mixed_pow_test() {
} }
void int_pow_test() { void int_pow_test() {
unary_pow_test<int, int>(); int_pow_test_impl<int, int>();
unary_pow_test<unsigned int, unsigned int>(); int_pow_test_impl<unsigned int, unsigned int>();
unary_pow_test<long long, long long>(); int_pow_test_impl<long long, long long>();
unary_pow_test<unsigned long long, unsigned long long>(); int_pow_test_impl<unsigned long long, unsigned long long>();
// Although in the following cases the exponent cannot be represented exactly // Although in the following cases the exponent cannot be represented exactly
// in the base type, we do not perform a conversion, but implement the // in the base type, we do not perform a conversion, but implement the
// operation using repeated squaring. // operation using repeated squaring.
unary_pow_test<long long, int>(); int_pow_test_impl<long long, int>();
unary_pow_test<int, unsigned int>(); int_pow_test_impl<int, unsigned int>();
unary_pow_test<unsigned int, int>(); int_pow_test_impl<unsigned int, int>();
unary_pow_test<long long, unsigned long long>(); int_pow_test_impl<long long, unsigned long long>();
unary_pow_test<unsigned long long, long long>(); int_pow_test_impl<unsigned long long, long long>();
unary_pow_test<long long, int>(); int_pow_test_impl<long long, int>();
} }
namespace Eigen { namespace Eigen {
@ -849,7 +883,6 @@ template<typename ArrayType> void array_real(const ArrayType& m)
// Test pow and atan2 on special IEEE values. // Test pow and atan2 on special IEEE values.
unary_ops_test<Scalar>(); unary_ops_test<Scalar>();
binary_ops_test<Scalar>(); binary_ops_test<Scalar>();
pow_scalar_exponent_test<Scalar>();
VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10))); VERIFY_IS_APPROX(log10(m3), log(m3)/numext::log(Scalar(10)));
VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2))); VERIFY_IS_APPROX(log2(m3), log(m3)/numext::log(Scalar(2)));
@ -1223,6 +1256,7 @@ EIGEN_DECLARE_TEST(array_cwise)
} }
for(int i = 0; i < g_repeat; i++) { for(int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_5( float_pow_test() );
CALL_SUBTEST_6( int_pow_test() ); CALL_SUBTEST_6( int_pow_test() );
CALL_SUBTEST_7( mixed_pow_test() ); CALL_SUBTEST_7( mixed_pow_test() );
CALL_SUBTEST_8( signbit_tests() ); CALL_SUBTEST_8( signbit_tests() );