mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
New accurate algorithm for pow(x,y). This version is accurate to 1.4 ulps for float, while still being 10x faster than std::pow for AVX512. A future change will introduce a specialization for double.
This commit is contained in:
parent
7ff0b7a980
commit
be0574e215
@ -1,3 +1,4 @@
|
||||
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
@ -36,7 +37,7 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||
enum { mantissa_bits = numext::numeric_limits<Scalar>::digits - 1};
|
||||
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
|
||||
}
|
||||
|
||||
@ -46,28 +47,29 @@ template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
|
||||
|
||||
EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
|
||||
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||
EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
|
||||
|
||||
enum {
|
||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
||||
};
|
||||
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
|
||||
~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000
|
||||
~(((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)) << int(MantissaBits)); // ~0x7f800000
|
||||
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
|
||||
const Packet half = pset1<Packet>(Scalar(0.5));
|
||||
const Packet zero = pzero(a);
|
||||
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
|
||||
|
||||
// To handle denormals, normalize by multiplying by 2^(mantissa_bits+1).
|
||||
// To handle denormals, normalize by multiplying by 2^(int(MantissaBits)+1).
|
||||
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(int(MantissaBits) + 1); // 24
|
||||
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
|
||||
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
|
||||
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
|
||||
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
|
||||
|
||||
// Determine exponent offset: -126 if normal, -126-24 if denormal
|
||||
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126
|
||||
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(int(ExponentBits)-1)) - ScalarUI(2)); // -126
|
||||
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
|
||||
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
|
||||
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
|
||||
@ -76,7 +78,7 @@ Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
||||
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
|
||||
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
|
||||
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
|
||||
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255
|
||||
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << int(ExponentBits)) - ScalarUI(1)); // 255
|
||||
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
|
||||
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
|
||||
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
|
||||
@ -113,18 +115,20 @@ Packet pldexp_generic(const Packet& a, const Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||
EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
|
||||
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||
EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
|
||||
enum {
|
||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
||||
};
|
||||
|
||||
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<exponent_bits) + ScalarI(mantissa_bits - 1))); // 278
|
||||
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127
|
||||
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<int(ExponentBits)) + ScalarI(int(MantissaBits) - 1))); // 278
|
||||
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
|
||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
||||
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
|
||||
Packet c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^b
|
||||
Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
|
||||
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
|
||||
b = psub(psub(psub(e, b), b), b); // e - 3b
|
||||
c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^(e-3*b)
|
||||
c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
|
||||
out = pmul(out, c);
|
||||
return out;
|
||||
}
|
||||
@ -890,112 +894,355 @@ Packet psqrt_complex(const Packet& a) {
|
||||
pselect(is_real_inf, real_inf_result,result));
|
||||
}
|
||||
|
||||
|
||||
// This function implements the Veltkamp splitting. Given a floating point
|
||||
// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds
|
||||
// exactly and that half of the significant of x fits in x_hi.
|
||||
// This code corresponds to Algorithms 3 and 4 in
|
||||
// https://hal.inria.fr/hal-01774587v2/document
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
|
||||
Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
|
||||
Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
|
||||
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
|
||||
x_hi = pmadd(pset1<Packet>(-shift_scale), x, gamma);
|
||||
#else
|
||||
Packet rho = psub(x, gamma);
|
||||
x_hi = padd(rho, gamma);
|
||||
#endif
|
||||
x_lo = psub(x, x_hi);
|
||||
}
|
||||
// TODO(rmlarsen): The following set of utilities for double word arithmetic
|
||||
// should perhaps be refactored as a separate file, since it would be generally
|
||||
// useful for special function implementation etc. Writing the algorithms in
|
||||
// terms if a double word type would also make the code more readable.
|
||||
|
||||
// This function splits x into the nearest integer n and fractional part r,
|
||||
// such that x = n + r holds exactly.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void integer_split(const Packet& x, Packet& n, Packet& r) {
|
||||
void absolute_split(const Packet& x, Packet& n, Packet& r) {
|
||||
n = pround(x);
|
||||
r = psub(x, n);
|
||||
}
|
||||
|
||||
// This function implements Dekker's algorithm for two products {x * y1, x * y2} with
|
||||
// a shared factor. Given floating point numbers {x, y1, y2} computes the pairs
|
||||
// {p1, r1} and {p2, r2} such that x * y1 = p1 + r1 holds exactly and
|
||||
// p1 = fl(x * y1), and x * y2 = p2 + r2 holds exactly and p2 = fl(x * y2).
|
||||
// This function computes the sum {s, r}, such that x + y = s_hi + s_lo
|
||||
// holds exactly, and s_hi = fl(x+y), if |x| >= |y|.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void double_dekker(const Packet& x, const Packet& y1, const Packet& y2,
|
||||
Packet& p1, Packet& r1, Packet& p2, Packet& r2) {
|
||||
Packet x_hi, x_lo, y1_hi, y1_lo, y2_hi, y2_lo;
|
||||
veltkamp_splitting(x, x_hi, x_lo);
|
||||
veltkamp_splitting(y1, y1_hi, y1_lo);
|
||||
veltkamp_splitting(y2, y2_hi, y2_lo);
|
||||
|
||||
p1 = pmul(x, y1);
|
||||
r1 = pmadd(x_hi, y1_hi, pnegate(p1));
|
||||
r1 = pmadd(x_hi, y1_lo, r1);
|
||||
r1 = pmadd(x_lo, y1_hi, r1);
|
||||
r1 = pmadd(x_lo, y1_lo, r1);
|
||||
|
||||
p2 = pmul(x, y2);
|
||||
r2 = pmadd(x_hi, y2_hi, pnegate(p2));
|
||||
r2 = pmadd(x_hi, y2_lo, r2);
|
||||
r2 = pmadd(x_lo, y2_hi, r2);
|
||||
r2 = pmadd(x_lo, y2_lo, r2);
|
||||
void fast_twosum(const Packet& x, const Packet& y, Packet& s_hi, Packet& s_lo) {
|
||||
s_hi = padd(x, y);
|
||||
const Packet t = psub(s_hi, x);
|
||||
s_lo = psub(y, t);
|
||||
}
|
||||
|
||||
#ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD
|
||||
// This function implements the extended precision product of
|
||||
// a pair of floating point numbers. Given {x, y}, it computes the pair
|
||||
// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
|
||||
// p_hi = fl(x * y).
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void twoprod(const Packet& x, const Packet& y,
|
||||
Packet& p_hi, Packet& p_lo) {
|
||||
p_hi = pmul(x, y);
|
||||
p_lo = pmadd(x, y, pnegate(p_hi));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// This function implements the Veltkamp splitting. Given a floating point
|
||||
// number x it returns the pair {x_hi, x_lo} such that x_hi + x_lo = x holds
|
||||
// exactly and that half of the significant of x fits in x_hi.
|
||||
// This is Algorithm 3 from Jean-Michel Muller, "Elementary Functions",
|
||||
// 3rd edition, Birkh\"auser, 2016.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void veltkamp_splitting(const Packet& x, Packet& x_hi, Packet& x_lo) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
EIGEN_CONSTEXPR int shift = (NumTraits<Scalar>::digits() + 1) / 2;
|
||||
const Scalar shift_scale = Scalar(uint64_t(1) << shift); // Scalar constructor not necessarily constexpr.
|
||||
const Packet gamma = pmul(pset1<Packet>(shift_scale + Scalar(1)), x);
|
||||
Packet rho = psub(x, gamma);
|
||||
x_hi = padd(rho, gamma);
|
||||
x_lo = psub(x, x_hi);
|
||||
}
|
||||
|
||||
// This function implements Dekker's algorithm for products x * y.
|
||||
// Given floating point numbers {x, y} computes the pair
|
||||
// {p_hi, p_lo} such that x * y = p_hi + p_lo holds exactly and
|
||||
// p_hi = fl(x * y).
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void twoprod(const Packet& x, const Packet& y,
|
||||
Packet& p_hi, Packet& p_lo) {
|
||||
Packet x_hi, x_lo, y_hi, y_lo;
|
||||
veltkamp_splitting(x, x_hi, x_lo);
|
||||
veltkamp_splitting(y, y_hi, y_lo);
|
||||
|
||||
p_hi = pmul(x, y);
|
||||
p_lo = pmadd(x_hi, y_hi, pnegate(p_hi));
|
||||
p_lo = pmadd(x_hi, y_lo, p_lo);
|
||||
p_lo = pmadd(x_lo, y_hi, p_lo);
|
||||
p_lo = pmadd(x_lo, y_lo, p_lo);
|
||||
}
|
||||
|
||||
#endif // EIGEN_HAS_SINGLE_INSTRUCTION_MADD
|
||||
|
||||
|
||||
// This function implements Dekker's algorithm for the addition
|
||||
// of two double word numbers represented by {x_hi, x_lo} and {y_hi, y_lo}.
|
||||
// It returns the result as a pair {s_hi, s_lo} such that
|
||||
// x_hi + x_lo + y_hi + y_lo = s_hi + s_lo holds exactly.
|
||||
// This is Algorithm 5 from Jean-Michel Muller, "Elementary Functions",
|
||||
// 3rd edition, Birkh\"auser, 2016.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void twosum(const Packet& x_hi, const Packet& x_lo,
|
||||
const Packet& y_hi, const Packet& y_lo,
|
||||
Packet& s_hi, Packet& s_lo) {
|
||||
const Packet x_greater_mask = pcmp_lt(pabs(y_hi), pabs(x_hi));
|
||||
Packet r_hi_1, r_lo_1;
|
||||
fast_twosum(x_hi, y_hi,r_hi_1, r_lo_1);
|
||||
Packet r_hi_2, r_lo_2;
|
||||
fast_twosum(y_hi, x_hi,r_hi_2, r_lo_2);
|
||||
const Packet r_hi = pselect(x_greater_mask, r_hi_1, r_hi_2);
|
||||
|
||||
const Packet s1 = padd(padd(y_lo, r_lo_1), x_lo);
|
||||
const Packet s2 = padd(padd(x_lo, r_lo_2), y_lo);
|
||||
const Packet s = pselect(x_greater_mask, s1, s2);
|
||||
|
||||
fast_twosum(r_hi, s, s_hi, s_lo);
|
||||
}
|
||||
|
||||
// This is a version of twosum for double word numbers,
|
||||
// which assumes that |x_hi| >= |y_hi|.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void fast_twosum(const Packet& x_hi, const Packet& x_lo,
|
||||
const Packet& y_hi, const Packet& y_lo,
|
||||
Packet& s_hi, Packet& s_lo) {
|
||||
Packet r_hi, r_lo;
|
||||
fast_twosum(x_hi, y_hi, r_hi, r_lo);
|
||||
const Packet s = padd(padd(y_lo, r_lo), x_lo);
|
||||
fast_twosum(r_hi, s, s_hi, s_lo);
|
||||
}
|
||||
|
||||
// This function implements the multiplication of a double word
|
||||
// number represented by {x_hi, x_lo} by a floating point number y.
|
||||
// It returns the result as a pair {p_hi, p_lo} such that
|
||||
// (x_hi + x_lo) * y = p_hi + p_lo hold with a relative error
|
||||
// of less than 2*2^{-2p}, where p is the number of significand bit
|
||||
// in the floating point type.
|
||||
// This is Algorithm 7 from Jean-Michel Muller, "Elementary Functions",
|
||||
// 3rd edition, Birkh\"auser, 2016.
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void twoprod(const Packet& x_hi, const Packet& x_lo, const Packet& y,
|
||||
Packet& p_hi, Packet& p_lo) {
|
||||
Packet c_hi, c_lo1;
|
||||
twoprod(x_hi, y, c_hi, c_lo1);
|
||||
const Packet c_lo2 = pmul(x_lo, y);
|
||||
Packet t_hi, t_lo1;
|
||||
fast_twosum(c_hi, c_lo2, t_hi, t_lo1);
|
||||
const Packet t_lo2 = padd(t_lo1, c_lo1);
|
||||
fast_twosum(t_hi, t_lo2, p_hi, p_lo);
|
||||
}
|
||||
|
||||
// This function computes log2(x) and returns the result as a double word.
|
||||
template <typename Scalar>
|
||||
struct accurate_log2 {
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void operator()(const Packet& x, Packet& log2_x_hi, Packet& log2_x_lo) {
|
||||
log2_x_hi = plog2(x);
|
||||
log2_x_lo = pzero(x);
|
||||
}
|
||||
};
|
||||
|
||||
// This specialization uses a more accurate algorithm to compute log2(x) for
|
||||
// floats in [1/sqrt(2);sqrt(2)] with a relative accuracy of ~6.42e-10.
|
||||
// This additional accuracy is needed to counter the error-magnification
|
||||
// inherent in multiplying by a potentially large exponent in pow(x,y).
|
||||
// The minimax polynomial used was calculated using the Sollya tool.
|
||||
// See sollya.org.
|
||||
template <>
|
||||
struct accurate_log2<float> {
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
void operator()(const Packet& z, Packet& log2_x_hi, Packet& log2_x_lo) {
|
||||
// The function log(1+x)/x is approximated in the interval
|
||||
// [1/sqrt(2)-1;sqrt(2)-1] by a degree 10 polynomial of the form
|
||||
// Q(x) = (C0 + x * (C1 + x * (C2 + x * (C3 + x * P(x))))),
|
||||
// where the degree 6 polynomial P(x) is evaluated in single precision,
|
||||
// while the remaining 4 terms of Q(x), as well as the final multiplication by x
|
||||
// to reconstruct log(1+x) are evaluated in extra precision using
|
||||
// double word arithmetic. C0 through C3 are extra precise constants
|
||||
// stored as double words.
|
||||
//
|
||||
// The polynomial coefficients were calculated using Sollya commands:
|
||||
// > n = 10;
|
||||
// > f = log2(1+x)/x;
|
||||
// > interval = [sqrt(0.5)-1;sqrt(2)-1];
|
||||
// > p = fpminimax(f,n,[|double,double,double,double,single...|],interval,relative,floating);
|
||||
|
||||
const Packet p6 = pset1<Packet>( 9.703654795885e-2f);
|
||||
const Packet p5 = pset1<Packet>(-0.1690667718648f);
|
||||
const Packet p4 = pset1<Packet>( 0.1720575392246f);
|
||||
const Packet p3 = pset1<Packet>(-0.1789081543684f);
|
||||
const Packet p2 = pset1<Packet>( 0.2050433009862f);
|
||||
const Packet p1 = pset1<Packet>(-0.2404672354459f);
|
||||
const Packet p0 = pset1<Packet>( 0.2885761857032f);
|
||||
|
||||
const Packet C3_hi = pset1<Packet>(-0.360674142838f);
|
||||
const Packet C3_lo = pset1<Packet>(-6.13283912543e-09f);
|
||||
const Packet C2_hi = pset1<Packet>(0.480897903442f);
|
||||
const Packet C2_lo = pset1<Packet>(-1.44861207474e-08f);
|
||||
const Packet C1_hi = pset1<Packet>(-0.721347510815f);
|
||||
const Packet C1_lo = pset1<Packet>(-4.84483164698e-09f);
|
||||
const Packet C0_hi = pset1<Packet>(1.44269502163f);
|
||||
const Packet C0_lo = pset1<Packet>(2.01711713999e-08f);
|
||||
const Packet one = pset1<Packet>(1.0f);
|
||||
|
||||
const Packet x = psub(z, one);
|
||||
// Evaluate P(x) in working precision.
|
||||
// We evaluate it in multiple parts to improve instruction level
|
||||
// parallelism.
|
||||
Packet x2 = pmul(x,x);
|
||||
Packet p_even = pmadd(p6, x2, p4);
|
||||
p_even = pmadd(p_even, x2, p2);
|
||||
p_even = pmadd(p_even, x2, p0);
|
||||
Packet p_odd = pmadd(p5, x2, p3);
|
||||
p_odd = pmadd(p_odd, x2, p1);
|
||||
Packet p = pmadd(p_odd, x, p_even);
|
||||
|
||||
// Now evaluate the low-order tems of Q(x) in double word precision.
|
||||
// In the following, due to the alternating signs and the fact that
|
||||
// |x| < sqrt(2)-1, we can assume that |C*_hi| >= q_i, and use
|
||||
// fast_twosum instead of the slower twosum.
|
||||
Packet q_hi, q_lo;
|
||||
Packet t_hi, t_lo;
|
||||
// C3 + x * p(x)
|
||||
twoprod(p, x, t_hi, t_lo);
|
||||
fast_twosum(C3_hi, C3_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
// C2 + x * p(x)
|
||||
twoprod(q_hi, q_lo, x, t_hi, t_lo);
|
||||
fast_twosum(C2_hi, C2_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
// C1 + x * p(x)
|
||||
twoprod(q_hi, q_lo, x, t_hi, t_lo);
|
||||
fast_twosum(C1_hi, C1_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
// C0 + x * p(x)
|
||||
twoprod(q_hi, q_lo, x, t_hi, t_lo);
|
||||
fast_twosum(C0_hi, C0_lo, t_hi, t_lo, q_hi, q_lo);
|
||||
|
||||
// log(z) ~= x * Q(x)
|
||||
twoprod(q_hi, q_lo, x, log2_x_hi, log2_x_lo);
|
||||
}
|
||||
};
|
||||
|
||||
// This function computes exp2(x) (i.e. 2**x).
|
||||
template <typename Scalar>
|
||||
struct fast_accurate_exp2 {
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet operator()(const Packet& x) {
|
||||
// TODO(rmlarsen): Add a pexp2 packetop.
|
||||
return pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), x));
|
||||
}
|
||||
};
|
||||
|
||||
// This specialization uses a faster algorithm to compute exp2(x) for floats
|
||||
// in [-0.5;0.5] with a relative accuracy of 1 ulp.
|
||||
// The minimax polynomial used was calculated using the Sollya tool.
|
||||
// See sollya.org.
|
||||
template <>
|
||||
struct fast_accurate_exp2<float> {
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet operator()(const Packet& x) {
|
||||
// This function approximates exp2(x) by a degree 6 polynomial of the form
|
||||
// Q(x) = 1 + x * (C + x * P(x)), where the degree 4 polynomial P(x) is evaluated in
|
||||
// single precision, and the remaining steps are evaluated with extra precision using
|
||||
// double word arithmetic. C is an extra precise constant stored as a double word.
|
||||
//
|
||||
// The polynomial coefficients were calculated using Sollya commands:
|
||||
// > n = 6;
|
||||
// > f = 2^x;
|
||||
// > interval = [-0.5;0.5];
|
||||
// > p = fpminimax(f,n,[|1,double,single...|],interval,relative,floating);
|
||||
|
||||
const Packet p4 = pset1<Packet>(1.539513905e-4f);
|
||||
const Packet p3 = pset1<Packet>(1.340007293e-3f);
|
||||
const Packet p2 = pset1<Packet>(9.618283249e-3f);
|
||||
const Packet p1 = pset1<Packet>(5.550328270e-2f);
|
||||
const Packet p0 = pset1<Packet>(0.2402264923f);
|
||||
|
||||
const Packet C_hi = pset1<Packet>(0.6931471825f);
|
||||
const Packet C_lo = pset1<Packet>(2.36836577e-08f);
|
||||
const Packet one = pset1<Packet>(1.0f);
|
||||
|
||||
// Evaluate P(x) in working precision.
|
||||
// We evaluate even and odd parts of the polynomial separately
|
||||
// to gain some instruction level parallelism.
|
||||
Packet x2 = pmul(x,x);
|
||||
Packet p_even = pmadd(p4, x2, p2);
|
||||
p_even = pmadd(p_even, x2, p0);
|
||||
Packet p_odd = pmadd(p3, x2, p1);
|
||||
Packet p = pmadd(p_odd, x, p_even);
|
||||
|
||||
// Evaluate the remaining terms of Q(x) with extra precision using
|
||||
// double word arithmetic.
|
||||
Packet p_hi, p_lo;
|
||||
// x * p(x)
|
||||
twoprod(p, x, p_hi, p_lo);
|
||||
// C + x * p(x)
|
||||
Packet q1_hi, q1_lo;
|
||||
twosum(p_hi, p_lo, C_hi, C_lo, q1_hi, q1_lo);
|
||||
// x * (C + x * p(x))
|
||||
Packet q2_hi, q2_lo;
|
||||
twoprod(q1_hi, q1_lo, x, q2_hi, q2_lo);
|
||||
// 1 + x * (C + x * p(x))
|
||||
Packet q3_hi, q3_lo;
|
||||
// Since |q2_hi| <= sqrt(2)-1 < 1, we can use fast_twosum
|
||||
// for adding it to unity here.
|
||||
fast_twosum(one, q2_hi, q3_hi, q3_lo);
|
||||
return padd(q3_hi, padd(q2_lo, q3_lo));
|
||||
}
|
||||
};
|
||||
|
||||
// This function implements the non-trivial case of pow(x,y) where x is
|
||||
// positive and y is (possibly) non-integer.
|
||||
// Formally, pow(x,y) = 2**(y * log2(x))
|
||||
template<typename Packet>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet generic_pow_impl(const Packet& x, const Packet& y) {
|
||||
// Formally, pow(x,y) = exp2(y * log2(x)), where exp2(x) is shorthand for 2^x.
|
||||
// TODO(rmlarsen): We should probably add this as a packet up 'ppow', to make it
|
||||
// easier to specialize or turn off for specific types and/or backends.x
|
||||
template <typename Packet>
|
||||
EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, const Packet& y) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
// Split x into exponent e_x and mantissa m_x.
|
||||
Packet e_x;
|
||||
Packet m_x = pfrexp(x, e_x);
|
||||
|
||||
// Adjust m_x to lie in [0.75:1.5) to minimize absolute error in log2(m_x).
|
||||
Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(Scalar(0.75)));
|
||||
// Adjust m_x to lie in [1/sqrt(2):sqrt(2)] to minimize absolute error in log2(m_x).
|
||||
EIGEN_CONSTEXPR Scalar sqrt_half = Scalar(0.70710678118654752440);
|
||||
const Packet m_x_scale_mask = pcmp_lt(m_x, pset1<Packet>(sqrt_half));
|
||||
m_x = pselect(m_x_scale_mask, pmul(pset1<Packet>(Scalar(2)), m_x), m_x);
|
||||
e_x = pselect(m_x_scale_mask, psub(e_x, pset1<Packet>(Scalar(1))), e_x);
|
||||
|
||||
Packet r_x = plog2(m_x);
|
||||
// Compute log2(m_x) with 6 extra bits of accuracy.
|
||||
Packet rx_hi, rx_lo;
|
||||
accurate_log2<Scalar>()(m_x, rx_hi, rx_lo);
|
||||
|
||||
// Compute the two terms {y * e_x, y * r_x} in f = y * log2(x) with doubled
|
||||
// precision using Dekker's algorithm.
|
||||
// precision using double word arithmetic.
|
||||
Packet f1_hi, f1_lo, f2_hi, f2_lo;
|
||||
double_dekker(y, e_x, r_x, f1_hi, f1_lo, f2_hi, f2_lo);
|
||||
twoprod(e_x, y, f1_hi, f1_lo);
|
||||
twoprod(rx_hi, rx_lo, y, f2_hi, f2_lo);
|
||||
// Sum the two terms in f using double word arithmetic. We know
|
||||
// that |e_x| > |log2(m_x)|, except for the case where e_x==0.
|
||||
// This means that we can use fast_twosum(f1,f2).
|
||||
// In the case e_x == 0, e_x * y = f1 = 0, so we don't lose any
|
||||
// accuracy by violating the assumption of fast_twosum, because
|
||||
// it's a no-op.
|
||||
Packet f_hi, f_lo;
|
||||
fast_twosum(f1_hi, f1_lo, f2_hi, f2_lo, f_hi, f_lo);
|
||||
|
||||
// Separate f into integer and fractional parts, keeping f1_hi, and f2_hi
|
||||
// separate to avoid cancellation.
|
||||
Packet n1, r1, n2, r2;
|
||||
integer_split(f1_hi, n1, r1);
|
||||
integer_split(f2_hi, n2, r2);
|
||||
|
||||
// Add up integer parts and sum the remainders.
|
||||
Packet n_z = padd(n1, n2);
|
||||
// Notice: I experimented with using compensated (Kahan) summation here,
|
||||
// but it does not seem to matter.
|
||||
Packet rem = padd(padd(f1_lo, f2_lo), padd(r1, r2));
|
||||
|
||||
// Extract any additional integer part that may have accumulated in rem.
|
||||
Packet nrem, r_z;
|
||||
integer_split(rem, nrem, r_z);
|
||||
n_z = padd(n_z, nrem);
|
||||
// Split f into integer and fractional parts.
|
||||
Packet n_z, r_z;
|
||||
absolute_split(f_hi, n_z, r_z);
|
||||
r_z = padd(r_z, f_lo);
|
||||
Packet n_r;
|
||||
absolute_split(r_z, n_r, r_z);
|
||||
n_z = padd(n_z, n_r);
|
||||
|
||||
// We now have an accurate split of f = n_z + r_z and can compute
|
||||
// x^y = 2**{n_z + r_z) = exp(ln(2) * r_z) * 2**{n_z}.
|
||||
// The first factor we compute by calling pexp(), while multiplication
|
||||
// by an integer power of 2 can be done exactly using pldexp().
|
||||
// Note: I experimented with using Dekker's algorithms for the
|
||||
// multiplication by ln(2) here, but did not see any difference.
|
||||
Packet e_r = pexp(pmul(pset1<Packet>(Scalar(EIGEN_LN2)), r_z));
|
||||
// TODO: investigate bounds of e_r and n_z, potentially using faster
|
||||
// implementation of ldexp.
|
||||
// x^y = 2**{n_z + r_z) = exp2(r_z) * 2**{n_z}.
|
||||
// Since r_z is in [-0.5;0.5], we compute the first factor to high accuracy
|
||||
// using a specialized algorithm. Multiplication by the second factor can
|
||||
// be done exactly using pldexp(), since it is an integer power of 2.
|
||||
// Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
|
||||
const Packet e_r = fast_accurate_exp2<Scalar>()(r_z);
|
||||
return pldexp(e_r, n_z);
|
||||
}
|
||||
|
||||
@ -1005,66 +1252,75 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet generic_pow(const Packet& x, const Packet& y) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
|
||||
const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
|
||||
const Packet cst_zero = pset1<Packet>(Scalar(0));
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_half = pset1<Packet>(Scalar(0.5));
|
||||
const Packet cst_nan = pset1<Packet>(NumTraits<Scalar>::quiet_NaN());
|
||||
|
||||
Packet abs_x = pabs(x);
|
||||
const Packet abs_x = pabs(x);
|
||||
// Predicates for sign and magnitude of x.
|
||||
Packet x_is_zero = pcmp_eq(x, cst_zero);
|
||||
Packet x_is_neg = pcmp_lt(x, cst_zero);
|
||||
Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
|
||||
Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
|
||||
Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
|
||||
Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
|
||||
Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
|
||||
Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
|
||||
const Packet x_is_zero = pcmp_eq(x, cst_zero);
|
||||
const Packet x_is_neg = pcmp_lt(x, cst_zero);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_one);
|
||||
const Packet abs_x_is_gt_one = pcmp_lt(cst_one, abs_x);
|
||||
const Packet abs_x_is_lt_one = pcmp_lt(abs_x, cst_one);
|
||||
const Packet x_is_one = pandnot(abs_x_is_one, x_is_neg);
|
||||
const Packet x_is_neg_one = pand(abs_x_is_one, x_is_neg);
|
||||
const Packet x_is_nan = pandnot(ptrue(x), pcmp_eq(x, x));
|
||||
|
||||
// Predicates for sign and magnitude of y.
|
||||
Packet y_is_zero = pcmp_eq(y, cst_zero);
|
||||
Packet y_is_neg = pcmp_lt(y, cst_zero);
|
||||
Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
|
||||
Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
|
||||
Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
|
||||
// |y| is so large that (1+eps)^y over- or underflows.
|
||||
const Packet y_is_one = pcmp_eq(y, cst_one);
|
||||
const Packet y_is_zero = pcmp_eq(y, cst_zero);
|
||||
const Packet y_is_neg = pcmp_lt(y, cst_zero);
|
||||
const Packet y_is_pos = pandnot(ptrue(y), por(y_is_zero, y_is_neg));
|
||||
const Packet y_is_nan = pandnot(ptrue(y), pcmp_eq(y, y));
|
||||
const Packet abs_y_is_inf = pcmp_eq(pabs(y), cst_pos_inf);
|
||||
EIGEN_CONSTEXPR Scalar huge_exponent =
|
||||
(std::numeric_limits<Scalar>::digits * Scalar(EIGEN_LOG2E)) /
|
||||
(std::numeric_limits<Scalar>::max_exponent * Scalar(EIGEN_LN2)) /
|
||||
std::numeric_limits<Scalar>::epsilon();
|
||||
Packet abs_y_is_huge = pcmp_lt(pset1<Packet>(huge_exponent), pabs(y));
|
||||
const Packet abs_y_is_huge = pcmp_le(pset1<Packet>(huge_exponent), pabs(y));
|
||||
|
||||
// Predicates for whether y is integer and/or even.
|
||||
Packet y_is_int = pcmp_eq(pfloor(y), y);
|
||||
Packet y_div_2 = pmul(y, cst_half);
|
||||
Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
|
||||
const Packet y_is_int = pcmp_eq(pfloor(y), y);
|
||||
const Packet y_div_2 = pmul(y, pset1<Packet>(Scalar(0.5)));
|
||||
const Packet y_is_even = pcmp_eq(pround(y_div_2), y_div_2);
|
||||
|
||||
// Predicates encoding special cases for the value of pow(x,y)
|
||||
Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf), y_is_int), abs_y_is_inf);
|
||||
Packet pow_is_one = por(por(x_is_one, y_is_zero),
|
||||
pand(x_is_neg_one,
|
||||
por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
|
||||
Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
|
||||
Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos), pand(abs_x_is_inf, y_is_neg)),
|
||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_pos)),
|
||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_neg));
|
||||
Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
|
||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)),
|
||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos));
|
||||
const Packet invalid_negative_x = pandnot(pandnot(pandnot(x_is_neg, abs_x_is_inf),
|
||||
y_is_int),
|
||||
abs_y_is_inf);
|
||||
const Packet pow_is_one = por(por(x_is_one, y_is_zero),
|
||||
pand(x_is_neg_one,
|
||||
por(abs_y_is_inf, pandnot(y_is_even, invalid_negative_x))));
|
||||
const Packet pow_is_nan = por(invalid_negative_x, por(x_is_nan, y_is_nan));
|
||||
const Packet pow_is_zero = por(por(por(pand(x_is_zero, y_is_pos),
|
||||
pand(abs_x_is_inf, y_is_neg)),
|
||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge),
|
||||
y_is_pos)),
|
||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge),
|
||||
y_is_neg));
|
||||
const Packet pow_is_inf = por(por(por(pand(x_is_zero, y_is_neg),
|
||||
pand(abs_x_is_inf, y_is_pos)),
|
||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge),
|
||||
y_is_neg)),
|
||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge),
|
||||
y_is_pos));
|
||||
|
||||
// General computation of pow(x,y) for positive x or negative x and integer y.
|
||||
Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
|
||||
Packet pow_abs = generic_pow_impl(abs_x, y);
|
||||
|
||||
return pselect(pow_is_one, cst_one,
|
||||
pselect(pow_is_nan, cst_nan,
|
||||
pselect(pow_is_inf, cst_pos_inf,
|
||||
pselect(pow_is_zero, cst_zero,
|
||||
pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))));
|
||||
const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
|
||||
const Packet pow_abs = generic_pow_impl(abs_x, y);
|
||||
return pselect(y_is_one, x,
|
||||
pselect(pow_is_one, cst_one,
|
||||
pselect(pow_is_nan, cst_nan,
|
||||
pselect(pow_is_inf, cst_pos_inf,
|
||||
pselect(pow_is_zero, cst_zero,
|
||||
pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* polevl (modified for Eigen)
|
||||
*
|
||||
* Evaluate polynomial
|
||||
|
@ -14,6 +14,7 @@
|
||||
template<typename Scalar>
|
||||
void pow_test() {
|
||||
const Scalar zero = Scalar(0);
|
||||
const Scalar eps = std::numeric_limits<Scalar>::epsilon();
|
||||
const Scalar one = Scalar(1);
|
||||
const Scalar two = Scalar(2);
|
||||
const Scalar three = Scalar(3);
|
||||
@ -21,20 +22,25 @@ void pow_test() {
|
||||
const Scalar sqrt2 = Scalar(std::sqrt(2));
|
||||
const Scalar inf = std::numeric_limits<Scalar>::infinity();
|
||||
const Scalar nan = std::numeric_limits<Scalar>::quiet_NaN();
|
||||
const Scalar denorm_min = std::numeric_limits<Scalar>::denorm_min();
|
||||
const Scalar min = (std::numeric_limits<Scalar>::min)();
|
||||
const Scalar max = (std::numeric_limits<Scalar>::max)();
|
||||
const Scalar max_exp = (static_cast<Scalar>(std::numeric_limits<Scalar>::max_exponent) * Scalar(EIGEN_LN2)) / eps;
|
||||
|
||||
const static Scalar abs_vals[] = {zero,
|
||||
denorm_min,
|
||||
min,
|
||||
eps,
|
||||
sqrt_half,
|
||||
one,
|
||||
sqrt2,
|
||||
two,
|
||||
three,
|
||||
min,
|
||||
max_exp,
|
||||
max,
|
||||
inf,
|
||||
nan};
|
||||
|
||||
const int abs_cases = 10;
|
||||
const int abs_cases = 13;
|
||||
const int num_cases = 2*abs_cases * 2*abs_cases;
|
||||
// Repeat the same value to make sure we hit the vectorized path.
|
||||
const int num_repeats = 32;
|
||||
@ -64,10 +70,7 @@ void pow_test() {
|
||||
bool all_pass = true;
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
for (int j = 0; j < num_cases; ++j) {
|
||||
// TODO(rmlarsen): Skip tests that trigger a known bug in pldexp for now.
|
||||
if (std::abs(x(i,j)) == max || std::abs(x(i,j)) == min) continue;
|
||||
|
||||
Scalar e = numext::pow(x(i,j), y(i,j));
|
||||
Scalar e = static_cast<Scalar>(std::pow(x(i,j), y(i,j)));
|
||||
Scalar a = actual(i, j);
|
||||
bool fail = !(a==e) && !internal::isApprox(a, e, tol) && !((numext::isnan)(a) && (numext::isnan)(e));
|
||||
all_pass &= !fail;
|
||||
@ -79,7 +82,6 @@ void pow_test() {
|
||||
VERIFY(all_pass);
|
||||
}
|
||||
|
||||
|
||||
template<typename ArrayType> void array(const ArrayType& m)
|
||||
{
|
||||
typedef typename ArrayType::Scalar Scalar;
|
||||
|
Loading…
x
Reference in New Issue
Block a user