From 1d80e23186a1ab82bc0dadc10cd1826f546cdacc Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Fri, 2 Jun 2023 18:53:06 +0000 Subject: [PATCH] Optimize scalar_unary_pow_op error handling --- .../arch/Default/GenericPacketMathFunctions.h | 212 ++++-------------- 1 file changed, 49 insertions(+), 163 deletions(-) diff --git a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h index 13fafe309..70c83d64a 100644 --- a/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h +++ b/Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h @@ -1972,7 +1972,7 @@ struct pchebevl { namespace unary_pow { template ::IsInteger> struct is_odd { - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) { ScalarExponent xdiv2 = x / ScalarExponent(2); ScalarExponent floorxdiv2 = numext::floor(xdiv2); return xdiv2 != floorxdiv2; @@ -1980,8 +1980,8 @@ struct is_odd { }; template struct is_odd { - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ScalarExponent run(const ScalarExponent& x) { - return x % ScalarExponent(2); + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool run(const ScalarExponent& x) { + return x % ScalarExponent(2) != 0; } }; @@ -1989,7 +1989,7 @@ template ::type>::IsInteger> struct do_div { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - typedef typename unpacket_traits::type Scalar; + using Scalar = typename unpacket_traits::type; const Packet cst_pos_one = pset1(Scalar(1)); return exponent < 0 ? pdiv(cst_pos_one, x) : x; } @@ -1998,15 +1998,14 @@ struct do_div { template struct do_div { // pdiv not defined, nor necessary for integer base types - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - EIGEN_UNUSED_VARIABLE(exponent); + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent&) { return x; } }; template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const ScalarExponent& exponent) { - typedef typename unpacket_traits::type Scalar; + using Scalar = typename unpacket_traits::type; const Packet cst_pos_one = pset1(Scalar(1)); if (exponent == 0) return cst_pos_one; Packet result = x; @@ -2031,178 +2030,70 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x, } template -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(const Packet& x, const Packet& powx, - const ScalarExponent& exponent) { - typedef typename unpacket_traits::type Scalar; +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, + const ScalarExponent& exponent) { + using Scalar = typename unpacket_traits::type; - // non-integer base, integer exponent case - - const bool exponent_is_odd = is_odd::run(exponent); - const bool exponent_is_neg = exponent < 0; - - const Packet exp_is_odd = exponent_is_odd ? ptrue(x) : pzero(x); - const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x); + // non-integer base and exponent case const Scalar pos_zero = Scalar(0); - const Scalar neg_zero = -Scalar(0); + const Scalar all_ones = ptrue(Scalar()); const Scalar pos_one = Scalar(1); const Scalar pos_inf = NumTraits::infinity(); - const Scalar neg_inf = -NumTraits::infinity(); - const Packet cst_pos_zero = pset1(pos_zero); - const Packet cst_neg_zero = pset1(neg_zero); + const Packet cst_pos_zero = pzero(x); const Packet cst_pos_one = pset1(pos_one); const Packet cst_pos_inf = pset1(pos_inf); - const Packet cst_neg_inf = pset1(neg_inf); - const Packet abs_x = pabs(x); - const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero); - const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); - const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); + const bool exponent_is_nan = (numext::isnan)(exponent); + const bool exponent_is_fin = (numext::isfinite)(exponent); + const bool exponent_is_neg = exponent < ScalarExponent(0); - const Packet x_has_signbit = psignbit(x); - const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero); - const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero); + const Packet exp_is_nan = pset1(exponent_is_nan ? all_ones : pos_zero); + const Packet exp_is_fin = pset1(exponent_is_fin ? all_ones : pos_zero); + const Packet exp_is_neg = pset1(exponent_is_neg ? all_ones : pos_zero); - if (exponent == 0) { - return cst_pos_one; - } + const Packet x_is_gt_one = pcmp_lt(cst_pos_one, x); + const Packet x_is_lt_one = pcmp_lt(x, cst_pos_one); + const Packet x_is_zero = pcmp_eq(x, cst_pos_zero); + const Packet x_is_not_one = por(x_is_gt_one, x_is_lt_one); - Packet pow_is_pos_inf = pand(pandnot(abs_x_is_zero, x_is_neg_zero), pand(exp_is_odd, exp_is_neg)); - pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_zero, pandnot(exp_is_neg, exp_is_odd))); - pow_is_pos_inf = por(pow_is_pos_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(pnot(exp_is_neg), exp_is_odd))); - pow_is_pos_inf = por(pow_is_pos_inf, pandnot(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg)); + const Packet inf_if_neg_exp = pand(cst_pos_inf, exp_is_neg); + const Packet inf_if_pos_exp = pandnot(cst_pos_inf, exp_is_neg); - Packet pow_is_neg_inf = pand(x_is_neg_zero, pand(exp_is_neg, exp_is_odd)); - pow_is_neg_inf = por(pow_is_neg_inf, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_odd, exp_is_neg))); - - Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg); - pow_is_pos_zero = por(pow_is_pos_zero, pand(pand(abs_x_is_inf, x_is_neg), pandnot(exp_is_neg, exp_is_odd))); - pow_is_pos_zero = por(pow_is_pos_zero, pand(pandnot(abs_x_is_inf, x_is_neg), exp_is_neg)); - - Packet pow_is_neg_zero = pand(x_is_neg_zero, pandnot(exp_is_odd, exp_is_neg)); - pow_is_neg_zero = por(pow_is_neg_zero, pand(pand(abs_x_is_inf, x_is_neg), pand(exp_is_odd, exp_is_neg))); - - Packet result = pselect(pow_is_neg_inf, cst_neg_inf, powx); - result = pselect(pow_is_neg_zero, cst_neg_zero, result); - result = pselect(pow_is_pos_zero, cst_pos_zero, result); - result = pselect(pow_is_pos_inf, cst_pos_inf, result); - result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result); + Packet result = powx; + result = pselect(x_is_zero, inf_if_neg_exp, result); + result = pselect(pandnot(x_is_gt_one, exp_is_fin), inf_if_pos_exp, result); + result = pselect(pandnot(x_is_lt_one, exp_is_fin), inf_if_neg_exp, result); + result = por(exp_is_nan, result); + result = pselect(x_is_not_one, result, cst_pos_one); return result; } template -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx, - const ScalarExponent& exponent) { - typedef typename unpacket_traits::type Scalar; +static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) { + using Scalar = typename unpacket_traits::type; - // non-integer base and exponent case + // integer base, integer exponent case - const bool exponent_is_fin = (numext::isfinite)(exponent); - const bool exponent_is_nan = (numext::isnan)(exponent); - const bool exponent_is_neg = exponent < 0; - const bool exponent_is_inf = !exponent_is_fin && !exponent_is_nan; - - const Packet exp_is_neg = exponent_is_neg ? ptrue(x) : pzero(x); - const Packet exp_is_inf = exponent_is_inf ? ptrue(x) : pzero(x); + // This routine handles negative exponents. + // The return value is either 0, 1, or -1. const Scalar pos_zero = Scalar(0); + const Scalar all_ones = ptrue(Scalar()); const Scalar pos_one = Scalar(1); - const Scalar pos_inf = NumTraits::infinity(); - const Scalar nan = NumTraits::quiet_NaN(); - const Packet cst_pos_zero = pset1(pos_zero); const Packet cst_pos_one = pset1(pos_one); - const Packet cst_pos_inf = pset1(pos_inf); - const Packet cst_nan = pset1(nan); - - const Packet abs_x = pabs(x); - const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_pos_zero); - const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_pos_one); - const Packet abs_x_is_gt_one = pcmp_lt(cst_pos_one, abs_x); - const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); - const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf); - - const Packet x_has_signbit = psignbit(x); - const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero); - - if (exponent_is_nan) { - return pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, cst_nan); - } - - Packet pow_is_pos_zero = pandnot(abs_x_is_zero, exp_is_neg); - pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_gt_one, pand(exp_is_inf, exp_is_neg))); - pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_lt_one, pandnot(exp_is_inf, exp_is_neg))); - pow_is_pos_zero = por(pow_is_pos_zero, pand(abs_x_is_inf, exp_is_neg)); - - const Packet pow_is_pos_one = pand(abs_x_is_one, exp_is_inf); - - Packet pow_is_pos_inf = pand(abs_x_is_zero, exp_is_neg); - pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_lt_one, pand(exp_is_inf, exp_is_neg))); - pow_is_pos_inf = por(pow_is_pos_inf, pand(abs_x_is_gt_one, pandnot(exp_is_inf, exp_is_neg))); - pow_is_pos_inf = por(pow_is_pos_inf, pandnot(abs_x_is_inf, exp_is_neg)); - - const Packet pow_is_nan = pandnot(pandnot(x_is_neg, abs_x_is_inf), exp_is_inf); - - Packet result = pselect(pow_is_pos_inf, cst_pos_inf, powx); - result = pselect(pow_is_pos_one, cst_pos_one, result); - result = pselect(pow_is_pos_zero, cst_pos_zero, result); - result = pselect(pow_is_nan, cst_nan, result); - result = pselect(pandnot(abs_x_is_one, x_is_neg), cst_pos_one, result); - return result; -} - -template ::type>::IsSigned, bool> = true> -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) { - typedef typename unpacket_traits::type Scalar; - - // signed integer base, integer exponent case - - // This routine handles negative and very large positive exponents - // Signed integer overflow and divide by zero is undefined behavior - // Unsigned integers do not overflow const bool exponent_is_odd = unary_pow::is_odd::run(exponent); - const Scalar zero = Scalar(0); - const Scalar pos_one = Scalar(1); - - const Packet cst_zero = pset1(zero); - const Packet cst_pos_one = pset1(pos_one); + const Packet exp_is_odd = pset1(exponent_is_odd ? all_ones : pos_zero); const Packet abs_x = pabs(x); + const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one); - const Packet pow_is_zero = exponent < 0 ? pcmp_lt(cst_pos_one, abs_x) : pzero(x); - const Packet pow_is_one = pcmp_eq(cst_pos_one, abs_x); - const Packet pow_is_neg = exponent_is_odd ? pcmp_lt(x, cst_zero) : pzero(x); - - Packet result = pselect(pow_is_zero, cst_zero, x); - result = pselect(pandnot(pow_is_one, pow_is_neg), cst_pos_one, result); - result = pselect(pand(pow_is_one, pow_is_neg), pnegate(cst_pos_one), result); - return result; -} - -template ::type>::IsSigned, bool> = true> -static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_int_int(const Packet& x, const ScalarExponent& exponent) { - typedef typename unpacket_traits::type Scalar; - - // unsigned integer base, integer exponent case - - // This routine handles negative and very large positive exponents - // Signed integer overflow and divide by zero is undefined behavior - // Unsigned integers do not overflow - - const Scalar zero = Scalar(0); - const Scalar pos_one = Scalar(1); - - const Packet cst_zero = pset1(zero); - const Packet cst_pos_one = pset1(pos_one); - - const Packet pow_is_zero = exponent < 0 ? pcmp_lt(cst_pos_one, x) : pzero(x); - const Packet pow_is_one = pcmp_eq(cst_pos_one, x); - - Packet result = pselect(pow_is_zero, cst_zero, x); - result = pselect(pow_is_one, cst_pos_one, result); + Packet result = pselect(exp_is_odd, x, abs_x); + result = pand(abs_x_is_one, result); return result; } @@ -2219,9 +2110,7 @@ struct unary_pow_impl { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent; if (exponent_is_integer) { - Packet result = unary_pow::int_pow(x, exponent); - result = unary_pow::handle_nonint_int_errors(x, result, exponent); - return result; + return unary_pow::int_pow(x, exponent); } else { Packet result = unary_pow::gen_pow(x, exponent); result = unary_pow::handle_nonint_nonint_errors(x, result, exponent); @@ -2234,23 +2123,20 @@ template struct unary_pow_impl { typedef typename unpacket_traits::type Scalar; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - Packet result = unary_pow::int_pow(x, exponent); - result = unary_pow::handle_nonint_int_errors(x, result, exponent); - return result; + return unary_pow::int_pow(x, exponent); } }; template struct unary_pow_impl { - typedef typename unpacket_traits::type Scalar; - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { - if (exponent < 0 || exponent > NumTraits::digits()) { - return unary_pow::handle_int_int(x, exponent); - } - else { - return unary_pow::int_pow(x, exponent); - } + typedef typename unpacket_traits::type Scalar; + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) { + if (exponent < ScalarExponent(0)) { + return unary_pow::handle_int_int(x, exponent); + } else { + return unary_pow::int_pow(x, exponent); } + } }; } // end namespace internal