mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-16 14:49:39 +08:00
Optimize various mathematical packet ops
This commit is contained in:
parent
1aa6dc2007
commit
0471e61b4c
@ -1234,38 +1234,36 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packe
|
||||
// for how corner cases are supposed to be handled according to the
|
||||
// IEEE floating-point standard (IEC 60559).
|
||||
|
||||
// bend two rules:
|
||||
// 1) inf / inf == 1
|
||||
// 2) 0 / 0 == 0
|
||||
// otherwise, evaluate atan(y/x) as usual and shift to the appropriate quadrant
|
||||
|
||||
const Packet kSignMask = pset1<Packet>(-Scalar(0));
|
||||
const Packet kZero = pzero(x);
|
||||
const Packet kOne = pset1<Packet>(Scalar(1));
|
||||
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
|
||||
const Packet kInf = pset1<Packet>(NumTraits<Scalar>::infinity());
|
||||
|
||||
const Packet abs_x = pabs(x);
|
||||
const Packet x_is_zero = pcmp_eq(abs_x, kZero);
|
||||
const Packet x_is_inf = pcmp_eq(abs_x, kInf);
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
|
||||
const Packet abs_y = pabs(y);
|
||||
const Packet y_is_zero = pcmp_eq(abs_y, kZero);
|
||||
const Packet y_is_inf = pcmp_eq(abs_y, kInf);
|
||||
const Packet y_signmask = pand(y, kSignMask);
|
||||
const Packet shift = por(pand(x_has_signbit, kPi), y_signmask);
|
||||
|
||||
const Packet arg_signmask = pand(pxor(x, y), kSignMask);
|
||||
const Packet shift = pxor(pand(x_has_signbit, kPi), y_signmask);
|
||||
const Packet xor_xy = pxor(x, y);
|
||||
// if x and y have the same absolute value, then xor(x,y) is zero
|
||||
// make sure that neither x nor y is nan
|
||||
// furthermore, xor(x,y) has the sign of the result
|
||||
const Packet x_and_y_are_same = pand(pcmp_eq(xor_xy, kZero), pcmp_eq(x, x));
|
||||
// more strictly, if x and y are both zero, then or(x,y) is zero
|
||||
// this implicitly checks for nan
|
||||
// the sign of or(x,y) is not meaningful
|
||||
const Packet x_and_y_are_zero = pcmp_eq(por(x, y), kZero);
|
||||
|
||||
// bend two rules:
|
||||
// 1) 0 / 0 == 0
|
||||
// 2) inf / inf == 1
|
||||
// otherwise, evaluate atan(y/x) as usual and shift to the appropriate quadrant
|
||||
|
||||
Packet arg = pdiv(abs_y, abs_x);
|
||||
arg = pselect(pand(x_is_zero, y_is_zero), kZero, arg);
|
||||
arg = pselect(pand(x_is_inf, y_is_inf), kOne, arg);
|
||||
Packet arg = pdiv(y, x);
|
||||
arg = pselect(x_and_y_are_same, por(kOne, xor_xy), arg);
|
||||
arg = pselect(x_and_y_are_zero, xor_xy, arg);
|
||||
|
||||
Packet result = patan(arg);
|
||||
result = pxor(result, arg_signmask);
|
||||
result = padd(result, shift);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -728,20 +728,19 @@ Packet pacos_float(const Packet& x_in) {
|
||||
|
||||
const Packet cst_one = pset1<Packet>(Scalar(1));
|
||||
const Packet cst_pi = pset1<Packet>(Scalar(EIGEN_PI));
|
||||
const Packet p6 = pset1<Packet>(Scalar(2.26911413483321666717529296875e-3));
|
||||
const Packet p5 = pset1<Packet>(Scalar(-1.1063250713050365447998046875e-2));
|
||||
const Packet p4 = pset1<Packet>(Scalar(2.680264413356781005859375e-2));
|
||||
const Packet p3 = pset1<Packet>(Scalar(-4.87488098442554473876953125e-2));
|
||||
const Packet p2 = pset1<Packet>(Scalar(8.874166011810302734375e-2));
|
||||
const Packet p1 = pset1<Packet>(Scalar(-0.2145837843418121337890625));
|
||||
const Packet p0 = pset1<Packet>(Scalar(1.57079613208770751953125));
|
||||
const Packet p6 = pset1<Packet>(Scalar(2.36423197202384471893310546875e-3));
|
||||
const Packet p5 = pset1<Packet>(Scalar(-1.1368644423782825469970703125e-2));
|
||||
const Packet p4 = pset1<Packet>(Scalar(2.717843465507030487060546875e-2));
|
||||
const Packet p3 = pset1<Packet>(Scalar(-4.8969544470310211181640625e-2));
|
||||
const Packet p2 = pset1<Packet>(Scalar(8.8804088532924652099609375e-2));
|
||||
const Packet p1 = pset1<Packet>(Scalar(-0.214591205120086669921875));
|
||||
const Packet p0 = pset1<Packet>(Scalar(1.57079637050628662109375));
|
||||
|
||||
// For x in [0:1], we approximate acos(x)/sqrt(1-x), which is a smooth
|
||||
// function, by a 6'th order polynomial.
|
||||
// For x in [-1:0) we use that acos(-x) = pi - acos(x).
|
||||
const Packet neg_mask = pcmp_lt(x_in, pzero(x_in));
|
||||
Packet x = pabs(x_in);
|
||||
const Packet invalid_mask = pcmp_lt(pset1<Packet>(1.0f), x);
|
||||
const Packet neg_mask = psignbit(x_in);
|
||||
const Packet abs_x = pabs(x_in);
|
||||
|
||||
// Evaluate the polynomial using Horner's rule:
|
||||
// P(x) = p0 + x * (p1 + x * (p2 + ... (p5 + x * p6)) ... ) .
|
||||
@ -753,19 +752,15 @@ Packet pacos_float(const Packet& x_in) {
|
||||
p_even = pmadd(p_even, x2, p2);
|
||||
p_odd = pmadd(p_odd, x2, p1);
|
||||
p_even = pmadd(p_even, x2, p0);
|
||||
Packet p = pmadd(p_odd, x, p_even);
|
||||
Packet p = pmadd(p_odd, abs_x, p_even);
|
||||
|
||||
// The polynomial approximates acos(x)/sqrt(1-x), so
|
||||
// multiply by sqrt(1-x) to get acos(x).
|
||||
Packet denom = psqrt(psub(cst_one, x));
|
||||
// Conveniently returns NaN for arguments outside [-1:1].
|
||||
Packet denom = psqrt(psub(cst_one, abs_x));
|
||||
Packet result = pmul(denom, p);
|
||||
|
||||
// Undo mapping for negative arguments.
|
||||
result = pselect(neg_mask, psub(cst_pi, result), result);
|
||||
// Return NaN for arguments outside [-1:1].
|
||||
return pselect(invalid_mask,
|
||||
pset1<Packet>(std::numeric_limits<float>::quiet_NaN()),
|
||||
result);
|
||||
return pselect(neg_mask, psub(cst_pi, result), result);
|
||||
}
|
||||
|
||||
// Generic implementation of asin(x).
|
||||
@ -834,6 +829,8 @@ Packet patan_reduced_float(const Packet& x) {
|
||||
// We evaluate even and odd terms in x^2 in parallel
|
||||
// to take advantage of instruction level parallelism
|
||||
// and hardware with multiple FMA units.
|
||||
|
||||
// note: if x == -0, this returns +0
|
||||
const Packet x2 = pmul(x, x);
|
||||
const Packet x4 = pmul(x2, x2);
|
||||
Packet q_odd = pmadd(q14, x4, q10);
|
||||
@ -852,20 +849,25 @@ Packet patan_float(const Packet& x_in) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static_assert(std::is_same<Scalar, float>::value, "Scalar type must be float");
|
||||
|
||||
constexpr float kPiOverTwo = static_cast<float>(EIGEN_PI / 2);
|
||||
|
||||
const Packet cst_signmask = pset1<Packet>(-0.0f);
|
||||
const Packet cst_one = pset1<Packet>(1.0f);
|
||||
constexpr float kPiOverTwo = static_cast<float>(EIGEN_PI/2);
|
||||
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
|
||||
|
||||
// "Large": For |x| > 1, use atan(1/x) = sign(x)*pi/2 - atan(x).
|
||||
// "Small": For |x| <= 1, approximate atan(x) directly by a polynomial
|
||||
// calculated using Sollya.
|
||||
const Packet neg_mask = pcmp_lt(x_in, pzero(x_in));
|
||||
const Packet large_mask = pcmp_lt(cst_one, pabs(x_in));
|
||||
const Packet large_shift = pselect(neg_mask, pset1<Packet>(-kPiOverTwo), pset1<Packet>(kPiOverTwo));
|
||||
const Packet x = pselect(large_mask, preciprocal(x_in), x_in);
|
||||
|
||||
const Packet abs_x = pabs(x_in);
|
||||
const Packet x_signmask = pand(x_in, cst_signmask);
|
||||
const Packet large_mask = pcmp_lt(cst_one, abs_x);
|
||||
const Packet x = pselect(large_mask, preciprocal(abs_x), abs_x);
|
||||
const Packet p = patan_reduced_float(x);
|
||||
|
||||
// Apply transformations according to the range reduction masks.
|
||||
return pselect(large_mask, psub(large_shift, p), p);
|
||||
Packet result = pselect(large_mask, psub(cst_pi_over_two, p), p);
|
||||
// Return correct sign
|
||||
return pxor(result, x_signmask);
|
||||
}
|
||||
|
||||
// Computes elementwise atan(x) for x in [-tan(pi/8):tan(pi/8)]
|
||||
@ -920,16 +922,17 @@ Packet patan_double(const Packet& x_in) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
static_assert(std::is_same<Scalar, double>::value, "Scalar type must be double");
|
||||
|
||||
const Packet cst_one = pset1<Packet>(1.0);
|
||||
constexpr double kPiOverTwo = static_cast<double>(EIGEN_PI / 2);
|
||||
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
|
||||
constexpr double kPiOverFour = static_cast<double>(EIGEN_PI / 4);
|
||||
const Packet cst_pi_over_four = pset1<Packet>(kPiOverFour);
|
||||
const Packet cst_large = pset1<Packet>(2.4142135623730950488016887); // tan(3*pi/8);
|
||||
const Packet cst_medium = pset1<Packet>(0.4142135623730950488016887); // tan(pi/8);
|
||||
constexpr double kTanPiOverEight = 0.4142135623730950488016887;
|
||||
constexpr double kTan3PiOverEight = 2.4142135623730950488016887;
|
||||
|
||||
const Packet neg_mask = pcmp_lt(x_in, pzero(x_in));
|
||||
Packet x = pabs(x_in);
|
||||
const Packet cst_signmask = pset1<Packet>(-0.0);
|
||||
const Packet cst_one = pset1<Packet>(1.0);
|
||||
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
|
||||
const Packet cst_pi_over_four = pset1<Packet>(kPiOverFour);
|
||||
const Packet cst_large = pset1<Packet>(kTan3PiOverEight);
|
||||
const Packet cst_medium = pset1<Packet>(kTanPiOverEight);
|
||||
|
||||
// Use the same range reduction strategy (to [0:tan(pi/8)]) as the
|
||||
// Cephes library:
|
||||
@ -938,10 +941,15 @@ Packet patan_double(const Packet& x_in) {
|
||||
// use atan(x) = pi/4 + atan((x-1)/(x+1)).
|
||||
// "Small": For x < tan(pi/8), approximate atan(x) directly by a polynomial
|
||||
// calculated using Sollya.
|
||||
const Packet large_mask = pcmp_lt(cst_large, x);
|
||||
x = pselect(large_mask, preciprocal(x), x);
|
||||
const Packet medium_mask = pandnot(pcmp_lt(cst_medium, x), large_mask);
|
||||
x = pselect(medium_mask, pdiv(psub(x, cst_one), padd(x, cst_one)), x);
|
||||
|
||||
const Packet abs_x = pabs(x_in);
|
||||
const Packet x_signmask = pand(x_in, cst_signmask);
|
||||
const Packet large_mask = pcmp_lt(cst_large, abs_x);
|
||||
const Packet medium_mask = pandnot(pcmp_lt(cst_medium, abs_x), large_mask);
|
||||
|
||||
Packet x = abs_x;
|
||||
x = pselect(large_mask, preciprocal(abs_x), x);
|
||||
x = pselect(medium_mask, pdiv(psub(abs_x, cst_one), padd(abs_x, cst_one)), x);
|
||||
|
||||
// Compute approximation of p ~= atan(x') where x' is the argument reduced to
|
||||
// [0:tan(pi/8)].
|
||||
@ -950,7 +958,8 @@ Packet patan_double(const Packet& x_in) {
|
||||
// Apply transformations according to the range reduction masks.
|
||||
p = pselect(large_mask, psub(cst_pi_over_two, p), p);
|
||||
p = pselect(medium_mask, padd(cst_pi_over_four, p), p);
|
||||
return pselect(neg_mask, pnegate(p), p);
|
||||
// Return the correct sign
|
||||
return pxor(p, x_signmask);
|
||||
}
|
||||
|
||||
template<typename Packet>
|
||||
@ -1751,7 +1760,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
|
||||
const Packet abs_x = pabs(x);
|
||||
// Predicates for sign and magnitude of x.
|
||||
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_zero);
|
||||
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
@ -1790,19 +1799,21 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
|
||||
const Packet pow_is_inf = por(por(por(pand(abs_x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
|
||||
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)),
|
||||
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos));
|
||||
const Packet pow_is_neg_zero = pand(pandnot(y_is_int, y_is_even),
|
||||
por(pand(y_is_neg, pand(abs_x_is_inf, x_is_neg)), pand(y_is_pos, x_is_neg_zero)));
|
||||
const Packet inf_val =
|
||||
pselect(pandnot(pand(por(pand(abs_x_is_inf, x_is_neg), pand(x_is_neg_zero, y_is_neg)), y_is_int), y_is_even),
|
||||
cst_neg_inf, cst_pos_inf);
|
||||
|
||||
// General computation of pow(x,y) for positive x or negative x and integer y.
|
||||
const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
|
||||
const Packet pow_abs = generic_pow_impl(abs_x, y);
|
||||
return pselect(
|
||||
y_is_one, x,
|
||||
pselect(pow_is_one, cst_one,
|
||||
pselect(pow_is_nan, cst_nan,
|
||||
pselect(pow_is_inf, inf_val,
|
||||
pselect(pow_is_zero, cst_zero, pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
|
||||
return pselect(y_is_one, x,
|
||||
pselect(pow_is_one, cst_one,
|
||||
pselect(pow_is_nan, cst_nan,
|
||||
pselect(pow_is_inf, inf_val,
|
||||
pselect(pow_is_neg_zero, pnegate(cst_zero),
|
||||
pselect(pow_is_zero, cst_zero,
|
||||
pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))));
|
||||
}
|
||||
|
||||
/* polevl (modified for Eigen)
|
||||
@ -2022,7 +2033,7 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(con
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
|
||||
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
|
||||
|
||||
@ -2087,7 +2098,7 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(
|
||||
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
|
||||
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
|
||||
|
||||
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
|
||||
const Packet x_has_signbit = psignbit(x);
|
||||
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
|
||||
|
||||
if (exponent_is_nan) {
|
||||
|
@ -83,6 +83,7 @@ void binary_op_test(std::string name, Fn fun, RefFn ref) {
|
||||
Scalar e = static_cast<Scalar>(ref(x(i,j), y(i,j)));
|
||||
Scalar a = actual(i, j);
|
||||
bool success = (a==e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
|
||||
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
|
||||
all_pass &= success;
|
||||
if (!success) {
|
||||
std::cout << name << "(" << x(i,j) << "," << y(i,j) << ") = " << a << " != " << e << std::endl;
|
||||
@ -125,6 +126,7 @@ void pow_scalar_exponent_test() {
|
||||
Scalar a = eigenPow(j);
|
||||
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) ||
|
||||
((numext::isnan)(a) && (numext::isnan)(e));
|
||||
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
|
||||
all_pass &= success;
|
||||
if (!success) {
|
||||
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;
|
||||
@ -138,6 +140,7 @@ void pow_scalar_exponent_test() {
|
||||
Scalar a = eigenPow(j);
|
||||
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) ||
|
||||
((numext::isnan)(a) && (numext::isnan)(e));
|
||||
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
|
||||
all_pass &= success;
|
||||
if (!success) {
|
||||
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;
|
||||
|
Loading…
x
Reference in New Issue
Block a user