From 28e008b99aab0c80b2c9eb0e4199a5d27149de3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Tue, 15 Feb 2022 21:31:51 +0000 Subject: [PATCH] Fix sqrt/rsqrt for NEON. --- Eigen/src/Core/arch/NEON/PacketMath.h | 30 +++++---------------------- test/packetmath.cpp | 8 +++---- 2 files changed, 9 insertions(+), 29 deletions(-) diff --git a/Eigen/src/Core/arch/NEON/PacketMath.h b/Eigen/src/Core/arch/NEON/PacketMath.h index e073535bb..33af653a4 100644 --- a/Eigen/src/Core/arch/NEON/PacketMath.h +++ b/Eigen/src/Core/arch/NEON/PacketMath.h @@ -3340,23 +3340,13 @@ template<> EIGEN_STRONG_INLINE Packet4ui psqrt(const Packet4ui& a) { } template<> EIGEN_STRONG_INLINE Packet4f prsqrt(const Packet4f& a) { - // Compute approximate reciprocal sqrt. - Packet4f x = vrsqrteq_f32(a); // Do Newton iterations for 1/sqrt(x). - x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); - x = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, x), x), x); - const Packet4f infinity = pset1(NumTraits::infinity()); - return pselect(pcmp_eq(a, pzero(a)), infinity, x); + return generic_rsqrt_newton_step::run(a, vrsqrteq_f32(a)); } template<> EIGEN_STRONG_INLINE Packet2f prsqrt(const Packet2f& a) { // Compute approximate reciprocal sqrt. - Packet2f x = vrsqrte_f32(a); - // Do Newton iterations for 1/sqrt(x). - x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); - x = vmul_f32(vrsqrts_f32(vmul_f32(a, x), x), x); - const Packet2f infinity = pset1(NumTraits::infinity()); - return pselect(pcmp_eq(a, pzero(a)), infinity, x); + return generic_rsqrt_newton_step::run(a, vrsqrte_f32(a)); } // Unfortunately vsqrt_f32 is only available for A64. @@ -3365,14 +3355,10 @@ template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& _x){return vsqrtq_ template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& _x){return vsqrt_f32(_x); } #else template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& a) { - const Packet4f infinity = pset1(NumTraits::infinity()); - const Packet4f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); - return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); + return generic_sqrt_newton_step::run(a, prsqrt(a)); } template<> EIGEN_STRONG_INLINE Packet2f psqrt(const Packet2f& a) { - const Packet2f infinity = pset1(NumTraits::infinity()); - const Packet2f is_zero_or_inf = por(pcmp_eq(a, pzero(a)), pcmp_eq(a, infinity)); - return pselect(is_zero_or_inf, a, pmul(a, prsqrt(a))); + return generic_sqrt_newton_step::run(a, prsqrt(a)); } #endif @@ -3966,14 +3952,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pset1frombits(uint64_t from) { return vreinterpretq_f64_u64(vdupq_n_u64(from)); } template<> EIGEN_STRONG_INLINE Packet2d prsqrt(const Packet2d& a) { - // Compute approximate reciprocal sqrt. - Packet2d x = vrsqrteq_f64(a); // Do Newton iterations for 1/sqrt(x). - x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); - x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); - x = vmulq_f64(vrsqrtsq_f64(vmulq_f64(a, x), x), x); - const Packet2d infinity = pset1(NumTraits::infinity()); - return pselect(pcmp_eq(a, pzero(a)), infinity, x); + return generic_rsqrt_newton_step::run(a, vrsqrteq_f64(a)); } template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrtq_f64(_x); } diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 8150438ef..71ccb86f7 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -954,11 +954,11 @@ void packetmath_real() { } else { data1[1] = -((std::numeric_limits::min)()); } - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); data1[0] = Scalar(0.0f); data1[1] = NumTraits::infinity(); - CHECK_CWISE1(numext::sqrt, internal::psqrt); + CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); } if (PacketTraits::HasRsqrt) { @@ -968,11 +968,11 @@ void packetmath_real() { } else { data1[1] = -((std::numeric_limits::min)()); } - CHECK_CWISE1(numext::rsqrt, internal::prsqrt); + CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); data1[0] = Scalar(0.0f); data1[1] = NumTraits::infinity(); - CHECK_CWISE1(numext::rsqrt, internal::prsqrt); + CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); } // TODO(rmlarsen): Re-enable for half and bfloat16.