Optimize various mathematical packet ops

This commit is contained in:
Charles Schlosser 2023-01-28 01:34:26 +00:00 committed by Rasmus Munk Larsen
parent 1aa6dc2007
commit 0471e61b4c
3 changed files with 79 additions and 67 deletions

View File

@ -1234,38 +1234,36 @@ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet patan2(const Packet& y, const Packe
// for how corner cases are supposed to be handled according to the
// IEEE floating-point standard (IEC 60559).
// bend two rules:
// 1) inf / inf == 1
// 2) 0 / 0 == 0
// otherwise, evaluate atan(y/x) as usual and shift to the appropriate quadrant
const Packet kSignMask = pset1<Packet>(-Scalar(0));
const Packet kZero = pzero(x);
const Packet kOne = pset1<Packet>(Scalar(1));
const Packet kPi = pset1<Packet>(Scalar(EIGEN_PI));
const Packet kInf = pset1<Packet>(NumTraits<Scalar>::infinity());
const Packet abs_x = pabs(x);
const Packet x_is_zero = pcmp_eq(abs_x, kZero);
const Packet x_is_inf = pcmp_eq(abs_x, kInf);
const Packet x_has_signbit = psignbit(x);
const Packet abs_y = pabs(y);
const Packet y_is_zero = pcmp_eq(abs_y, kZero);
const Packet y_is_inf = pcmp_eq(abs_y, kInf);
const Packet y_signmask = pand(y, kSignMask);
const Packet shift = por(pand(x_has_signbit, kPi), y_signmask);
const Packet arg_signmask = pand(pxor(x, y), kSignMask);
const Packet shift = pxor(pand(x_has_signbit, kPi), y_signmask);
const Packet xor_xy = pxor(x, y);
// if x and y have the same absolute value, then xor(x,y) is zero
// make sure that neither x nor y is nan
// furthermore, xor(x,y) has the sign of the result
const Packet x_and_y_are_same = pand(pcmp_eq(xor_xy, kZero), pcmp_eq(x, x));
// more strictly, if x and y are both zero, then or(x,y) is zero
// this implicitly checks for nan
// the sign of or(x,y) is not meaningful
const Packet x_and_y_are_zero = pcmp_eq(por(x, y), kZero);
// bend two rules:
// 1) 0 / 0 == 0
// 2) inf / inf == 1
// otherwise, evaluate atan(y/x) as usual and shift to the appropriate quadrant
Packet arg = pdiv(abs_y, abs_x);
arg = pselect(pand(x_is_zero, y_is_zero), kZero, arg);
arg = pselect(pand(x_is_inf, y_is_inf), kOne, arg);
Packet arg = pdiv(y, x);
arg = pselect(x_and_y_are_same, por(kOne, xor_xy), arg);
arg = pselect(x_and_y_are_zero, xor_xy, arg);
Packet result = patan(arg);
result = pxor(result, arg_signmask);
result = padd(result, shift);
return result;
}

View File

@ -728,20 +728,19 @@ Packet pacos_float(const Packet& x_in) {
const Packet cst_one = pset1<Packet>(Scalar(1));
const Packet cst_pi = pset1<Packet>(Scalar(EIGEN_PI));
const Packet p6 = pset1<Packet>(Scalar(2.26911413483321666717529296875e-3));
const Packet p5 = pset1<Packet>(Scalar(-1.1063250713050365447998046875e-2));
const Packet p4 = pset1<Packet>(Scalar(2.680264413356781005859375e-2));
const Packet p3 = pset1<Packet>(Scalar(-4.87488098442554473876953125e-2));
const Packet p2 = pset1<Packet>(Scalar(8.874166011810302734375e-2));
const Packet p1 = pset1<Packet>(Scalar(-0.2145837843418121337890625));
const Packet p0 = pset1<Packet>(Scalar(1.57079613208770751953125));
const Packet p6 = pset1<Packet>(Scalar(2.36423197202384471893310546875e-3));
const Packet p5 = pset1<Packet>(Scalar(-1.1368644423782825469970703125e-2));
const Packet p4 = pset1<Packet>(Scalar(2.717843465507030487060546875e-2));
const Packet p3 = pset1<Packet>(Scalar(-4.8969544470310211181640625e-2));
const Packet p2 = pset1<Packet>(Scalar(8.8804088532924652099609375e-2));
const Packet p1 = pset1<Packet>(Scalar(-0.214591205120086669921875));
const Packet p0 = pset1<Packet>(Scalar(1.57079637050628662109375));
// For x in [0:1], we approximate acos(x)/sqrt(1-x), which is a smooth
// function, by a 6'th order polynomial.
// For x in [-1:0) we use that acos(-x) = pi - acos(x).
const Packet neg_mask = pcmp_lt(x_in, pzero(x_in));
Packet x = pabs(x_in);
const Packet invalid_mask = pcmp_lt(pset1<Packet>(1.0f), x);
const Packet neg_mask = psignbit(x_in);
const Packet abs_x = pabs(x_in);
// Evaluate the polynomial using Horner's rule:
// P(x) = p0 + x * (p1 + x * (p2 + ... (p5 + x * p6)) ... ) .
@ -753,19 +752,15 @@ Packet pacos_float(const Packet& x_in) {
p_even = pmadd(p_even, x2, p2);
p_odd = pmadd(p_odd, x2, p1);
p_even = pmadd(p_even, x2, p0);
Packet p = pmadd(p_odd, x, p_even);
Packet p = pmadd(p_odd, abs_x, p_even);
// The polynomial approximates acos(x)/sqrt(1-x), so
// multiply by sqrt(1-x) to get acos(x).
Packet denom = psqrt(psub(cst_one, x));
// Conveniently returns NaN for arguments outside [-1:1].
Packet denom = psqrt(psub(cst_one, abs_x));
Packet result = pmul(denom, p);
// Undo mapping for negative arguments.
result = pselect(neg_mask, psub(cst_pi, result), result);
// Return NaN for arguments outside [-1:1].
return pselect(invalid_mask,
pset1<Packet>(std::numeric_limits<float>::quiet_NaN()),
result);
return pselect(neg_mask, psub(cst_pi, result), result);
}
// Generic implementation of asin(x).
@ -834,6 +829,8 @@ Packet patan_reduced_float(const Packet& x) {
// We evaluate even and odd terms in x^2 in parallel
// to take advantage of instruction level parallelism
// and hardware with multiple FMA units.
// note: if x == -0, this returns +0
const Packet x2 = pmul(x, x);
const Packet x4 = pmul(x2, x2);
Packet q_odd = pmadd(q14, x4, q10);
@ -852,20 +849,25 @@ Packet patan_float(const Packet& x_in) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, float>::value, "Scalar type must be float");
constexpr float kPiOverTwo = static_cast<float>(EIGEN_PI / 2);
const Packet cst_signmask = pset1<Packet>(-0.0f);
const Packet cst_one = pset1<Packet>(1.0f);
constexpr float kPiOverTwo = static_cast<float>(EIGEN_PI/2);
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
// "Large": For |x| > 1, use atan(1/x) = sign(x)*pi/2 - atan(x).
// "Small": For |x| <= 1, approximate atan(x) directly by a polynomial
// calculated using Sollya.
const Packet neg_mask = pcmp_lt(x_in, pzero(x_in));
const Packet large_mask = pcmp_lt(cst_one, pabs(x_in));
const Packet large_shift = pselect(neg_mask, pset1<Packet>(-kPiOverTwo), pset1<Packet>(kPiOverTwo));
const Packet x = pselect(large_mask, preciprocal(x_in), x_in);
const Packet abs_x = pabs(x_in);
const Packet x_signmask = pand(x_in, cst_signmask);
const Packet large_mask = pcmp_lt(cst_one, abs_x);
const Packet x = pselect(large_mask, preciprocal(abs_x), abs_x);
const Packet p = patan_reduced_float(x);
// Apply transformations according to the range reduction masks.
return pselect(large_mask, psub(large_shift, p), p);
Packet result = pselect(large_mask, psub(cst_pi_over_two, p), p);
// Return correct sign
return pxor(result, x_signmask);
}
// Computes elementwise atan(x) for x in [-tan(pi/8):tan(pi/8)]
@ -920,16 +922,17 @@ Packet patan_double(const Packet& x_in) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, double>::value, "Scalar type must be double");
const Packet cst_one = pset1<Packet>(1.0);
constexpr double kPiOverTwo = static_cast<double>(EIGEN_PI / 2);
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
constexpr double kPiOverFour = static_cast<double>(EIGEN_PI / 4);
const Packet cst_pi_over_four = pset1<Packet>(kPiOverFour);
const Packet cst_large = pset1<Packet>(2.4142135623730950488016887); // tan(3*pi/8);
const Packet cst_medium = pset1<Packet>(0.4142135623730950488016887); // tan(pi/8);
constexpr double kTanPiOverEight = 0.4142135623730950488016887;
constexpr double kTan3PiOverEight = 2.4142135623730950488016887;
const Packet neg_mask = pcmp_lt(x_in, pzero(x_in));
Packet x = pabs(x_in);
const Packet cst_signmask = pset1<Packet>(-0.0);
const Packet cst_one = pset1<Packet>(1.0);
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
const Packet cst_pi_over_four = pset1<Packet>(kPiOverFour);
const Packet cst_large = pset1<Packet>(kTan3PiOverEight);
const Packet cst_medium = pset1<Packet>(kTanPiOverEight);
// Use the same range reduction strategy (to [0:tan(pi/8)]) as the
// Cephes library:
@ -938,10 +941,15 @@ Packet patan_double(const Packet& x_in) {
// use atan(x) = pi/4 + atan((x-1)/(x+1)).
// "Small": For x < tan(pi/8), approximate atan(x) directly by a polynomial
// calculated using Sollya.
const Packet large_mask = pcmp_lt(cst_large, x);
x = pselect(large_mask, preciprocal(x), x);
const Packet medium_mask = pandnot(pcmp_lt(cst_medium, x), large_mask);
x = pselect(medium_mask, pdiv(psub(x, cst_one), padd(x, cst_one)), x);
const Packet abs_x = pabs(x_in);
const Packet x_signmask = pand(x_in, cst_signmask);
const Packet large_mask = pcmp_lt(cst_large, abs_x);
const Packet medium_mask = pandnot(pcmp_lt(cst_medium, abs_x), large_mask);
Packet x = abs_x;
x = pselect(large_mask, preciprocal(abs_x), x);
x = pselect(medium_mask, pdiv(psub(abs_x, cst_one), padd(abs_x, cst_one)), x);
// Compute approximation of p ~= atan(x') where x' is the argument reduced to
// [0:tan(pi/8)].
@ -950,7 +958,8 @@ Packet patan_double(const Packet& x_in) {
// Apply transformations according to the range reduction masks.
p = pselect(large_mask, psub(cst_pi_over_two, p), p);
p = pselect(medium_mask, padd(cst_pi_over_four, p), p);
return pselect(neg_mask, pnegate(p), p);
// Return the correct sign
return pxor(p, x_signmask);
}
template<typename Packet>
@ -1751,7 +1760,7 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
const Packet abs_x = pabs(x);
// Predicates for sign and magnitude of x.
const Packet abs_x_is_zero = pcmp_eq(abs_x, cst_zero);
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
const Packet x_has_signbit = psignbit(x);
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
@ -1790,19 +1799,21 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
const Packet pow_is_inf = por(por(por(pand(abs_x_is_zero, y_is_neg), pand(abs_x_is_inf, y_is_pos)),
pand(pand(abs_x_is_lt_one, abs_y_is_huge), y_is_neg)),
pand(pand(abs_x_is_gt_one, abs_y_is_huge), y_is_pos));
const Packet pow_is_neg_zero = pand(pandnot(y_is_int, y_is_even),
por(pand(y_is_neg, pand(abs_x_is_inf, x_is_neg)), pand(y_is_pos, x_is_neg_zero)));
const Packet inf_val =
pselect(pandnot(pand(por(pand(abs_x_is_inf, x_is_neg), pand(x_is_neg_zero, y_is_neg)), y_is_int), y_is_even),
cst_neg_inf, cst_pos_inf);
// General computation of pow(x,y) for positive x or negative x and integer y.
const Packet negate_pow_abs = pandnot(x_is_neg, y_is_even);
const Packet pow_abs = generic_pow_impl(abs_x, y);
return pselect(
y_is_one, x,
pselect(pow_is_one, cst_one,
pselect(pow_is_nan, cst_nan,
pselect(pow_is_inf, inf_val,
pselect(pow_is_zero, cst_zero, pselect(negate_pow_abs, pnegate(pow_abs), pow_abs))))));
return pselect(y_is_one, x,
pselect(pow_is_one, cst_one,
pselect(pow_is_nan, cst_nan,
pselect(pow_is_inf, inf_val,
pselect(pow_is_neg_zero, pnegate(cst_zero),
pselect(pow_is_zero, cst_zero,
pselect(negate_pow_abs, pnegate(pow_abs), pow_abs)))))));
}
/* polevl (modified for Eigen)
@ -2022,7 +2033,7 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_int_errors(con
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
const Packet x_has_signbit = psignbit(x);
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
const Packet x_is_neg_zero = pand(x_has_signbit, abs_x_is_zero);
@ -2087,7 +2098,7 @@ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
const Packet abs_x_is_inf = pcmp_eq(abs_x, cst_pos_inf);
const Packet x_has_signbit = pcmp_eq(por(pand(x, cst_neg_inf), cst_pos_inf), cst_neg_inf);
const Packet x_has_signbit = psignbit(x);
const Packet x_is_neg = pandnot(x_has_signbit, abs_x_is_zero);
if (exponent_is_nan) {

View File

@ -83,6 +83,7 @@ void binary_op_test(std::string name, Fn fun, RefFn ref) {
Scalar e = static_cast<Scalar>(ref(x(i,j), y(i,j)));
Scalar a = actual(i, j);
bool success = (a==e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) || ((numext::isnan)(a) && (numext::isnan)(e));
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
all_pass &= success;
if (!success) {
std::cout << name << "(" << x(i,j) << "," << y(i,j) << ") = " << a << " != " << e << std::endl;
@ -125,6 +126,7 @@ void pow_scalar_exponent_test() {
Scalar a = eigenPow(j);
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) ||
((numext::isnan)(a) && (numext::isnan)(e));
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
all_pass &= success;
if (!success) {
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;
@ -138,6 +140,7 @@ void pow_scalar_exponent_test() {
Scalar a = eigenPow(j);
bool success = (a == e) || ((numext::isfinite)(e) && internal::isApprox(a, e, tol)) ||
((numext::isnan)(a) && (numext::isnan)(e));
if ((a == a) && (e == e)) success &= (bool)numext::signbit(e) == (bool)numext::signbit(a);
all_pass &= success;
if (!success) {
std::cout << "pow(" << bases(j) << "," << exponent << ") = " << a << " != " << e << std::endl;