Optimize generic_rsqrt_newton_step

This commit is contained in:
Charles Schlosser 2023-03-24 22:42:57 +00:00 committed by Rasmus Munk Larsen
parent b8b8a26145
commit a08649994f

View File

@ -77,26 +77,29 @@ struct generic_reciprocal_newton_step<Packet, 0> {
template <typename Packet, int Steps>
struct generic_rsqrt_newton_step {
static_assert(Steps > 0, "Steps must be at least 1.");
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
using Scalar = typename unpacket_traits<Packet>::type;
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
run(const Packet& a, const Packet& approx_rsqrt) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet one_point_five = pset1<Packet>(Scalar(1.5));
const Packet minus_half = pset1<Packet>(Scalar(-0.5));
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
const Packet cst_minus_half = pset1<Packet>(Scalar(-1)/Scalar(2));
// Refine the approximation using one Newton-Raphson step:
// x_{n+1} = x_n * (1.5 + (-0.5 * x_n) * (a * x_n)).
// The approximation is expressed this way to avoid over/under-flows.
Packet x_newton = pmul(approx_rsqrt, pmadd(pmul(minus_half, approx_rsqrt), pmul(a, approx_rsqrt), one_point_five));
for (int step = 1; step < Steps; ++step) {
x_newton = pmul(x_newton, pmadd(pmul(minus_half, x_newton), pmul(a, x_newton), one_point_five));
// x' = x - (x/2) * ( (a*x)*x - 1)
Packet x = approx_rsqrt;
for (int step = 0; step < Steps; ++step) {
Packet minushalfx = pmul(cst_minus_half, x);
Packet ax = pmul(a, x);
Packet ax2m1 = pmadd(ax, x, cst_minus_one);
x = pmadd(ax2m1, minushalfx, x);
}
// If approx_rsqrt is 0 or +/-inf, we should return it as is. Note:
// on intel, approx_rsqrt can be inf for small denormal values.
const Packet return_approx = por(pcmp_eq(approx_rsqrt, pzero(a)),
pcmp_eq(pabs(approx_rsqrt), pset1<Packet>(NumTraits<Scalar>::infinity())));
return pselect(return_approx, approx_rsqrt, x_newton);
// If x is NaN, then either:
// 1) the input is NaN
// 2) zero and infinity were multiplied
// In either of these cases, return approx_rsqrt
return pselect(pisnan(x), approx_rsqrt, x);
}
};
@ -108,7 +111,6 @@ struct generic_rsqrt_newton_step<Packet, 0> {
}
};
/** \internal Fast sqrt using Newton-Raphson's method.
Preconditions: