mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-10 02:39:03 +08:00
Optimize scalar_unary_pow_op error handling
This commit is contained in:
parent
316eab8deb
commit
1d80e23186
@ -1972,7 +1972,7 @@ struct pchebevl {
|
||||
namespace unary_pow {
|
||||
template <typename ScalarExponent, bool IsIntegerAtCompileTime = NumTraits<ScalarExponent>::IsInteger>
|
||||
struct is_odd {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) {
|
||||
ScalarExponent xdiv2 = x / ScalarExponent(2);
|
||||
ScalarExponent floorxdiv2 = numext::floor(xdiv2);
|
||||
return xdiv2 != floorxdiv2;
|
||||
@ -1980,8 +1980,8 @@ struct is_odd {
|
||||
};
|
||||
template <typename ScalarExponent>
|
||||
struct is_odd<ScalarExponent, true> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) {
|
||||
return x % ScalarExponent(2);
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) {
|
||||
return x % ScalarExponent(2) != 0;
|
||||
}
|
||||
};
|
||||
|
||||
@ -1989,7 +1989,7 @@ template <typename Packet, typename ScalarExponent,
|
||||
bool BaseIsIntegerType = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>
|
||||
struct do_div {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||
return exponent < 0 ? pdiv(cst_pos_one, x) : x;
|
||||
}
|
||||
@ -1998,15 +1998,14 @@ struct do_div {
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct do_div<Packet, ScalarExponent, true> {
|
||||
// pdiv not defined, nor necessary for integer base types
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
EIGEN_UNUSED_VARIABLE(exponent);
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) {
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
|
||||
if (exponent == 0) return cst_pos_one;
|
||||
Packet result = x;
|
||||
@ -2031,178 +2030,70 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(const Packet& x, const Packet& powx,
|
||||
const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
|
||||
const ScalarExponent& exponent) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
|
||||
// non-integer base, integer exponent case
|
||||
|
||||
const bool exponent_is_odd = is_odd<ScalarExponent>::run(exponent);
|
||||
const bool exponent_is_neg = exponent < 0;
|
||||
|
||||
const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x);
|
||||
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
|
||||
// non-integer base and exponent case
|
||||
|
||||
const Scalar pos_zero = Scalar(0);
|
||||
const Scalar neg_zero = -Scalar(0);
|
||||
const Scalar all_ones = ptrue<Scalar>(Scalar());
|
||||
const Scalar pos_one = Scalar(1);
|
||||
const Scalar pos_inf = NumTraits<Scalar>::infinity();
|
||||
const Scalar neg_inf = -NumTraits<Scalar>::infinity();
|
||||
|
||||
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
|
||||
const Packet cst_neg_zero = pset1<Packet>(neg_zero);
|
||||
const Packet cst_pos_zero = pzero(x);
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
|
||||
const Packet cst_neg_inf = pset1<Packet>(neg_inf);
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
const bool exponent_is_nan = (numext::isnan)(exponent);
|
||||
const bool exponent_is_fin = (numext::isfinite)(exponent);
|
||||
const bool exponent_is_neg = exponent < ScalarExponent(0);
|
||||
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||
const Packet exp_is_nan = pset1<Packet>(exponent_is_nan ? all_ones : pos_zero);
|
||||
const Packet exp_is_fin = pset1<Packet>(exponent_is_fin ? all_ones : pos_zero);
|
||||
const Packet exp_is_neg = pset1<Packet>(exponent_is_neg ? all_ones : pos_zero);
|
||||
|
||||
if (exponent == 0) {
|
||||
return cst_pos_one;
|
||||
}
|
||||
const Packet x_is_gt_one = pcmp_lt(cst_pos_one, x);
|
||||
const Packet x_is_lt_one = pcmp_lt(x, cst_pos_one);
|
||||
const Packet x_is_zero = pcmp_eq(x, cst_pos_zero);
|
||||
const Packet x_is_not_one = por(x_is_gt_one, x_is_lt_one);
|
||||
|
||||
Packet pow_is_pos_inf = pand(pandnot(abs_x_is_zero, x_is_neg_zero), pand(exp_is_odd, exp_is_neg));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_zero, pandnot(exp_is_neg, exp_is_odd)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(pnot(exp_is_neg), exp_is_odd)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
|
||||
const Packet inf_if_neg_exp = pand(cst_pos_inf, exp_is_neg);
|
||||
const Packet inf_if_pos_exp = pandnot(cst_pos_inf, exp_is_neg);
|
||||
|
||||
Packet pow_is_neg_inf = pand(x_is_neg_zero, pand(exp_is_neg, exp_is_odd));
|
||||
pow_is_neg_inf = por(pow_is_neg_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_odd, exp_is_neg)));
|
||||
|
||||
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_neg, exp_is_odd)));
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg));
|
||||
|
||||
Packet pow_is_neg_zero = pand(x_is_neg_zero, pandnot(exp_is_odd, exp_is_neg));
|
||||
pow_is_neg_zero = por(pow_is_neg_zero, pand(pand(abs_x_is_inf, x_is_neg), pand(exp_is_odd, exp_is_neg)));
|
||||
|
||||
Packet result = pselect(pow_is_neg_inf, cst_neg_inf, powx);
|
||||
result = pselect(pow_is_neg_zero, cst_neg_zero, result);
|
||||
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
|
||||
result = pselect(pow_is_pos_inf, cst_pos_inf, result);
|
||||
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
|
||||
Packet result = powx;
|
||||
result = pselect(x_is_zero, inf_if_neg_exp, result);
|
||||
result = pselect(pandnot(x_is_gt_one, exp_is_fin), inf_if_pos_exp, result);
|
||||
result = pselect(pandnot(x_is_lt_one, exp_is_fin), inf_if_neg_exp, result);
|
||||
result = por(exp_is_nan, result);
|
||||
result = pselect(x_is_not_one, result, cst_pos_one);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
|
||||
const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) {
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
|
||||
// non-integer base and exponent case
|
||||
// integer base, integer exponent case
|
||||
|
||||
const bool exponent_is_fin = (numext::isfinite)(exponent);
|
||||
const bool exponent_is_nan = (numext::isnan)(exponent);
|
||||
const bool exponent_is_neg = exponent < 0;
|
||||
const bool exponent_is_inf = !exponent_is_fin && !exponent_is_nan;
|
||||
|
||||
const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x);
|
||||
const Packet exp_is_inf = exponent_is_inf ? ptrue(x) : pzero(x);
|
||||
// This routine handles negative exponents.
|
||||
// The return value is either 0, 1, or -1.
|
||||
|
||||
const Scalar pos_zero = Scalar(0);
|
||||
const Scalar all_ones = ptrue<Scalar>(Scalar());
|
||||
const Scalar pos_one = Scalar(1);
|
||||
const Scalar pos_inf = NumTraits<Scalar>::infinity();
|
||||
const Scalar nan = NumTraits<Scalar>::quiet_NaN();
|
||||
|
||||
const Packet cst_pos_zero = pset1<Packet>(pos_zero);
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
|
||||
const Packet cst_nan = pset1<Packet>(nan);
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero);
|
||||
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_gt_one = pcmp_lt(cst_pos_one, abs_x);
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
|
||||
if (exponent_is_nan) {
|
||||
return pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, cst_nan);
|
||||
}
|
||||
|
||||
Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg);
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_gt_one, pand(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_lt_one, pandnot(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_inf, exp_is_neg));
|
||||
|
||||
const Packet pow_is_pos_one = pand(abs_x_is_one, exp_is_inf);
|
||||
|
||||
Packet pow_is_pos_inf = pand(abs_x_is_zero, exp_is_neg);
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_lt_one, pand(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_gt_one, pandnot(exp_is_inf, exp_is_neg)));
|
||||
pow_is_pos_inf = por(pow_is_pos_inf, pandnot(abs_x_is_inf, exp_is_neg));
|
||||
|
||||
const Packet pow_is_nan = pandnot(pandnot(x_is_neg, abs_x_is_inf), exp_is_inf);
|
||||
|
||||
Packet result = pselect(pow_is_pos_inf, cst_pos_inf, powx);
|
||||
result = pselect(pow_is_pos_one, cst_pos_one, result);
|
||||
result = pselect(pow_is_pos_zero, cst_pos_zero, result);
|
||||
result = pselect(pow_is_nan, cst_nan, result);
|
||||
result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsSigned, bool> = true>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
// signed integer base, integer exponent case
|
||||
|
||||
// This routine handles negative and very large positive exponents
|
||||
// Signed integer overflow and divide by zero is undefined behavior
|
||||
// Unsigned integers do not overflow
|
||||
|
||||
const bool exponent_is_odd = unary_pow::is_odd<ScalarExponent>::run(exponent);
|
||||
|
||||
const Scalar zero = Scalar(0);
|
||||
const Scalar pos_one = Scalar(1);
|
||||
|
||||
const Packet cst_zero = pset1<Packet>(zero);
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
const Packet exp_is_odd = pset1<Packet>(exponent_is_odd ? all_ones : pos_zero);
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||
|
||||
const Packet pow_is_zero = exponent < 0 ? pcmp_lt(cst_pos_one, abs_x) : pzero(x);
|
||||
const Packet pow_is_one = pcmp_eq(cst_pos_one, abs_x);
|
||||
const Packet pow_is_neg = exponent_is_odd ? pcmp_lt(x, cst_zero) : pzero(x);
|
||||
|
||||
Packet result = pselect(pow_is_zero, cst_zero, x);
|
||||
result = pselect(pandnot(pow_is_one, pow_is_neg), cst_pos_one, result);
|
||||
result = pselect(pand(pow_is_one, pow_is_neg), pnegate(cst_pos_one), result);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Packet, typename ScalarExponent, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned, bool> = true>
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
// unsigned integer base, integer exponent case
|
||||
|
||||
// This routine handles negative and very large positive exponents
|
||||
// Signed integer overflow and divide by zero is undefined behavior
|
||||
// Unsigned integers do not overflow
|
||||
|
||||
const Scalar zero = Scalar(0);
|
||||
const Scalar pos_one = Scalar(1);
|
||||
|
||||
const Packet cst_zero = pset1<Packet>(zero);
|
||||
const Packet cst_pos_one = pset1<Packet>(pos_one);
|
||||
|
||||
const Packet pow_is_zero = exponent < 0 ? pcmp_lt(cst_pos_one, x) : pzero(x);
|
||||
const Packet pow_is_one = pcmp_eq(cst_pos_one, x);
|
||||
|
||||
Packet result = pselect(pow_is_zero, cst_zero, x);
|
||||
result = pselect(pow_is_one, cst_pos_one, result);
|
||||
Packet result = pselect(exp_is_odd, x, abs_x);
|
||||
result = pand(abs_x_is_one, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -2219,9 +2110,7 @@ struct unary_pow_impl<Packet, ScalarExponent, false, false> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
|
||||
if (exponent_is_integer) {
|
||||
Packet result = unary_pow::int_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
|
||||
return result;
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
} else {
|
||||
Packet result = unary_pow::gen_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_nonint_errors(x, result, exponent);
|
||||
@ -2234,23 +2123,20 @@ template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, true> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
Packet result = unary_pow::int_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_int_errors(x, result, exponent);
|
||||
return result;
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Packet, typename ScalarExponent>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, true, true> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
if (exponent < 0 || exponent > NumTraits<Scalar>::digits()) {
|
||||
return unary_pow::handle_int_int(x, exponent);
|
||||
}
|
||||
else {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
if (exponent < ScalarExponent(0)) {
|
||||
return unary_pow::handle_int_int(x, exponent);
|
||||
} else {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
Loading…
x
Reference in New Issue
Block a user