diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h index 247f89860..42cdfcd25 100644 --- a/Eigen/src/Core/arch/AVX512/Complex.h +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -66,7 +66,7 @@ template<> EIGEN_STRONG_INLINE Packet8cf pconj(const Packet8cf& a) const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32( 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000, 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000)); - return Packet8cf(_mm512_xor_ps(a.v,mask)); + return Packet8cf(pxor(a.v,mask)); } template<> EIGEN_STRONG_INLINE Packet8cf pmul(const Packet8cf& a, const Packet8cf& b) @@ -75,10 +75,10 @@ template<> EIGEN_STRONG_INLINE Packet8cf pmul(const Packet8cf& a, con return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(a.v), b.v, tmp2)); } -template<> EIGEN_STRONG_INLINE Packet8cf pand (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_and_ps(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet8cf por (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_or_ps(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet8cf pxor (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_xor_ps(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet8cf pandnot(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_andnot_ps(b.v,a.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pand (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf por (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pxor (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pandnot(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(pandnot(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet8cf pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet8cf(pload(&numext::real_ref(*from))); } template<> EIGEN_STRONG_INLINE Packet8cf ploadu(const std::complex* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet8cf(ploadu(&numext::real_ref(*from))); } @@ -124,20 +124,20 @@ template<> EIGEN_STRONG_INLINE Packet8cf preverse(const Packet8cf& a) { template<> EIGEN_STRONG_INLINE std::complex predux(const Packet8cf& a) { - return predux(padd(Packet4cf(_mm512_extractf32x8_ps(a.v,0)), - Packet4cf(_mm512_extractf32x8_ps(a.v,1)))); + return predux(padd(Packet4cf(extract256<0>(a.v)), + Packet4cf(extract256<1>(a.v)))); } template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet8cf& a) { - return predux_mul(pmul(Packet4cf(_mm512_extractf32x8_ps(a.v, 0)), - Packet4cf(_mm512_extractf32x8_ps(a.v, 1)))); + return predux_mul(pmul(Packet4cf(extract256<0>(a.v)), + Packet4cf(extract256<1>(a.v)))); } template <> EIGEN_STRONG_INLINE Packet4cf predux_half_dowto4(const Packet8cf& a) { - __m256 lane0 = _mm512_extractf32x8_ps(a.v, 0); - __m256 lane1 = _mm512_extractf32x8_ps(a.v, 1); + __m256 lane0 = extract256<0>(a.v); + __m256 lane1 = extract256<1>(a.v); __m256 res = _mm256_add_ps(lane0, lane1); return Packet4cf(res); } @@ -262,10 +262,10 @@ template<> EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, con return Packet4cd(_mm512_fmaddsub_pd(tmp1, b.v, odd)); } -template<> EIGEN_STRONG_INLINE Packet4cd pand (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_and_pd(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet4cd por (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_or_pd(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet4cd pxor (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_xor_pd(a.v,b.v)); } -template<> EIGEN_STRONG_INLINE Packet4cd pandnot(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_andnot_pd(b.v,a.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pand (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pand(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd por (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(por(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pxor (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pxor(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pandnot(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(pandnot(a.v,b.v)); } template<> EIGEN_STRONG_INLINE Packet4cd pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet4cd(pload((const double*)from)); } @@ -308,7 +308,7 @@ template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet4c template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet4cd& a) { - __m128d low = _mm512_extractf64x2_pd(a.v, 0); + __m128d low = extract128<0>(a.v); EIGEN_ALIGN16 double res[2]; _mm_store_pd(res, low); return std::complex(res[0],res[1]); diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index b1cbef9f1..72b09d998 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -264,12 +264,19 @@ EIGEN_STRONG_INLINE Packet8d pmax(const Packet8d& a, #ifdef EIGEN_VECTORIZE_AVX512DQ template EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm512_extractf32x8_ps(x,I); } +template EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { return _mm512_extractf64x2_pd(x,I); } EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_insertf32x8(_mm512_castps256_ps512(a),b,1); } #else // AVX512F does not define _mm512_extractf32x8_ps to extract _m256 from _m512 template EIGEN_STRONG_INLINE Packet8f extract256(Packet16f x) { return _mm256_castsi256_ps(_mm512_extracti64x4_epi64( _mm512_castps_si512(x),I)); } + +// AVX512F does not define _mm512_extractf64x2_pd to extract _m128 from _m512 +template EIGEN_STRONG_INLINE Packet2d extract128(Packet8d x) { + return _mm_castsi128_pd(_mm512_extracti32x4_epi32( _mm512_castpd_si512(x),I)); +} + EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) { return _mm512_castsi512_ps(_mm512_inserti64x4(_mm512_castsi256_si512(_mm256_castps_si256(a)), _mm256_castps_si256(b),1));