Improve pow(x,y): 25% speedup, increase accuracy for integer exponents.

This commit is contained in:
Rasmus Munk Larsen 2024-11-26 06:13:48 +00:00
parent 8ad4344ca7
commit 1ea61a5d26
2 changed files with 50 additions and 78 deletions

View File

@ -194,7 +194,7 @@ struct is_half {
static constexpr int Size = unpacket_traits<Packet>::size;
using DefaultPacket = typename packet_traits<Scalar>::type;
static constexpr int DefaultSize = unpacket_traits<DefaultPacket>::size;
static constexpr bool value = Size < DefaultSize;
static constexpr bool value = Size != 1 && Size < DefaultSize;
};
template <typename Src, typename Tgt>

View File

@ -1837,81 +1837,51 @@ struct accurate_log2 {
};
// This specialization uses a more accurate algorithm to compute log2(x) for
// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10.
// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.56508e-10.
// This additional accuracy is needed to counter the error-magnification
// inherent in multiplying by a potentially large exponent in pow(x,y).
// The minimax polynomial used was calculated using the Sollya tool.
// See sollya.org.
// The minimax polynomial used was calculated using the Rminimax tool,
// see https://gitlab.inria.fr/sfilip/rminimax.
// Command line:
// $ ratapprox --function="log2(1+x)/x" --dom='[-0.2929,0.41422]' --type=[10,0]
// --numF="[D,D,SG]" --denF="[SG]" --log --dispCoeff="dec"
//
// The resulting implementation of pow(x,y) is accurate to 3 ulps.
template <>
struct accurate_log2<float> {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
// The function log(1+x)/x is approximated in the interval
// [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form
// Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))),
// where the degree 6 polynomial P(x) is evaluated in single precision,
// while the remaining 4 terms of Q(x), as well as the final multiplication by x
// to reconstruct log(1+x) are evaluated in extra precision using
// double word arithmetic. C0 through C3 are extra precise constants
// stored as double words.
//
// The polynomial coefficients were calculated using Sollya commands:
// > n = 10;
// > f = log2(1+x)/x;
// > interval = [sqrt(0.5)-1;sqrt(2)-1];
// > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);
// // Split the two lowest order constant coefficient into double-word representation.
constexpr double kC0 = 1.442695041742110273474963832995854318141937255859375e+00;
constexpr float kC0_hi = static_cast<float>(kC0);
constexpr float kC0_lo = static_cast<float>(kC0 - static_cast<double>(kC0_hi));
const Packet c0_hi = pset1<Packet>(kC0_hi);
const Packet c0_lo = pset1<Packet>(kC0_lo);
const Packet p6 = pset1<Packet>(9.703654795885e-2f);
const Packet p5 = pset1<Packet>(-0.1690667718648f);
const Packet p4 = pset1<Packet>(0.1720575392246f);
const Packet p3 = pset1<Packet>(-0.1789081543684f);
const Packet p2 = pset1<Packet>(0.2050433009862f);
const Packet p1 = pset1<Packet>(-0.2404672354459f);
const Packet p0 = pset1<Packet>(0.2885761857032f);
constexpr double kC1 = -7.2134751588268664068692714863573201000690460205078125e-01;
constexpr float kC1_hi = static_cast<float>(kC1);
constexpr float kC1_lo = static_cast<float>(kC1 - static_cast<double>(kC1_hi));
const Packet c1_hi = pset1<Packet>(kC1_hi);
const Packet c1_lo = pset1<Packet>(kC1_lo);
const Packet C3_hi = pset1<Packet>(-0.360674142838f);
const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f);
const Packet C2_hi = pset1<Packet>(0.480897903442f);
const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f);
const Packet C1_hi = pset1<Packet>(-0.721347510815f);
const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f);
const Packet C0_hi = pset1<Packet>(1.44269502163f);
const Packet C0_lo = pset1<Packet>(2.01711713999e-08f);
constexpr float c[] = {
9.7010828554630279541015625e-02, -1.6896486282348632812500000e-01, 1.7200836539268493652343750e-01,
-1.7892272770404815673828125e-01, 2.0505344867706298828125000e-01, -2.4046677350997924804687500e-01,
2.8857553005218505859375000e-01, -3.6067414283752441406250000e-01, 4.8089790344238281250000000e-01};
// Evaluate the higher order terms in the polynomial using
// standard arithmetic.
const Packet one = pset1<Packet>(1.0f);
const Packet x = psub(z, one);
// Evaluate P(x) in working precision.
// We evaluate it in multiple parts to improve instruction level
// parallelism.
Packet x2 = pmul(x, x);
Packet p_even = pmadd(p6, x2, p4);
p_even = pmadd(p_even, x2, p2);
p_even = pmadd(p_even, x2, p0);
Packet p_odd = pmadd(p5, x2, p3);
p_odd = pmadd(p_odd, x2, p1);
Packet p = pmadd(p_odd, x, p_even);
// Now evaluate the low-order tems of Q(x) in double word precision.
// In the following, due to the alternating signs and the fact that
// |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use
// fast_twosum instead of the slower twosum.
Packet q_hi, q_lo;
Packet t_hi, t_lo;
// C3 + x * p(x)
twoprod(p, x, t_hi, t_lo);
fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo);
// C2 + x * p(x)
twoprod(q_hi, q_lo, x, t_hi, t_lo);
fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo);
// C1 + x * p(x)
twoprod(q_hi, q_lo, x, t_hi, t_lo);
fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo);
// C0 + x * p(x)
twoprod(q_hi, q_lo, x, t_hi, t_lo);
fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo);
// log(z) ~= x * Q(x)
twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo);
Packet p = ppolevl<Packet, 8>::run(x, c);
// Evaluate the final two step in Horner's rule using double-word arithmetic.
Packet p_hi, p_lo;
twoprod(x, p, p_hi, p_lo);
fast_twosum(c1_hi, c1_lo, p_hi, p_lo, p_hi, p_lo);
twoprod(p_hi, p_lo, x, p_hi, p_lo);
fast_twosum(c0_hi, c0_lo, p_hi, p_lo, p_hi, p_lo);
// Multiply by x to recover log2(z).
twoprod(p_hi, p_lo, x, log2_x_hi, log2_x_lo);
}
};
@ -2006,16 +1976,6 @@ struct accurate_log2<double> {
}
};
// This function accurately computes exp2(x) for x in [-0.5:0.5], which is
// needed in pow(x,y).
template <typename Scalar>
struct fast_accurate_exp2 {
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) {
return generic_exp2(x);
}
};
// This function implements the non-trivial case of pow(x,y) where x is
// positive and y is (possibly) non-integer.
// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
@ -2343,7 +2303,13 @@ struct unary_pow_impl<Packet, ScalarExponent, false, false, ExponentIsSigned> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
if (exponent_is_integer) {
return unary_pow::int_pow(x, exponent);
// The simple recursive doubling implementation is only accurate to 3 ulps for
// integer exponents in [-3:7]. Since this is a common case, we specialize it here.
if (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))) {
return unary_pow::int_pow(x, exponent);
}
// TODO(rmlarsen): Implement more efficient special case handling.
return generic_pow(x, pset1<Packet>(exponent));
} else {
Packet result = unary_pow::gen_pow(x, exponent);
result = unary_pow::handle_nonint_nonint_errors(x, result, exponent);
@ -2356,7 +2322,13 @@ template <typename Packet, typename ScalarExponent, bool ExponentIsSigned>
struct unary_pow_impl<Packet, ScalarExponent, false, true, ExponentIsSigned> {
typedef typename unpacket_traits<Packet>::type Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
return unary_pow::int_pow(x, exponent);
// The simple recursive doubling implementation is only sufficiently accurate to 3 ulps for
// integer exponents in [-3:7]. Since this is a common case, we specialize it here.
if (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))) {
return unary_pow::int_pow(x, exponent);
}
// TODO(rmlarsen): Implement more efficient special case handling.
return generic_pow<Packet>(x, pset1<Packet>(Scalar(exponent)));
}
};