mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-22 12:54:26 +08:00
Improve pow(x,y): 25% speedup, increase accuracy for integer exponents.
This commit is contained in:
parent
8ad4344ca7
commit
1ea61a5d26
@ -194,7 +194,7 @@ struct is_half {
|
||||
static constexpr int Size = unpacket_traits<Packet>::size;
|
||||
using DefaultPacket = typename packet_traits<Scalar>::type;
|
||||
static constexpr int DefaultSize = unpacket_traits<DefaultPacket>::size;
|
||||
static constexpr bool value = Size < DefaultSize;
|
||||
static constexpr bool value = Size != 1 && Size < DefaultSize;
|
||||
};
|
||||
|
||||
template <typename Src, typename Tgt>
|
||||
|
@ -1837,81 +1837,51 @@ struct accurate_log2 {
|
||||
};
|
||||
|
||||
// This specialization uses a more accurate algorithm to compute log2(x) for
|
||||
// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10.
|
||||
// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.56508e-10.
|
||||
// This additional accuracy is needed to counter the error-magnification
|
||||
// inherent in multiplying by a potentially large exponent in pow(x,y).
|
||||
// The minimax polynomial used was calculated using the Sollya tool.
|
||||
// See sollya.org.
|
||||
// The minimax polynomial used was calculated using the Rminimax tool,
|
||||
// see https://gitlab.inria.fr/sfilip/rminimax.
|
||||
// Command line:
|
||||
// $ ratapprox --function="log2(1+x)/x" --dom='[-0.2929,0.41422]' --type=[10,0]
|
||||
// --numF="[D,D,SG]" --denF="[SG]" --log --dispCoeff="dec"
|
||||
//
|
||||
// The resulting implementation of pow(x,y) is accurate to 3 ulps.
|
||||
template <>
|
||||
struct accurate_log2<float> {
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
|
||||
// The function log(1+x)/x is approximated in the interval
|
||||
// [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form
|
||||
// Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))),
|
||||
// where the degree 6 polynomial P(x) is evaluated in single precision,
|
||||
// while the remaining 4 terms of Q(x), as well as the final multiplication by x
|
||||
// to reconstruct log(1+x) are evaluated in extra precision using
|
||||
// double word arithmetic. C0 through C3 are extra precise constants
|
||||
// stored as double words.
|
||||
//
|
||||
// The polynomial coefficients were calculated using Sollya commands:
|
||||
// > n = 10;
|
||||
// > f = log2(1+x)/x;
|
||||
// > interval = [sqrt(0.5)-1;sqrt(2)-1];
|
||||
// > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);
|
||||
// // Split the two lowest order constant coefficient into double-word representation.
|
||||
constexpr double kC0 = 1.442695041742110273474963832995854318141937255859375e+00;
|
||||
constexpr float kC0_hi = static_cast<float>(kC0);
|
||||
constexpr float kC0_lo = static_cast<float>(kC0 - static_cast<double>(kC0_hi));
|
||||
const Packet c0_hi = pset1<Packet>(kC0_hi);
|
||||
const Packet c0_lo = pset1<Packet>(kC0_lo);
|
||||
|
||||
const Packet p6 = pset1<Packet>(9.703654795885e-2f);
|
||||
const Packet p5 = pset1<Packet>(-0.1690667718648f);
|
||||
const Packet p4 = pset1<Packet>(0.1720575392246f);
|
||||
const Packet p3 = pset1<Packet>(-0.1789081543684f);
|
||||
const Packet p2 = pset1<Packet>(0.2050433009862f);
|
||||
const Packet p1 = pset1<Packet>(-0.2404672354459f);
|
||||
const Packet p0 = pset1<Packet>(0.2885761857032f);
|
||||
constexpr double kC1 = -7.2134751588268664068692714863573201000690460205078125e-01;
|
||||
constexpr float kC1_hi = static_cast<float>(kC1);
|
||||
constexpr float kC1_lo = static_cast<float>(kC1 - static_cast<double>(kC1_hi));
|
||||
const Packet c1_hi = pset1<Packet>(kC1_hi);
|
||||
const Packet c1_lo = pset1<Packet>(kC1_lo);
|
||||
|
||||
const Packet C3_hi = pset1<Packet>(-0.360674142838f);
|
||||
const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f);
|
||||
const Packet C2_hi = pset1<Packet>(0.480897903442f);
|
||||
const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f);
|
||||
const Packet C1_hi = pset1<Packet>(-0.721347510815f);
|
||||
const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f);
|
||||
const Packet C0_hi = pset1<Packet>(1.44269502163f);
|
||||
const Packet C0_lo = pset1<Packet>(2.01711713999e-08f);
|
||||
constexpr float c[] = {
|
||||
9.7010828554630279541015625e-02, -1.6896486282348632812500000e-01, 1.7200836539268493652343750e-01,
|
||||
-1.7892272770404815673828125e-01, 2.0505344867706298828125000e-01, -2.4046677350997924804687500e-01,
|
||||
2.8857553005218505859375000e-01, -3.6067414283752441406250000e-01, 4.8089790344238281250000000e-01};
|
||||
|
||||
// Evaluate the higher order terms in the polynomial using
|
||||
// standard arithmetic.
|
||||
const Packet one = pset1<Packet>(1.0f);
|
||||
|
||||
const Packet x = psub(z, one);
|
||||
// Evaluate P(x) in working precision.
|
||||
// We evaluate it in multiple parts to improve instruction level
|
||||
// parallelism.
|
||||
Packet x2 = pmul(x, x);
|
||||
Packet p_even = pmadd(p6, x2, p4);
|
||||
p_even = pmadd(p_even, x2, p2);
|
||||
p_even = pmadd(p_even, x2, p0);
|
||||
Packet p_odd = pmadd(p5, x2, p3);
|
||||
p_odd = pmadd(p_odd, x2, p1);
|
||||
Packet p = pmadd(p_odd, x, p_even);
|
||||
|
||||
// Now evaluate the low-order tems of Q(x) in double word precision.
|
||||
// In the following, due to the alternating signs and the fact that
|
||||
// |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use
|
||||
// fast_twosum instead of the slower twosum.
|
||||
Packet q_hi, q_lo;
|
||||
Packet t_hi, t_lo;
|
||||
// C3 + x * p(x)
|
||||
twoprod(p, x, t_hi, t_lo);
|
||||
fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
// C2 + x * p(x)
|
||||
twoprod(q_hi, q_lo, x, t_hi, t_lo);
|
||||
fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
// C1 + x * p(x)
|
||||
twoprod(q_hi, q_lo, x, t_hi, t_lo);
|
||||
fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
// C0 + x * p(x)
|
||||
twoprod(q_hi, q_lo, x, t_hi, t_lo);
|
||||
fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
|
||||
// log(z) ~= x * Q(x)
|
||||
twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo);
|
||||
Packet p = ppolevl<Packet, 8>::run(x, c);
|
||||
// Evaluate the final two step in Horner's rule using double-word arithmetic.
|
||||
Packet p_hi, p_lo;
|
||||
twoprod(x, p, p_hi, p_lo);
|
||||
fast_twosum(c1_hi, c1_lo, p_hi, p_lo, p_hi, p_lo);
|
||||
twoprod(p_hi, p_lo, x, p_hi, p_lo);
|
||||
fast_twosum(c0_hi, c0_lo, p_hi, p_lo, p_hi, p_lo);
|
||||
// Multiply by x to recover log2(z).
|
||||
twoprod(p_hi, p_lo, x, log2_x_hi, log2_x_lo);
|
||||
}
|
||||
};
|
||||
|
||||
@ -2006,16 +1976,6 @@ struct accurate_log2<double> {
|
||||
}
|
||||
};
|
||||
|
||||
// This function accurately computes exp2(x) for x in [-0.5:0.5], which is
|
||||
// needed in pow(x,y).
|
||||
template <typename Scalar>
|
||||
struct fast_accurate_exp2 {
|
||||
template <typename Packet>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet operator()(const Packet& x) {
|
||||
return generic_exp2(x);
|
||||
}
|
||||
};
|
||||
|
||||
// This function implements the non-trivial case of pow(x,y) where x is
|
||||
// positive and y is (possibly) non-integer.
|
||||
// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
|
||||
@ -2343,7 +2303,13 @@ struct unary_pow_impl<Packet, ScalarExponent, false, false, ExponentIsSigned> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
const bool exponent_is_integer = (numext::isfinite)(exponent) && numext::round(exponent) == exponent;
|
||||
if (exponent_is_integer) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
// The simple recursive doubling implementation is only accurate to 3 ulps for
|
||||
// integer exponents in [-3:7]. Since this is a common case, we specialize it here.
|
||||
if (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
// TODO(rmlarsen): Implement more efficient special case handling.
|
||||
return generic_pow(x, pset1<Packet>(exponent));
|
||||
} else {
|
||||
Packet result = unary_pow::gen_pow(x, exponent);
|
||||
result = unary_pow::handle_nonint_nonint_errors(x, result, exponent);
|
||||
@ -2356,7 +2322,13 @@ template <typename Packet, typename ScalarExponent, bool ExponentIsSigned>
|
||||
struct unary_pow_impl<Packet, ScalarExponent, false, true, ExponentIsSigned> {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& x, const ScalarExponent& exponent) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
// The simple recursive doubling implementation is only sufficiently accurate to 3 ulps for
|
||||
// integer exponents in [-3:7]. Since this is a common case, we specialize it here.
|
||||
if (exponent <= ScalarExponent(7) && (!ExponentIsSigned || exponent >= ScalarExponent(-3))) {
|
||||
return unary_pow::int_pow(x, exponent);
|
||||
}
|
||||
// TODO(rmlarsen): Implement more efficient special case handling.
|
||||
return generic_pow<Packet>(x, pset1<Packet>(Scalar(exponent)));
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user