mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 02:29:33 +08:00
Fix code and unit test for a few corner cases in vectorized pow()
(cherry picked from commit 7a87ed1b6a49bd0067856dcba9ad9a3a46186220)
This commit is contained in:
parent
61efca2e90
commit
a9490cd3c5
@ -1443,21 +1443,22 @@ EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generic implementation of pow(x,y).
|
// Generic implementation of pow(x,y).
|
||||||
template<typename Packet>
|
template <typename Packet>
|
||||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Packet& x, const Packet& y) {
|
||||||
EIGEN_UNUSED
|
|
||||||
Packet generic_pow(const Packet& x, const Packet& y) {
|
|
||||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||||
|
|
||||||
const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
|
const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
|
||||||
|
const Packet cst_neg_inf = pset1<Packet>(-NumTraits<Scalar>::infinity());
|
||||||
const Packet cst_zero = pset1<Packet>(Scalar(0));
|
const Packet cst_zero = pset1<Packet>(Scalar(0));
|
||||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||||
const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
|
const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
|
||||||
|
|
||||||
const Packet abs_x = pabs(x);
|
const Packet abs_x = pabs(x);
|
||||||
// Predicates for sign and magnitude of x.
|
// Predicates for sign and magnitude of x.
|
||||||
const Packet x_is_zero = pcmp_eq(x, cst_zero);
|
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_zero);
|
||||||
const Packet x_is_neg = pcmp_lt(x, cst_zero);
|
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||||
|
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||||
|
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
|
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
|
||||||
const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
|
const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
|
||||||
@ -1467,15 +1468,15 @@ Packet generic_pow(const Packet& x, const Packet& y) {
|
|||||||
const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
|
const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
|
||||||
|
|
||||||
// Predicates for sign and magnitude of y.
|
// Predicates for sign and magnitude of y.
|
||||||
|
const Packet abs_y = pabs(y);
|
||||||
const Packet y_is_one = pcmp_eq(y, cst_one);
|
const Packet y_is_one = pcmp_eq(y, cst_one);
|
||||||
const Packet y_is_zero = pcmp_eq(y, cst_zero);
|
const Packet abs_y_is_zero = pcmp_eq(abs_y, cst_zero);
|
||||||
const Packet y_is_neg = pcmp_lt(y, cst_zero);
|
const Packet y_is_neg = pcmp_lt(y, cst_zero);
|
||||||
const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
|
const Packet y_is_pos = pandnot(ptrue(y), por(abs_y_is_zero, y_is_neg));
|
||||||
const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
|
const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
|
||||||
const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
|
const Packet abs_y_is_inf = pcmp_eq(abs_y, cst_pos_inf);
|
||||||
EIGEN_CONSTEXPR Scalar huge_exponent =
|
EIGEN_CONSTEXPR Scalar huge_exponent =
|
||||||
(NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) /
|
(NumTraits<Scalar>::max_exponent() * Scalar(EIGEN_LN2)) / NumTraits<Scalar>::epsilon();
|
||||||
NumTraits<Scalar>::epsilon();
|
|
||||||
const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
|
const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
|
||||||
|
|
||||||
// Predicates for whether y is integer and/or even.
|
// Predicates for whether y is integer and/or even.
|
||||||
@ -1484,39 +1485,31 @@ Packet generic_pow(const Packet& x, const Packet& y) {
|
|||||||
const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
|
const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
|
||||||
|
|
||||||
// Predicates encoding special cases for the value of pow(x,y)
|
// Predicates encoding special cases for the value of pow(x,y)
|
||||||
const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf),
|
const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf);
|
||||||
y_is_int),
|
|
||||||
abs_y_is_inf);
|
|
||||||
const Packet pow_is_one = por(por(x_is_one, y_is_zero),
|
|
||||||
pand(x_is_neg_one,
|
|
||||||
por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
|
|
||||||
const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
|
const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
|
||||||
const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos),
|
const Packet pow_is_one =
|
||||||
pand(abs_x_is_inf, y_is_neg)),
|
por(por(x_is_one, abs_y_is_zero), pand(x_is_neg_one, por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
|
||||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge),
|
const Packet pow_is_zero = por(por(por(pand(abs_x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)),
|
||||||
y_is_pos)),
|
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_pos)),
|
||||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge),
|
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_neg));
|
||||||
y_is_neg));
|
const Packet pow_is_inf = por(por(por(pand(abs_x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
|
||||||
const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg),
|
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)),
|
||||||
pand(abs_x_is_inf, y_is_pos)),
|
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos));
|
||||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge),
|
const Packet inf_val =
|
||||||
y_is_neg)),
|
pselect(pandnot(pand(por(pand(abs_x_is_inf, x_is_neg), pand(x_is_neg_zero, y_is_neg)), y_is_int), y_is_even),
|
||||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge),
|
cst_neg_inf, cst_pos_inf);
|
||||||
y_is_pos));
|
|
||||||
|
|
||||||
// General computation of pow(x,y) for positive x or negative x and integer y.
|
// General computation of pow(x,y) for positive x or negative x and integer y.
|
||||||
const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
|
const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
|
||||||
const Packet pow_abs = generic_pow_impl(abs_x, y);
|
const Packet pow_abs = generic_pow_impl(abs_x, y);
|
||||||
return pselect(y_is_one, x,
|
return pselect(
|
||||||
|
y_is_one, x,
|
||||||
pselect(pow_is_one, cst_one,
|
pselect(pow_is_one, cst_one,
|
||||||
pselect(pow_is_nan, cst_nan,
|
pselect(pow_is_nan, cst_nan,
|
||||||
pselect(pow_is_inf, cst_pos_inf,
|
pselect(pow_is_inf, inf_val,
|
||||||
pselect(pow_is_zero, cst_zero,
|
pselect(pow_is_zero, cst_zero, pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
|
||||||
pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/* polevl (modified for Eigen)
|
/* polevl (modified for Eigen)
|
||||||
*
|
*
|
||||||
* Evaluate polynomial
|
* Evaluate polynomial
|
||||||
|
@ -72,9 +72,9 @@ void pow_test() {
|
|||||||
for (int j = 0; j < num_cases; ++j) {
|
for (int j = 0; j < num_cases; ++j) {
|
||||||
Scalar e = static_cast<Scalar>(std::pow(x(i,j), y(i,j)));
|
Scalar e = static_cast<Scalar>(std::pow(x(i,j), y(i,j)));
|
||||||
Scalar a = actual(i, j);
|
Scalar a = actual(i, j);
|
||||||
bool fail = !(a==e) && !internal::isApprox(a, e, tol) && !((numext::isnan)(a) && (numext::isnan)(e));
|
bool success = (a==e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
|
||||||
all_pass &= !fail;
|
all_pass &= success;
|
||||||
if (fail) {
|
if (!success) {
|
||||||
std::cout << "pow(" << x(i,j) << "," << y(i,j) << ") = " << a << " != " << e << std::endl;
|
std::cout << "pow(" << x(i,j) << "," << y(i,j) << ") = " << a << " != " << e << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user