diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 12b897572..abece01bc 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -54,6 +54,7 @@ template<> struct packet_traits : default_packet_traits AlignedOnScalar = 1, size = 16, HasHalfPacket = 1, + HasBlend = 0, #if EIGEN_GNUC_AT_LEAST(5, 3) #ifdef EIGEN_VECTORIZE_AVX512DQ HasLog = 1, @@ -470,6 +471,8 @@ EIGEN_STRONG_INLINE Packet16f ploaddup(const float* from) { __m512 pairs = _mm512_permute_ps(even_elements, _MM_SHUFFLE(2, 2, 0, 0)); return pairs; } + +#ifdef EIGEN_VECTORIZE_AVX512DQ // Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, // a3} template <> @@ -481,6 +484,17 @@ EIGEN_STRONG_INLINE Packet8d ploaddup(const double* from) { x = _mm512_insertf64x2(x, _mm_loaddup_pd(&from[3]), 3); return x; } +#else +template <> +EIGEN_STRONG_INLINE Packet8d ploaddup(const double* from) { + __m512d x = _mm512_setzero_pd(); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<0, _mm_load_sd(from+0)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<2, _mm_load_sd(from+1)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<4, _mm_load_sd(from+2)); + x = _mm512_mask_broadcastsd_pd(x, 0x3<<6, _mm_load_sd(from+3)); + return x; +} +#endif // Loads 4 floats from memory a returns the packet // {a0, a0 a0, a0, a1, a1, a1, a1, a2, a2, a2, a2, a3, a3, a3, a3} @@ -1272,11 +1286,38 @@ EIGEN_STRONG_INLINE Packet16f pblend(const Selector<16>& /*ifPacket*/, return Packet16f(); } template <> -EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& /*ifPacket*/, - const Packet8d& /*thenPacket*/, - const Packet8d& /*elsePacket*/) { - assert(false && "To be implemented"); - return Packet8d(); +EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, + const Packet8d& thenPacket, + const Packet8d& elsePacket) { + __mmask8 m = (ifPacket.select[0] ) + | (ifPacket.select[1]<<1) + | (ifPacket.select[2]<<2) + | (ifPacket.select[3]<<3) + | (ifPacket.select[4]<<4) + | (ifPacket.select[5]<<5) + | (ifPacket.select[6]<<6) + | (ifPacket.select[7]<<7); + return _mm512_mask_blend_pd(m, elsePacket, thenPacket); +} + +template<> EIGEN_STRONG_INLINE Packet16f pinsertfirst(const Packet16f& a, float b) +{ + return _mm512_mask_broadcastss_ps(a, (1), _mm_load_ss(&b)); +} + +template<> EIGEN_STRONG_INLINE Packet8d pinsertfirst(const Packet8d& a, double b) +{ + return _mm512_mask_broadcastsd_pd(a, (1), _mm_load_sd(&b)); +} + +template<> EIGEN_STRONG_INLINE Packet16f pinsertlast(const Packet16f& a, float b) +{ + return _mm512_mask_broadcastss_ps(a, (1<<15), _mm_load_ss(&b)); +} + +template<> EIGEN_STRONG_INLINE Packet8d pinsertlast(const Packet8d& a, double b) +{ + return _mm512_mask_broadcastsd_pd(a, (1<<7), _mm_load_sd(&b)); } } // end namespace internal