From 1f54164eca5a8960205b94dd3f3cec87089e4ac6 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Sat, 7 Jul 2018 00:15:07 +0200 Subject: [PATCH] Fix a few issues with Packet16h --- Eigen/src/Core/arch/CUDA/PacketMathHalf.h | 33 ++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h b/Eigen/src/Core/arch/CUDA/PacketMathHalf.h index c068351ce..ae9193b0d 100644 --- a/Eigen/src/Core/arch/CUDA/PacketMathHalf.h +++ b/Eigen/src/Core/arch/CUDA/PacketMathHalf.h @@ -362,10 +362,10 @@ struct packet_traits : default_packet_traits { AlignedOnScalar = 1, size = 16, HasHalfPacket = 0, - HasAdd = 0, - HasSub = 0, - HasMul = 0, - HasNegate = 0, + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasNegate = 1, HasAbs = 0, HasAbs2 = 0, HasMin = 0, @@ -414,6 +414,21 @@ template<> EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet1 _mm256_storeu_si256((__m256i*)to, from.x); } +template<> EIGEN_STRONG_INLINE Packet16h +ploaddup(const Eigen::half* from) { + Packet16h result; + unsigned short a = from[0].x; + unsigned short b = from[1].x; + unsigned short c = from[2].x; + unsigned short d = from[3].x; + unsigned short e = from[4].x; + unsigned short f = from[5].x; + unsigned short g = from[6].x; + unsigned short h = from[7].x; + result.x = _mm256_set_epi16(h, h, g, g, f, f, e, e, d, d, c, c, b, b, a, a); + return result; +} + template<> EIGEN_STRONG_INLINE Packet16h ploadquad(const Eigen::half* from) { Packet16h result; @@ -519,6 +534,11 @@ template<> EIGEN_STRONG_INLINE half predux(const Packet16h& from) { return half(predux(from_float)); } +template<> EIGEN_STRONG_INLINE half predux_mul(const Packet16h& from) { + Packet16f from_float = half2float(from); + return half(predux_mul(from_float)); +} + template<> EIGEN_STRONG_INLINE Packet16h preduxp(const Packet16h* p) { Packet16f pf[16]; pf[0] = half2float(p[0]); @@ -545,8 +565,9 @@ template<> EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) { __m128i m = _mm_setr_epi8(14,15,12,13,10,11,8,9,6,7,4,5,2,3,0,1); Packet16h res; - res.x = _mm256_set_m128i(_mm_shuffle_epi8(_mm256_extractf128_si256(a.x,0),m), - _mm_shuffle_epi8(_mm256_extractf128_si256(a.x,1),m)); + res.x = _mm256_insertf128_si256( + _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(a.x,1),m)), + _mm_shuffle_epi8(_mm256_extractf128_si256(a.x,0),m), 1); return res; }