Reduce max relative error of prsqrt from 3 to 2 ulps.

This commit is contained in:
Rasmus Munk Larsen 2023-06-04 22:25:33 +00:00
parent 1d80e23186
commit 7ac8897431

View File

@ -80,26 +80,26 @@ struct generic_rsqrt_newton_step {
using Scalar = typename unpacket_traits<Packet>::type;
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
run(const Packet& a, const Packet& approx_rsqrt) {
constexpr Scalar kMinusHalf = Scalar(-1)/Scalar(2);
const Packet cst_minus_half = pset1<Packet>(kMinusHalf);
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_minus_half = pset1<Packet>(Scalar(-1)/Scalar(2));
// Refine the approximation using one Newton-Raphson step:
// The approximation is expressed this way to avoid over/under-flows.
// x' = x - (x/2) * ( (a*x)*x - 1)
Packet x = approx_rsqrt;
Packet inv_sqrt = approx_rsqrt;
for (int step = 0; step < Steps; ++step) {
Packet minushalfx = pmul(cst_minus_half, x);
Packet ax = pmul(a, x);
Packet ax2m1 = pmadd(ax, x, cst_minus_one);
x = pmadd(ax2m1, minushalfx, x);
// Refine the approximation using one Newton-Raphson step:
// h_n = x * (inv_sqrt * inv_sqrt) - 1 (so that h_n is nearly 0).
// inv_sqrt = inv_sqrt - 0.5 * inv_sqrt * h_n
Packet r2 = pmul(inv_sqrt, inv_sqrt);
Packet half_r = pmul(inv_sqrt, cst_minus_half);
Packet h_n = pmadd(a, r2, cst_minus_one);
inv_sqrt = pmadd(half_r, h_n, inv_sqrt);
}
// If x is NaN, then either:
// 1) the input is NaN
// 2) zero and infinity were multiplied
// In either of these cases, return approx_rsqrt
return pselect(pisnan(x), approx_rsqrt, x);
return pselect(pisnan(inv_sqrt), approx_rsqrt, inv_sqrt);
}
};