From 72f77ccb3ee2aeb3f0d7122dd1ab90d215206320 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Tue, 29 Aug 2023 00:36:07 +0000 Subject: [PATCH] Fix arm32 float division and related bugs (cherry picked from commit 81b48065ea673cd352d11ef9b6a3d86778ac962d) --- Eigen/src/Core/arch/NEON/PacketMath.h | 172 +++++++++++++++----------- test/array_cwise.cpp | 10 +- test/packetmath.cpp | 83 ++++++++++++- 3 files changed, 188 insertions(+), 77 deletions(-) diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index da4c49d73..f0fa74946 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -960,57 +960,6 @@ template<> EIGEN_STRONG_INLINE Packet2ul pmul(const Packet2ul& a, con vdup_n_u64(vgetq_lane_u64(a, 1)*vgetq_lane_u64(b, 1))); } -template<> EIGEN_STRONG_INLINE Packet2f pdiv(const Packet2f& a, const Packet2f& b) -{ -#if EIGEN_ARCH_ARM64 - return vdiv_f32(a,b); -#else - Packet2f inv, restep, div; - - // NEON does not offer a divide instruction, we have to do a reciprocal approximation - // However NEON in contrast to other SIMD engines (AltiVec/SSE), offers - // a reciprocal estimate AND a reciprocal step -which saves a few instructions - // vrecpeq_f32() returns an estimate to 1/b, which we will finetune with - // Newton-Raphson and vrecpsq_f32() - inv = vrecpe_f32(b); - - // This returns a differential, by which we will have to multiply inv to get a better - // approximation of 1/b. - restep = vrecps_f32(b, inv); - inv = vmul_f32(restep, inv); - - // Finally, multiply a by 1/b and get the wanted result of the division. - div = vmul_f32(a, inv); - - return div; -#endif -} -template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) -{ -#if EIGEN_ARCH_ARM64 - return vdivq_f32(a,b); -#else - Packet4f inv, restep, div; - - // NEON does not offer a divide instruction, we have to do a reciprocal approximation - // However NEON in contrast to other SIMD engines (AltiVec/SSE), offers - // a reciprocal estimate AND a reciprocal step -which saves a few instructions - // vrecpeq_f32() returns an estimate to 1/b, which we will finetune with - // Newton-Raphson and vrecpsq_f32() - inv = vrecpeq_f32(b); - - // This returns a differential, by which we will have to multiply inv to get a better - // approximation of 1/b. - restep = vrecpsq_f32(b, inv); - inv = vmulq_f32(restep, inv); - - // Finally, multiply a by 1/b and get the wanted result of the division. - div = vmulq_f32(a, inv); - - return div; -#endif -} - template<> EIGEN_STRONG_INLINE Packet4c pdiv(const Packet4c& /*a*/, const Packet4c& /*b*/) { eigen_assert(false && "packet integer division are not supported by NEON"); @@ -3289,40 +3238,115 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) { return res; } -template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) { +EIGEN_STRONG_INLINE Packet4f prsqrt_float_unsafe(const Packet4f& a) { // Compute approximate reciprocal sqrt. - Packet4f x = vrsqrteq_f32(a); - // Do Newton iterations for 1/sqrt(x). - x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); - x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); - const Packet4f infinity = pset1(NumTraits::infinity()); - return pselect(pcmp_eq(a, pzero(a)), infinity, x); + // Does not correctly handle +/- 0 or +inf + float32x4_t result = vrsqrteq_f32(a); + result = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, result), result), result); + result = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, result), result), result); + return result; +} + +EIGEN_STRONG_INLINE Packet2f prsqrt_float_unsafe(const Packet2f& a) { + // Compute approximate reciprocal sqrt. + // Does not correctly handle +/- 0 or +inf + float32x2_t result = vrsqrte_f32(a); + result = vmul_f32(vrsqrts_f32(vmul_f32(a, result), result), result); + result = vmul_f32(vrsqrts_f32(vmul_f32(a, result), result), result); + return result; +} + +template Packet prsqrt_float_common(const Packet& a) { + const Packet cst_zero = pzero(a); + const Packet cst_inf = pset1(NumTraits::infinity()); + Packet return_zero = pcmp_eq(a, cst_inf); + Packet return_inf = pcmp_eq(a, cst_zero); + Packet result = prsqrt_float_unsafe(a); + result = pselect(return_inf, por(cst_inf, a), result); + result = pandnot(result, return_zero); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) { + return prsqrt_float_common(a); } template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) { - // Compute approximate reciprocal sqrt. - Packet2f x = vrsqrte_f32(a); - // Do Newton iterations for 1/sqrt(x). - x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); - x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); - const Packet2f infinity = pset1(NumTraits::infinity()); - return pselect(pcmp_eq(a, pzero(a)), infinity, x); + return prsqrt_float_common(a); +} + +EIGEN_STRONG_INLINE Packet4f preciprocal(const Packet4f& a) +{ + // Compute approximate reciprocal. + float32x4_t result = vrecpeq_f32(a); + result = vmulq_f32(vrecpsq_f32(a, result), result); + result = vmulq_f32(vrecpsq_f32(a, result), result); + return result; +} + +EIGEN_STRONG_INLINE Packet2f preciprocal(const Packet2f& a) +{ + // Compute approximate reciprocal. + float32x2_t result = vrecpe_f32(a); + result = vmul_f32(vrecps_f32(a, result), result); + result = vmul_f32(vrecps_f32(a, result), result); + return result; } // Unfortunately vsqrt_f32 is only available for A64. #if EIGEN_ARCH_ARM64 -template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_f32(_x);} -template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); } +template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) { return vsqrtq_f32(a); } + +template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) { return vsqrt_f32(a); } + +template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) { return vdivq_f32(a, b); } + +template<> EIGEN_STRONG_INLINE Packet2f pdiv(const Packet2f& a, const Packet2f& b) { return vdiv_f32(a, b); } #else -template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) { - const Packet4f infinity = pset1(NumTraits::infinity()); - const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); - return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); +template +EIGEN_STRONG_INLINE Packet psqrt_float_common(const Packet& a) { + const Packet cst_zero = pzero(a); + const Packet cst_inf = pset1(NumTraits::infinity()); + + Packet result = pmul(a, prsqrt_float_unsafe(a)); + Packet a_is_zero = pcmp_eq(a, cst_zero); + Packet a_is_inf = pcmp_eq(a, cst_inf); + Packet return_a = por(a_is_zero, a_is_inf); + + result = pselect(return_a, a, result); + return result; } + +template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) { + return psqrt_float_common(a); +} + template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) { - const Packet2f infinity = pset1(NumTraits::infinity()); - const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); - return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); + return psqrt_float_common(a); +} + +template +EIGEN_STRONG_INLINE Packet pdiv_float_common(const Packet& a, const Packet& b) { + // if b is large, NEON intrinsics will flush preciprocal(b) to zero + // avoid underflow with the following manipulation: + // a / b = f * (a * reciprocal(f * b)) + + const Packet cst_one = pset1(1.0f); + const Packet cst_quarter = pset1(0.25f); + const Packet cst_thresh = pset1(NumTraits::highest() / 4.0f); + + Packet b_will_underflow = pcmp_le(cst_thresh, pabs(b)); + Packet f = pselect(b_will_underflow, cst_quarter, cst_one); + Packet result = pmul(f, pmul(a, preciprocal(pmul(b, f)))); + return result; +} + +template<> EIGEN_STRONG_INLINE Packet4f pdiv(const Packet4f& a, const Packet4f& b) { + return pdiv_float_common(a, b); +} + +template<> EIGEN_STRONG_INLINE Packet2f pdiv(const Packet2f& a, const Packet2f& b) { + return pdiv_float_common(a, b); } #endif diff --git a/test/array_cwise.cpp b/test/array_cwise.cpp index f57e04273..bbb74b1a6 100644 --- a/test/array_cwise.cpp +++ b/test/array_cwise.cpp @@ -22,7 +22,7 @@ void pow_test() { const Scalar sqrt2 = Scalar(std::sqrt(2)); const Scalar inf = Eigen::NumTraits::infinity(); const Scalar nan = Eigen::NumTraits::quiet_NaN(); - const Scalar denorm_min = std::numeric_limits::denorm_min(); + const Scalar denorm_min = EIGEN_ARCH_ARM ? zero : std::numeric_limits::denorm_min(); const Scalar min = (std::numeric_limits::min)(); const Scalar max = (std::numeric_limits::max)(); const Scalar max_exp = (static_cast(int(Eigen::NumTraits::max_exponent())) * Scalar(EIGEN_LN2)) / eps; @@ -356,7 +356,12 @@ template void array_real(const ArrayType& m) m3(rows, cols), m4 = m1; - m4 = (m4.abs()==Scalar(0)).select(Scalar(1),m4); + // avoid denormalized values so verification doesn't fail on platforms that don't support them + // denormalized behavior is tested elsewhere (unary_op_test, binary_ops_test) + const Scalar min = (std::numeric_limits::min)(); + m1 = (m1.abs()(); @@ -396,6 +401,7 @@ template void array_real(const ArrayType& m) // avoid inf and NaNs so verification doesn't fail m3 = m4.abs(); + VERIFY_IS_APPROX(m3.sqrt(), sqrt(abs(m3))); VERIFY_IS_APPROX(m3.rsqrt(), Scalar(1)/sqrt(abs(m3))); VERIFY_IS_APPROX(rsqrt(m3), Scalar(1)/sqrt(abs(m3))); diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 518b801b9..711a69474 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -631,6 +631,85 @@ Scalar log2(Scalar x) { return Scalar(EIGEN_LOG2E) * std::log(x); } +// Create a functor out of a function so it can be passed (with overloads) +// to another function as an input argument. +#define CREATE_FUNCTOR(Name, Func) \ +struct Name { \ + template \ + T operator()(const T& val) const { \ + return Func(val); \ + } \ + } + +CREATE_FUNCTOR(psqrt_functor, internal::psqrt); +CREATE_FUNCTOR(prsqrt_functor, internal::prsqrt); + +// TODO(rmlarsen): Run this test for more functions. +template +void packetmath_test_IEEE_corner_cases(const RefFunctorT& ref_fun, + const FunctorT& fun) { + const int PacketSize = internal::unpacket_traits::size; + const Scalar norm_min = (std::numeric_limits::min)(); + const Scalar norm_max = (std::numeric_limits::max)(); + + constexpr int size = PacketSize * 2; + EIGEN_ALIGN_MAX Scalar data1[size]; + EIGEN_ALIGN_MAX Scalar data2[size]; + EIGEN_ALIGN_MAX Scalar ref[size]; + for (int i = 0; i < size; ++i) { + data1[i] = data2[i] = ref[i] = Scalar(0); + } + + // Test for subnormals. + if (Cond && std::numeric_limits::has_denorm == std::denorm_present && !EIGEN_ARCH_ARM) { + + for (int scale = 1; scale < 5; ++scale) { + // When EIGEN_FAST_MATH is 1 we relax the conditions slightly, and allow the function + // to return the same value for subnormals as the reference would return for zero with + // the same sign as the input. +#if EIGEN_FAST_MATH + data1[0] = Scalar(scale) * std::numeric_limits::denorm_min(); + data1[1] = -data1[0]; + test::packet_helper h; + h.store(data2, fun(h.load(data1))); + for (int i=0; i < PacketSize; ++i) { + const Scalar ref_zero = ref_fun(data1[i] < 0 ? -Scalar(0) : Scalar(0)); + const Scalar ref_val = ref_fun(data1[i]); + VERIFY(((std::isnan)(data2[i]) && (std::isnan)(ref_val)) || data2[i] == ref_zero || + verifyIsApprox(data2[i], ref_val)); + } +#else + CHECK_CWISE1_IF(Cond, ref_fun, fun); +#endif + } + } + + // Test for smallest normalized floats. + data1[0] = norm_min; + data1[1] = -data1[0]; + CHECK_CWISE1_IF(Cond, ref_fun, fun); + + // Test for largest floats. + data1[0] = norm_max; + data1[1] = -data1[0]; + CHECK_CWISE1_IF(Cond, ref_fun, fun); + + // Test for zeros. + data1[0] = Scalar(0.0); + data1[1] = -data1[0]; + CHECK_CWISE1_IF(Cond, ref_fun, fun); + + // Test for infinities. + data1[0] = NumTraits::infinity(); + data1[1] = -data1[0]; + CHECK_CWISE1_IF(Cond, ref_fun, fun); + + // Test for quiet NaNs. + data1[0] = std::numeric_limits::quiet_NaN(); + data1[1] = -std::numeric_limits::quiet_NaN(); + CHECK_CWISE1_IF(Cond, ref_fun, fun); +} + template void packetmath_real() { typedef internal::packet_traits PacketTraits; @@ -735,13 +814,15 @@ void packetmath_real() { CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp); if (PacketTraits::HasExp) { // Check denormals: + #if !EIGEN_ARCH_ARM for (int j=0; j<3; ++j) { data1[0] = Scalar(std::ldexp(1, NumTraits::min_exponent()-j)); CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp); data1[0] = -data1[0]; CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp); } - + #endif + // zero data1[0] = Scalar(0); CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);