Fix a bug in the implementation of Carmack's fast sqrt algorithm in Eigen (enabled by EIGEN_FAST_MATH), which causes the vectorized parts of the computation to return -0.0 instead of NaN for negative arguments.

Benchmark speed in Giga-sqrts/s
Intel(R) Xeon(R) CPU E5-1650 v3 @ 3.50GHz
-----------------------------------------
                    SSE        AVX
Fast=1              2.529G     4.380G
Fast=0              1.944G     1.898G
Fast=1 fixed        2.214G     3.739G

This table illustrates the worst case in terms speed impact: It was measured by repeatedly computing the sqrt of an n=4096 float vector that fits in L1 cache. For large vectors the operation becomes memory bound and the differences between the different versions almost negligible.
This commit is contained in:
Rasmus Munk Larsen 2016-10-04 14:22:56 -07:00
parent 6af5ac7e27
commit 3ed67cb0bb
3 changed files with 34 additions and 40 deletions

View File

@ -362,23 +362,17 @@ pexp<Packet4d>(const Packet4d& _x) {
template <> template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet8f
psqrt<Packet8f>(const Packet8f& _x) { psqrt<Packet8f>(const Packet8f& _x) {
_EIGEN_DECLARE_CONST_Packet8f(one_point_five, 1.5f); Packet8f half = pmul(_x, pset1<Packet8f>(.5f));
_EIGEN_DECLARE_CONST_Packet8f(minus_half, -0.5f); Packet8f denormal_mask = _mm256_and_ps(
_EIGEN_DECLARE_CONST_Packet8f_FROM_INT(flt_min, 0x00800000); _mm256_cmpge_ps(_x, _mm256_setzero_ps()),
_mm256_cmplt_ps(_x, pset1<Packet8f>((std::numeric_limits<float>::min)())));
Packet8f neg_half = pmul(_x, p8f_minus_half);
// select only the inverse sqrt of positive normal inputs (denormals are
// flushed to zero and cause infs as well).
Packet8f non_zero_mask = _mm256_cmp_ps(_x, p8f_flt_min, _CMP_GE_OQ);
Packet8f x = _mm256_and_ps(non_zero_mask, _mm256_rsqrt_ps(_x));
// Compute approximate reciprocal sqrt.
Packet8f x = _mm256_rsqrt_ps(_x);
// Do a single step of Newton's iteration. // Do a single step of Newton's iteration.
x = pmul(x, pmadd(neg_half, pmul(x, x), p8f_one_point_five)); x = pmul(x, psub(pset1<Packet8f>(1.5f), pmul(half, pmul(x,x))));
// Flush results for denormals to zero.
// Multiply the original _x by it's reciprocal square root to extract the return _mm256_andnot_ps(denormal_mask, pmul(_x,x));
// square root.
return pmul(_x, x);
} }
#else #else
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED

View File

@ -451,13 +451,16 @@ template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f psqrt<Packet4f>(const Packet4f& _x) Packet4f psqrt<Packet4f>(const Packet4f& _x)
{ {
Packet4f half = pmul(_x, pset1<Packet4f>(.5f)); Packet4f half = pmul(_x, pset1<Packet4f>(.5f));
Packet4f denormal_mask = _mm_and_ps(
_mm_cmpge_ps(_x, _mm_setzero_ps()),
_mm_cmplt_ps(_x, pset1<Packet4f>((std::numeric_limits<float>::min)())));
/* select only the inverse sqrt of non-zero inputs */ // Compute approximate reciprocal sqrt.
Packet4f non_zero_mask = _mm_cmpge_ps(_x, pset1<Packet4f>((std::numeric_limits<float>::min)())); Packet4f x = _mm_rsqrt_ps(_x);
Packet4f x = _mm_and_ps(non_zero_mask, _mm_rsqrt_ps(_x)); // Do a single step of Newton's iteration.
x = pmul(x, psub(pset1<Packet4f>(1.5f), pmul(half, pmul(x,x)))); x = pmul(x, psub(pset1<Packet4f>(1.5f), pmul(half, pmul(x,x))));
return pmul(_x,x); // Flush results for denormals to zero.
return _mm_andnot_ps(denormal_mask, pmul(_x,x));
} }
#else #else

View File

@ -440,12 +440,9 @@ template<typename Scalar> void packetmath_real()
data1[0] = Scalar(-1.0f); data1[0] = Scalar(-1.0f);
h.store(data2, internal::plog(h.load(data1))); h.store(data2, internal::plog(h.load(data1)));
VERIFY((numext::isnan)(data2[0])); VERIFY((numext::isnan)(data2[0]));
#if !EIGEN_FAST_MATH
h.store(data2, internal::psqrt(h.load(data1))); h.store(data2, internal::psqrt(h.load(data1)));
VERIFY((numext::isnan)(data2[0])); VERIFY((numext::isnan)(data2[0]));
VERIFY((numext::isnan)(data2[1])); VERIFY((numext::isnan)(data2[1]));
#endif
} }
} }