From f41b1f1666e91dc674a42fed9c444c91f483133f Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 26 Feb 2015 09:42:41 -0800 Subject: [PATCH] Added support for fast reciprocal square root computation. --- Eigen/src/Core/GenericPacketMath.h | 9 ++++++++ Eigen/src/Core/arch/AVX/MathFunctions.h | 21 +++++++++++++++---- Eigen/src/Core/arch/AVX/PacketMath.h | 2 ++ Eigen/src/Core/arch/SSE/MathFunctions.h | 18 ++++++++++++++-- Eigen/src/Core/arch/SSE/PacketMath.h | 2 ++ Eigen/src/Core/functors/UnaryFunctors.h | 19 +++++++++++++++++ Eigen/src/Core/util/ForwardDeclarations.h | 1 + .../Eigen/CXX11/src/Tensor/TensorBase.h | 6 ++++++ 8 files changed, 72 insertions(+), 6 deletions(-) diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 967a07df5..721280b2c 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -58,6 +58,7 @@ struct default_packet_traits HasDiv = 0, HasSqrt = 0, + HasRsqrt = 0, HasExp = 0, HasLog = 0, HasPow = 0, @@ -352,6 +353,14 @@ Packet plog(const Packet& a) { using std::log; return log(a); } template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt(const Packet& a) { using std::sqrt; return sqrt(a); } +/** \internal \returns the reciprocal square-root of \a a (coeff-wise) */ +template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS +Packet prsqrt(const Packet& a) { + using std::sqrt; + const Packet one(1); + return one/sqrt(a); +} + /*************************************************************************** * The following functions might not have to be overwritten for vectorized types ***************************************************************************/ diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index 2810a7a0b..aecfdd6ad 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -300,16 +300,29 @@ psqrt(const Packet8f& _x) { return pmul(_x, x); } #else -template <> -EIGEN_STRONG_INLINE Packet8f psqrt(const Packet8f& x) { +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f psqrt(const Packet8f& x) { return _mm256_sqrt_ps(x); } #endif -template <> -EIGEN_STRONG_INLINE Packet4d psqrt(const Packet4d& x) { +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4d psqrt(const Packet4d& x) { return _mm256_sqrt_pd(x); } +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet8f prsqrt(const Packet8f& x) { + _EIGEN_DECLARE_CONST_Packet8f(one, 1.0f); + return _mm256_div_ps(p8f_one, _mm256_sqrt_ps(x)); +} +#endif +template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4d prsqrt(const Packet4d& x) { + _EIGEN_DECLARE_CONST_Packet4d(one, 1.0); + return _mm256_div_pd(p4d_one, _mm256_sqrt_pd(x)); +} + + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 695185a49..fb20a45cc 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -65,6 +65,7 @@ template<> struct packet_traits : default_packet_traits HasLog = 1, HasExp = 1, HasSqrt = 1, + HasRsqrt = 1, HasBlend = 1 }; }; @@ -81,6 +82,7 @@ template<> struct packet_traits : default_packet_traits HasDiv = 1, HasExp = 0, HasSqrt = 1, + HasRsqrt = 1, HasBlend = 1 }; }; diff --git a/Eigen/src/Core/arch/SSE/MathFunctions.h b/Eigen/src/Core/arch/SSE/MathFunctions.h index f86c0a39a..f9cb93bfc 100644 --- a/Eigen/src/Core/arch/SSE/MathFunctions.h +++ b/Eigen/src/Core/arch/SSE/MathFunctions.h @@ -462,11 +462,25 @@ Packet4f psqrt(const Packet4f& _x) #else -template<> EIGEN_STRONG_INLINE Packet4f psqrt(const Packet4f& x) { return _mm_sqrt_ps(x); } +template<>EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f psqrt(const Packet4f& x) { return _mm_sqrt_ps(x); } #endif -template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& x) { return _mm_sqrt_pd(x); } +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d psqrt(const Packet2d& x) { return _mm_sqrt_pd(x); } + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet4f prsqrt(const Packet4f& x) { + // Unfortunately we can't use the much faster mm_rqsrt_ps since it only provides an approximation. + return _mm_div_ps(pset1(1.0f), _mm_sqrt_ps(x)); +} + +template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED +Packet2d prsqrt(const Packet2d& x) { + // Unfortunately we can't use the much faster mm_rqsrt_pd since it only provides an approximation. + return _mm_div_pd(pset1(1.0), _mm_sqrt_pd(x)); +} } // end namespace internal diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 3653783fd..38a84273d 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -108,6 +108,7 @@ template<> struct packet_traits : default_packet_traits HasLog = 1, HasExp = 1, HasSqrt = 1, + HasRsqrt = 1, HasBlend = 1 }; }; @@ -124,6 +125,7 @@ template<> struct packet_traits : default_packet_traits HasDiv = 1, HasExp = 1, HasSqrt = 1, + HasRsqrt = 1, HasBlend = 1 }; }; diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index ec42e6850..a3b5f50b1 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -224,6 +224,25 @@ struct functor_traits > }; }; +/** \internal + * \brief Template functor to compute the reciprocal square root of a scalar + * \sa class CwiseUnaryOp, Cwise::rsqrt() + */ +template struct scalar_rsqrt_op { + EIGEN_EMPTY_STRUCT_CTOR(scalar_rsqrt_op) + EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const { using std::sqrt; return Scalar(1)/sqrt(a); } + typedef typename packet_traits::type Packet; + inline Packet packetOp(const Packet& a) const { return internal::prsqrt(a); } +}; + +template +struct functor_traits > +{ enum { + Cost = 5 * NumTraits::MulCost, + PacketAccess = packet_traits::HasRsqrt + }; +}; + /** \internal * \brief Template functor to compute the cosine of a scalar * \sa class CwiseUnaryOp, ArrayBase::cos() diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index c23892c50..0d24beb5a 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -189,6 +189,7 @@ template struct scalar_imag_op; template struct scalar_abs_op; template struct scalar_abs2_op; template struct scalar_sqrt_op; +template struct scalar_rsqrt_op; template struct scalar_exp_op; template struct scalar_log_op; template struct scalar_cos_op; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index cfcf18e8e..13709b504 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -77,6 +77,12 @@ class TensorBase return unaryExpr(internal::scalar_sqrt_op()); } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> + rsqrt() const { + return unaryExpr(internal::scalar_rsqrt_op()); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseUnaryOp, const Derived> square() const {