Add generic fast psqrt and prsqrt impls and make them correct for 0, +Inf, NaN, and negative arguments.

This commit is contained in:
Rasmus Munk Larsen 2022-02-05 00:20:13 +00:00
parent 4bffbe84f9
commit 979fdd58a4
5 changed files with 135 additions and 169 deletions

View File

@ -29,7 +29,7 @@ namespace internal {
If the preconditions are satisfied, which they are for for the _*_rcp_ps
instructions on x86, the result has a maximum relative error of 2 ulps,
and correctly handles reciprocals of zero and infinity.
and correctly handles reciprocals of zero, infinity, and NaN.
*/
template <typename Packet, int Steps>
struct generic_reciprocal_newton_step {
@ -53,11 +53,109 @@ struct generic_reciprocal_newton_step {
template<typename Packet>
struct generic_reciprocal_newton_step<Packet, 0> {
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
run(const Packet& /*unused*/, const Packet& approx_a_recip) {
return approx_a_recip;
run(const Packet& /*unused*/, const Packet& approx_rsqrt) {
return approx_rsqrt;
}
};
/** \internal Fast reciprocal sqrt using Newton-Raphson's method.
Preconditions:
1. The starting guess provided in approx_a_recip must have at least half
the leading mantissa bits in the correct result, such that a single
Newton-Raphson step is sufficient to get within 1-2 ulps of the currect
result.
2. If a is zero, approx_a_recip must be infinite with the same sign as a.
3. If a is infinite, approx_a_recip must be zero with the same sign as a.
If the preconditions are satisfied, which they are for for the _*_rcp_ps
instructions on x86, the result has a maximum relative error of 2 ulps,
and correctly handles zero, infinity, and NaN. Positive denormals are
treated as zero.
*/
template <typename Packet, int Steps>
struct generic_rsqrt_newton_step {
static_assert(Steps > 0, "Steps must be at least 1.");
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
run(const Packet& a, const Packet& approx_rsqrt) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet one_point_five = pset1<Packet>(Scalar(1.5));
const Packet minus_half = pset1<Packet>(Scalar(-0.5));
const Packet minus_half_a = pmul(minus_half, a);
const Packet neg_mask = pcmp_lt(a, pzero(a));
Packet x =
generic_rsqrt_newton_step<Packet,Steps - 1>::run(a, approx_rsqrt);
const Packet tmp = pmul(minus_half_a, x);
// If tmp is NaN, it means that a is either 0 or Inf.
// In this case return the approximation directly.
const Packet is_not_nan = pcmp_eq(tmp, tmp);
// If a is negative, return NaN.
x = por(x, neg_mask);
// Refine the approximation using one Newton-Raphson step:
// x_{n+1} = x_n * (1.5 - x_n * ((0.5 * a) * x_n)).
const Packet x_newton = pmul(x, pmadd(tmp, x, one_point_five));
return pselect(is_not_nan, x_newton, x);
}
};
template<typename Packet>
struct generic_rsqrt_newton_step<Packet, 0> {
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
run(const Packet& /*unused*/, const Packet& approx_rsqrt) {
return approx_rsqrt;
}
};
/** \internal Fast sqrt using Newton-Raphson's method.
Preconditions:
1. The starting guess for the reciprocal sqrt provided in approx_rsqrt must
have at least half the leading mantissa bits in the correct result, such
that a single Newton-Raphson step is sufficient to get within 1-2 ulps of
the currect result.
2. If a is zero, approx_rsqrt must be infinite.
3. If a is infinite, approx_rsqrt must be zero.
If the preconditions are satisfied, which they are for for the _*_rsqrt_ps
instructions on x86, the result has a maximum relative error of 2 ulps,
and correctly handles zero and infinity, and NaN. Positive denormal inputs
are treated as zero.
*/
template <typename Packet, int Steps=1>
struct generic_sqrt_newton_step {
static_assert(Steps > 0, "Steps must be at least 1.");
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet
run(const Packet& a, const Packet& approx_rsqrt) {
using Scalar = typename unpacket_traits<Packet>::type;
const Packet one_point_five = pset1<Packet>(Scalar(1.5));
const Packet negative_mask = pcmp_lt(a, pzero(a));
const Packet minus_half_a = pmul(a, pset1<Packet>(Scalar(-0.5)));
// Set negative arguments to NaN.
const Packet a_poisoned = por(a, negative_mask);
// Do a single step of Newton's iteration for reciprocal square root:
// x_{n+1} = x_n * (1.5 - x_n * ((0.5 * a) * x_n)).
const Packet tmp = pmul(approx_rsqrt, minus_half_a);
// If tmp is NaN, it means that the argument was either 0 or +inf,
// and we should return the argument itself as the result.
const Packet return_rsqrt = pcmp_eq(tmp, tmp);
Packet rsqrt = pmul(approx_rsqrt, pmadd(tmp, approx_rsqrt, one_point_five));
for (int step = 1; step < Steps; ++step) {
rsqrt = pmul(rsqrt, pmadd(pmul(rsqrt, minus_half_a), rsqrt, one_point_five));
}
// Return sqrt(x) = x * rsqrt(x) for non-zero finite positive arguments.
// Return a itself for 0 or +inf, NaN for negative arguments.
return pselect(return_rsqrt, pmul(a_poisoned, rsqrt), por(a, negative_mask));
}
};
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
is accurate up to a couple of ulps in the (approximate) range [-8, 8],

View File

@ -89,34 +89,12 @@ pexp<Packet4d>(const Packet4d& _x) {
return pexp_double(_x);
}
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
// of Newton's method, at a cost of 1-2 bits of precision as opposed to the
// exact solution. It does not handle +inf, or denormalized numbers correctly.
// The main advantage of this approach is not just speed, but also the fact that
// it can be inlined and pipelined with other computations, further reducing its
// effective latency. This is similar to Quake3's fast inverse square root.
// For detail see here: http://www.beyond3d.com/content/articles/8/
#if EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f psqrt<Packet8f>(const Packet8f& _x) {
const Packet8f minus_half_x = pmul(_x, pset1<Packet8f>(-0.5f));
const Packet8f negative_mask = pcmp_lt(_x, pzero(_x));
const Packet8f denormal_mask =
pandnot(pcmp_lt(_x, pset1<Packet8f>((std::numeric_limits<float>::min)())),
negative_mask);
// Compute approximate reciprocal sqrt.
Packet8f rs = _mm256_rsqrt_ps(_x);
// Flush negative arguments to zero. This is a workaround which ensures
// that sqrt of a negative denormal returns -NaN, despite _mm256_rsqrt_ps
// returning -Inf for such values.
const Packet8f x_flushed = pandnot(_x, negative_mask);
// Do a single step of Newton's iteration.
rs = pmul(rs, pmadd(minus_half_x, pmul(rs,rs), pset1<Packet8f>(1.5f)));
// Flush results for denormals to zero.
return pandnot(pmul(x_flushed, rs), denormal_mask);
return generic_sqrt_newton_step<Packet8f>::run(_x, _mm256_rsqrt_ps(_x));
}
#else
@ -135,35 +113,8 @@ Packet4d psqrt<Packet4d>(const Packet4d& _x) {
#if EIGEN_FAST_MATH
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f prsqrt<Packet8f>(const Packet8f& _x) {
EIGEN_DECLARE_CONST_Packet8f_FROM_INT(inf, 0x7f800000);
EIGEN_DECLARE_CONST_Packet8f(one_point_five, 1.5f);
EIGEN_DECLARE_CONST_Packet8f(minus_half, -0.5f);
EIGEN_DECLARE_CONST_Packet8f_FROM_INT(flt_min, 0x00800000);
Packet8f neg_half = pmul(_x, p8f_minus_half);
// select only the inverse sqrt of positive normal inputs (denormals are
// flushed to zero and cause infs as well).
Packet8f lt_min_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_LT_OQ);
Packet8f inf_mask = _mm256_cmp_ps(_x, p8f_inf, _CMP_EQ_OQ);
Packet8f not_normal_finite_mask = _mm256_or_ps(lt_min_mask, inf_mask);
// Compute an approximate result using the rsqrt intrinsic.
Packet8f y_approx = _mm256_rsqrt_ps(_x);
// Do a single step of Newton-Raphson iteration to improve the approximation.
// This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
// It is essential to evaluate the inner term like this because forming
// y_n^2 may over- or underflow.
Packet8f y_newton = pmul(y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p8f_one_point_five));
// Select the result of the Newton-Raphson step for positive normal arguments.
// For other arguments, choose the output of the intrinsic. This will
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
// x is zero or a positive denormalized float (equivalent to flushing positive
// denormalized inputs to zero).
return pselect<Packet8f>(not_normal_finite_mask, y_approx, y_newton);
Packet8f prsqrt<Packet8f>(const Packet8f& a) {
return generic_rsqrt_newton_step<Packet8f, /*Steps=*/1>::run(a, _mm256_rsqrt_ps(a));
}
template<> EIGEN_STRONG_INLINE Packet8f preciprocal<Packet8f>(const Packet8f& a) {

View File

@ -155,49 +155,18 @@ EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exp
return F32ToBf16(pldexp<Packet16f>(Bf16ToF32(a), Bf16ToF32(exponent)));
}
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
// of Newton's method, at a cost of 1-2 bits of precision as opposed to the
// exact solution. The main advantage of this approach is not just speed, but
// also the fact that it can be inlined and pipelined with other computations,
// further reducing its effective latency.
#if EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
psqrt<Packet16f>(const Packet16f& _x) {
Packet16f neg_half = pmul(_x, pset1<Packet16f>(-.5f));
__mmask16 denormal_mask = _mm512_kand(
_mm512_cmp_ps_mask(_x, pset1<Packet16f>((std::numeric_limits<float>::min)()),
_CMP_LT_OQ),
_mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_GE_OQ));
Packet16f x = _mm512_rsqrt14_ps(_x);
// Do a single step of Newton's iteration.
x = pmul(x, pmadd(neg_half, pmul(x, x), pset1<Packet16f>(1.5f)));
// Flush results for denormals to zero.
return _mm512_mask_blend_ps(denormal_mask, pmul(_x,x), _mm512_setzero_ps());
return generic_sqrt_newton_step<Packet16f>::run(_x, _mm512_rsqrt14_ps(_x));
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8d
psqrt<Packet8d>(const Packet8d& _x) {
Packet8d neg_half = pmul(_x, pset1<Packet8d>(-.5));
__mmask16 denormal_mask = _mm512_kand(
_mm512_cmp_pd_mask(_x, pset1<Packet8d>((std::numeric_limits<double>::min)()),
_CMP_LT_OQ),
_mm512_cmp_pd_mask(_x, _mm512_setzero_pd(), _CMP_GE_OQ));
Packet8d x = _mm512_rsqrt14_pd(_x);
// Do a single step of Newton's iteration.
x = pmul(x, pmadd(neg_half, pmul(x, x), pset1<Packet8d>(1.5)));
// Do a second step of Newton's iteration.
x = pmul(x, pmadd(neg_half, pmul(x, x), pset1<Packet8d>(1.5)));
return _mm512_mask_blend_pd(denormal_mask, pmul(_x,x), _mm512_setzero_pd());
// Double requires 2 Newton-Raphson steps for convergence.
return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(_x, _mm512_rsqrt14_pd(_x));
}
#else
template <>
@ -226,31 +195,7 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet16f
prsqrt<Packet16f>(const Packet16f& _x) {
EIGEN_DECLARE_CONST_Packet16f_FROM_INT(inf, 0x7f800000);
EIGEN_DECLARE_CONST_Packet16f(one_point_five, 1.5f);
EIGEN_DECLARE_CONST_Packet16f(minus_half, -0.5f);
Packet16f neg_half = pmul(_x, p16f_minus_half);
// Identity infinite, negative and denormal arguments.
__mmask16 inf_mask = _mm512_cmp_ps_mask(_x, p16f_inf, _CMP_EQ_OQ);
__mmask16 not_pos_mask = _mm512_cmp_ps_mask(_x, _mm512_setzero_ps(), _CMP_LE_OQ);
__mmask16 not_finite_pos_mask = not_pos_mask | inf_mask;
// Compute an approximate result using the rsqrt intrinsic, forcing +inf
// for denormals for consistency with AVX and SSE implementations.
Packet16f y_approx = _mm512_rsqrt14_ps(_x);
// Do a single step of Newton-Raphson iteration to improve the approximation.
// This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
// It is essential to evaluate the inner term like this because forming
// y_n^2 may over- or underflow.
Packet16f y_newton = pmul(y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p16f_one_point_five));
// Select the result of the Newton-Raphson step for positive finite arguments.
// For other arguments, choose the output of the intrinsic. This will
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(0) = +inf.
return _mm512_mask_blend_ps(not_finite_pos_mask, y_newton, y_approx);
return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(_x, _mm512_rsqrt14_ps(_x));
}
#endif

View File

@ -75,30 +75,12 @@ Packet4f pcos<Packet4f>(const Packet4f& _x)
return pcos_float(_x);
}
// Functions for sqrt.
// The EIGEN_FAST_MATH version uses the _mm_rsqrt_ps approximation and one step
// of Newton's method, at a cost of 1-2 bits of precision as opposed to the
// exact solution. It does not handle +inf, or denormalized numbers correctly.
// The main advantage of this approach is not just speed, but also the fact that
// it can be inlined and pipelined with other computations, further reducing its
// effective latency. This is similar to Quake3's fast inverse square root.
// For detail see here: http://www.beyond3d.com/content/articles/8/
#if EIGEN_FAST_MATH
template<>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psqrt<Packet4f>(const Packet4f& _x)
{
const Packet4f minus_half_x = pmul(_x, pset1<Packet4f>(-0.5f));
const Packet4f denormal_mask = pandnot(
pcmp_lt(_x, pset1<Packet4f>((std::numeric_limits<float>::min)())),
pcmp_lt(_x, pzero(_x)));
// Compute approximate reciprocal sqrt.
Packet4f x = _mm_rsqrt_ps(_x);
// Do a single step of Newton's iteration.
x = pmul(x, pmadd(minus_half_x, pmul(x,x), pset1<Packet4f>(1.5f)));
// Flush results for denormals to zero.
return pandnot(pmul(_x,x), denormal_mask);
return generic_sqrt_newton_step<Packet4f>::run(_x, _mm_rsqrt_ps(_x));
}
#else
@ -117,43 +99,16 @@ Packet16b psqrt<Packet16b>(const Packet16b& x) { return x; }
#if EIGEN_FAST_MATH
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f prsqrt<Packet4f>(const Packet4f& _x) {
EIGEN_DECLARE_CONST_Packet4f(one_point_five, 1.5f);
EIGEN_DECLARE_CONST_Packet4f(minus_half, -0.5f);
EIGEN_DECLARE_CONST_Packet4f_FROM_INT(inf, 0x7f800000u);
EIGEN_DECLARE_CONST_Packet4f_FROM_INT(flt_min, 0x00800000u);
Packet4f neg_half = pmul(_x, p4f_minus_half);
// Identity infinite, zero, negative and denormal arguments.
Packet4f lt_min_mask = _mm_cmplt_ps(_x, p4f_flt_min);
Packet4f inf_mask = _mm_cmpeq_ps(_x, p4f_inf);
Packet4f not_normal_finite_mask = _mm_or_ps(lt_min_mask, inf_mask);
// Compute an approximate result using the rsqrt intrinsic.
Packet4f y_approx = _mm_rsqrt_ps(_x);
// Do a single step of Newton-Raphson iteration to improve the approximation.
// This uses the formula y_{n+1} = y_n * (1.5 - y_n * (0.5 * x) * y_n).
// It is essential to evaluate the inner term like this because forming
// y_n^2 may over- or underflow.
Packet4f y_newton = pmul(
y_approx, pmadd(y_approx, pmul(neg_half, y_approx), p4f_one_point_five));
// Select the result of the Newton-Raphson step for positive normal arguments.
// For other arguments, choose the output of the intrinsic. This will
// return rsqrt(+inf) = 0, rsqrt(x) = NaN if x < 0, and rsqrt(x) = +inf if
// x is zero or a positive denormalized float (equivalent to flushing positive
// denormalized inputs to zero).
return pselect<Packet4f>(not_normal_finite_mask, y_approx, y_newton);
Packet4f prsqrt<Packet4f>(const Packet4f& x) {
return generic_rsqrt_newton_step<Packet4f, /*Steps=*/1>::run(x, _mm_rsqrt_ps(x));
}
#ifdef EIGEN_VECTORIZE_FMA
// Trying to speed up reciprocal using Newton-Raphson is counterproductive
// unless FMA is available. Without FMA pdiv(pset1<Packet>(Scalar(1),a) is
// 30% faster.
template<> EIGEN_STRONG_INLINE Packet4f preciprocal<Packet4f>(const Packet4f& a) {
return generic_reciprocal_newton_step<Packet4f, /*Steps=*/1>::run(a, _mm_rcp_ps(a));
template<> EIGEN_STRONG_INLINE Packet4f preciprocal<Packet4f>(const Packet4f& x) {
return generic_reciprocal_newton_step<Packet4f, /*Steps=*/1>::run(x, _mm_rcp_ps(x));
}
#endif

View File

@ -946,18 +946,35 @@ void packetmath_real() {
VERIFY((numext::isnan)(data2[0]));
VERIFY((numext::isnan)(data2[1]));
}
if (PacketTraits::HasSqrt) {
test::packet_helper<PacketTraits::HasSqrt, Packet> h;
data1[0] = Scalar(-1.0f);
if (std::numeric_limits<Scalar>::has_denorm == std::denorm_present) {
data1[1] = -std::numeric_limits<Scalar>::denorm_min();
} else {
data1[1] = -((std::numeric_limits<Scalar>::min)());
}
h.store(data2, internal::psqrt(h.load(data1)));
VERIFY((numext::isnan)(data2[0]));
VERIFY((numext::isnan)(data2[1]));
CHECK_CWISE1(numext::sqrt, internal::psqrt);
data1[0] = Scalar(0.0f);
data1[1] = NumTraits<Scalar>::infinity();
CHECK_CWISE1(numext::sqrt, internal::psqrt);
}
if (PacketTraits::HasRsqrt) {
data1[0] = Scalar(-1.0f);
if (std::numeric_limits<Scalar>::has_denorm == std::denorm_present) {
data1[1] = -std::numeric_limits<Scalar>::denorm_min();
} else {
data1[1] = -((std::numeric_limits<Scalar>::min)());
}
CHECK_CWISE1(numext::rsqrt, internal::prsqrt);
data1[0] = Scalar(0.0f);
data1[1] = NumTraits<Scalar>::infinity();
CHECK_CWISE1(numext::rsqrt, internal::prsqrt);
}
// TODO(rmlarsen): Re-enable for half and bfloat16.
if (PacketTraits::HasCos
&& !internal::is_same<Scalar, half>::value