Reduce max relative error of prsqrt from 3 to 2 ulps.

This commit is contained in:
Rasmus Munk Larsen 2023-06-04 22:25:33 +00:00
parent 1d80e23186
commit 7ac8897431

View File

@ -80,26 +80,26 @@ struct generic_rsqrt_newton_step {
using Scalar = typename unpacket_traits<Packet>::type; using Scalar = typename unpacket_traits<Packet>::type;
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
run(const Packet& a, const Packet& approx_rsqrt) { run(const Packet& a, const Packet& approx_rsqrt) {
constexpr Scalar kMinusHalf = Scalar(-1)/Scalar(2);
const Packet cst_minus_half = pset1<Packet>(kMinusHalf);
const Packet cst_minus_one = pset1<Packet>(Scalar(-1)); const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_minus_half = pset1<Packet>(Scalar(-1)/Scalar(2));
// Refine the approximation using one Newton-Raphson step:
// The approximation is expressed this way to avoid over/under-flows.
// x' = x - (x/2) * ( (a*x)*x - 1)
Packet x = approx_rsqrt; Packet inv_sqrt = approx_rsqrt;
for (int step = 0; step < Steps; ++step) { for (int step = 0; step < Steps; ++step) {
Packet minushalfx = pmul(cst_minus_half, x); // Refine the approximation using one Newton-Raphson step:
Packet ax = pmul(a, x); // h_n = x * (inv_sqrt * inv_sqrt) - 1 (so that h_n is nearly 0).
Packet ax2m1 = pmadd(ax, x, cst_minus_one); // inv_sqrt = inv_sqrt - 0.5 * inv_sqrt * h_n
x = pmadd(ax2m1, minushalfx, x); Packet r2 = pmul(inv_sqrt, inv_sqrt);
Packet half_r = pmul(inv_sqrt, cst_minus_half);
Packet h_n = pmadd(a, r2, cst_minus_one);
inv_sqrt = pmadd(half_r, h_n, inv_sqrt);
} }
// If x is NaN, then either: // If x is NaN, then either:
// 1) the input is NaN // 1) the input is NaN
// 2) zero and infinity were multiplied // 2) zero and infinity were multiplied
// In either of these cases, return approx_rsqrt // In either of these cases, return approx_rsqrt
return pselect(pisnan(x), approx_rsqrt, x); return pselect(pisnan(inv_sqrt), approx_rsqrt, inv_sqrt);
} }
}; };