From 0b9e3dcd06585d28ac4b59dfd518b0a49af3a359 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Tue, 10 May 2016 11:05:33 -0700 Subject: [PATCH] Added packet primitives to compute exp, log, sqrt and rsqrt on fp16. This improves the performance by 10 to 30%. --- Eigen/src/Core/arch/CUDA/PacketMathHalf.h | 38 ++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h b/Eigen/src/Core/arch/CUDA/PacketMathHalf.h index 0cebc1017..8873d5357 100644 --- a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +++ b/Eigen/src/Core/arch/CUDA/PacketMathHalf.h @@ -34,7 +34,11 @@ template<> struct packet_traits : default_packet_traits AlignedOnScalar = 1, size=2, HasHalfPacket = 0, - HasDiv = 1 + HasDiv = 1, + HasSqrt = 1, + HasRsqrt = 1, + HasExp = 1, + HasLog = 1 }; }; @@ -267,6 +271,38 @@ template<> EIGEN_DEVICE_FUNC inline half predux_mul(const half2& a) { #endif } +template<> EIGEN_DEVICE_FUNC inline half2 plog(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = logf(a1); + float r2 = logf(a2); + return __floats2half2_rn(r1, r2); +} + +template<> EIGEN_DEVICE_FUNC inline half2 pexp(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = expf(a1); + float r2 = expf(a2); + return __floats2half2_rn(r1, r2); +} + +template<> EIGEN_DEVICE_FUNC inline half2 psqrt(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = sqrtf(a1); + float r2 = sqrtf(a2); + return __floats2half2_rn(r1, r2); +} + +template<> EIGEN_DEVICE_FUNC inline half2 prsqrt(const half2& a) { + float a1 = __low2float(a); + float a2 = __high2float(a); + float r1 = rsqrtf(a1); + float r2 = rsqrtf(a2); + return __floats2half2_rn(r1, r2); +} + } // end namespace internal } // end namespace Eigen