diff --git a/Eigen/Core b/Eigen/Core index 0e23247d3..759b1bb80 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -154,14 +154,12 @@ using std::ptrdiff_t; #if defined EIGEN_VECTORIZE_AVX512 #include "src/Core/arch/SSE/PacketMath.h" #include "src/Core/arch/SSE/TypeCasting.h" - // FIXME: this needs to be fixed (compilation issue if included) - // there is no reason to disable SSE/AVX vectorization - // of complex<> while enabling AVX512. - // #include "src/Core/arch/SSE/Complex.h" + #include "src/Core/arch/SSE/Complex.h" #include "src/Core/arch/AVX/PacketMath.h" #include "src/Core/arch/AVX/TypeCasting.h" - // #include "src/Core/arch/AVX/Complex.h" + #include "src/Core/arch/AVX/Complex.h" #include "src/Core/arch/AVX512/PacketMath.h" + #include "src/Core/arch/AVX512/Complex.h" #include "src/Core/arch/SSE/MathFunctions.h" #include "src/Core/arch/AVX/MathFunctions.h" #include "src/Core/arch/AVX512/MathFunctions.h" diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 9c2a437bf..2b2ee9e2c 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -56,6 +56,7 @@ struct default_packet_traits HasConj = 1, HasSetLinear = 1, HasBlend = 0, + HasReduxp = 1, HasDiv = 0, HasSqrt = 0, diff --git a/Eigen/src/Core/arch/AVX/Complex.h b/Eigen/src/Core/arch/AVX/Complex.h index 7fa61969d..2bb40fc79 100644 --- a/Eigen/src/Core/arch/AVX/Complex.h +++ b/Eigen/src/Core/arch/AVX/Complex.h @@ -22,6 +22,7 @@ struct Packet4cf __m256 v; }; +#ifndef EIGEN_VECTORIZE_AVX512 template<> struct packet_traits > : default_packet_traits { typedef Packet4cf type; @@ -44,6 +45,7 @@ template<> struct packet_traits > : default_packet_traits HasSetLinear = 0 }; }; +#endif template<> struct unpacket_traits { typedef std::complex type; enum {size=4, alignment=Aligned32}; typedef Packet2cf half; }; @@ -228,6 +230,7 @@ struct Packet2cd __m256d v; }; +#ifndef EIGEN_VECTORIZE_AVX512 template<> struct packet_traits > : default_packet_traits { typedef Packet2cd type; @@ -250,6 +253,7 @@ template<> struct packet_traits > : default_packet_traits HasSetLinear = 0 }; }; +#endif template<> struct unpacket_traits { typedef std::complex type; enum {size=2, alignment=Aligned32}; typedef Packet1cd half; }; diff --git a/Eigen/src/Core/arch/AVX512/Complex.h b/Eigen/src/Core/arch/AVX512/Complex.h new file mode 100644 index 000000000..2820abb49 --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/Complex.h @@ -0,0 +1,442 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2018 Gael Guennebaud +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_COMPLEX_AVX512_H +#define EIGEN_COMPLEX_AVX512_H + +namespace Eigen { + +namespace internal { + +//---------- float ---------- +struct Packet8cf +{ + EIGEN_STRONG_INLINE Packet8cf() {} + EIGEN_STRONG_INLINE explicit Packet8cf(const __m512& a) : v(a) {} + __m512 v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet8cf type; + typedef Packet4cf half; + enum { + Vectorizable = 1, + AlignedOnScalar = 1, + size = 8, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0, + HasReduxp = 0 + }; +}; + +template<> struct unpacket_traits { + typedef std::complex type; + enum { + size = 8, + alignment=unpacket_traits::alignment + }; + typedef Packet4cf half; +}; + +template<> EIGEN_STRONG_INLINE Packet8cf padd(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_add_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf psub(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_sub_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pnegate(const Packet8cf& a) +{ + return Packet8cf(pnegate(a.v)); +} +template<> EIGEN_STRONG_INLINE Packet8cf pconj(const Packet8cf& a) +{ + const __m512 mask = _mm512_castsi512_ps(_mm512_setr_epi32( + 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000, + 0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000,0x00000000,0x80000000)); + return Packet8cf(_mm512_xor_ps(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pmul(const Packet8cf& a, const Packet8cf& b) +{ + __m512 tmp2 = _mm512_mul_ps(_mm512_movehdup_ps(a.v), _mm512_permute_ps(b.v, _MM_SHUFFLE(2,3,0,1))); + return Packet8cf(_mm512_fmaddsub_ps(_mm512_moveldup_ps(a.v), b.v, tmp2)); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pand (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_and_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf por (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_or_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pxor (const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_xor_ps(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet8cf pandnot(const Packet8cf& a, const Packet8cf& b) { return Packet8cf(_mm512_andnot_ps(a.v,b.v)); } + +template<> EIGEN_STRONG_INLINE Packet8cf pload (const std::complex* from) { EIGEN_DEBUG_ALIGNED_LOAD return Packet8cf(pload(&numext::real_ref(*from))); } +template<> EIGEN_STRONG_INLINE Packet8cf ploadu(const std::complex* from) { EIGEN_DEBUG_UNALIGNED_LOAD return Packet8cf(ploadu(&numext::real_ref(*from))); } + + +template<> EIGEN_STRONG_INLINE Packet8cf pset1(const std::complex& from) +{ + return Packet8cf(_mm512_castpd_ps(pload1((const double*)(const void*)&from))); +} + +template<> EIGEN_STRONG_INLINE Packet8cf ploaddup(const std::complex* from) +{ + return Packet8cf( _mm512_castpd_ps( ploaddup((const double*)(const void*)from )) ); +} +template<> EIGEN_STRONG_INLINE Packet8cf ploadquad(const std::complex* from) +{ + return Packet8cf( _mm512_castpd_ps( ploadquad((const double*)(const void*)from )) ); +} + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex* to, const Packet8cf& from) { EIGEN_DEBUG_ALIGNED_STORE pstore(&numext::real_ref(*to), from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex* to, const Packet8cf& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu(&numext::real_ref(*to), from.v); } + +template<> EIGEN_DEVICE_FUNC inline Packet8cf pgather, Packet8cf>(const std::complex* from, Index stride) +{ + return Packet8cf(_mm512_castpd_ps(pgather((const double*)(const void*)from, stride))); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet8cf>(std::complex* to, const Packet8cf& from, Index stride) +{ + pscatter((double*)(void*)to, _mm512_castps_pd(from.v), stride); +} + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet8cf& a) +{ + return pfirst(Packet2cf(_mm512_castps512_ps128(a.v))); +} + +template<> EIGEN_STRONG_INLINE Packet8cf preverse(const Packet8cf& a) { + return Packet8cf(_mm512_castsi512_ps( + _mm512_permutexvar_epi64( _mm512_set_epi32(0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7), + _mm512_castps_si512(a.v)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet8cf& a) +{ + return predux(padd(Packet4cf(_mm512_extractf32x8_ps(a.v,0)), + Packet4cf(_mm512_extractf32x8_ps(a.v,1)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet8cf& a) +{ + return predux_mul(pmul(Packet4cf(_mm512_extractf32x8_ps(a.v, 0)), + Packet4cf(_mm512_extractf32x8_ps(a.v, 1)))); +} + +template <> +EIGEN_STRONG_INLINE Packet4cf predux_half_dowto4(const Packet8cf& a) { + __m256 lane0 = _mm512_extractf32x8_ps(a.v, 0); + __m256 lane1 = _mm512_extractf32x8_ps(a.v, 1); + __m256 res = _mm256_add_ps(lane0, lane1); + return Packet4cf(res); +} + +template +struct palign_impl +{ + static EIGEN_STRONG_INLINE void run(Packet8cf& first, const Packet8cf& second) + { + if (Offset==0) return; + palign_impl::run(first.v, second.v); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet8cf pmadd(const Packet8cf& x, const Packet8cf& y, const Packet8cf& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet8cf pmul(const Packet8cf& a, const Packet8cf& b) const + { + return internal::pmul(a, pconj(b)); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet8cf pmadd(const Packet8cf& x, const Packet8cf& y, const Packet8cf& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet8cf pmul(const Packet8cf& a, const Packet8cf& b) const + { + return internal::pmul(pconj(a), b); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet8cf pmadd(const Packet8cf& x, const Packet8cf& y, const Packet8cf& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet8cf pmul(const Packet8cf& a, const Packet8cf& b) const + { + return pconj(internal::pmul(a, b)); + } +}; + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet8cf,Packet16f) + +template<> EIGEN_STRONG_INLINE Packet8cf pdiv(const Packet8cf& a, const Packet8cf& b) +{ + Packet8cf num = pmul(a, pconj(b)); + __m512 tmp = _mm512_mul_ps(b.v, b.v); + __m512 tmp2 = _mm512_shuffle_ps(tmp,tmp,0xB1); + __m512 denom = _mm512_add_ps(tmp, tmp2); + return Packet8cf(_mm512_div_ps(num.v, denom)); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pcplxflip(const Packet8cf& x) +{ + return Packet8cf(_mm512_shuffle_ps(x.v, x.v, _MM_SHUFFLE(2, 3, 0 ,1))); +} + +//---------- double ---------- +struct Packet4cd +{ + EIGEN_STRONG_INLINE Packet4cd() {} + EIGEN_STRONG_INLINE explicit Packet4cd(const __m512d& a) : v(a) {} + __m512d v; +}; + +template<> struct packet_traits > : default_packet_traits +{ + typedef Packet4cd type; + typedef Packet2cd half; + enum { + Vectorizable = 1, + AlignedOnScalar = 0, + size = 4, + HasHalfPacket = 1, + + HasAdd = 1, + HasSub = 1, + HasMul = 1, + HasDiv = 1, + HasNegate = 1, + HasAbs = 0, + HasAbs2 = 0, + HasMin = 0, + HasMax = 0, + HasSetLinear = 0, + HasReduxp = 0 + }; +}; + +template<> struct unpacket_traits { + typedef std::complex type; + enum { + size = 4, + alignment = unpacket_traits::alignment + }; + typedef Packet2cd half; +}; + +template<> EIGEN_STRONG_INLINE Packet4cd padd(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_add_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd psub(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_sub_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pnegate(const Packet4cd& a) { return Packet4cd(pnegate(a.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pconj(const Packet4cd& a) +{ + const __m512d mask = _mm512_castsi512_pd( + _mm512_set_epi32(0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0, + 0x80000000,0x0,0x0,0x0,0x80000000,0x0,0x0,0x0)); + return Packet4cd(_mm512_xor_pd(a.v,mask)); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) +{ + __m512d tmp1 = _mm512_shuffle_pd(a.v,a.v,0x0); + __m512d tmp2 = _mm512_shuffle_pd(a.v,a.v,0xFF); + __m512d tmp3 = _mm512_shuffle_pd(b.v,b.v,0x55); + __m512d odd = _mm512_mul_pd(tmp2, tmp3); + return Packet4cd(_mm512_fmaddsub_pd(tmp1, b.v, odd)); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pand (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_and_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd por (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_or_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pxor (const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_xor_pd(a.v,b.v)); } +template<> EIGEN_STRONG_INLINE Packet4cd pandnot(const Packet4cd& a, const Packet4cd& b) { return Packet4cd(_mm512_andnot_pd(a.v,b.v)); } + +template<> EIGEN_STRONG_INLINE Packet4cd pload (const std::complex* from) +{ EIGEN_DEBUG_ALIGNED_LOAD return Packet4cd(pload((const double*)from)); } +template<> EIGEN_STRONG_INLINE Packet4cd ploadu(const std::complex* from) +{ EIGEN_DEBUG_UNALIGNED_LOAD return Packet4cd(ploadu((const double*)from)); } + +template<> EIGEN_STRONG_INLINE Packet4cd pset1(const std::complex& from) +{ + #ifdef EIGEN_VECTORIZE_AVX512DQ + return Packet4cd(_mm512_broadcast_f64x2(pset1(from).v)); + #else + return Packet4cd(_mm512_castps_pd(_mm512_broadcast_f32x4( _mm_castpd_ps(pset1(from).v)))); + #endif +} + +template<> EIGEN_STRONG_INLINE Packet4cd ploaddup(const std::complex* from) { + return Packet4cd(_mm512_insertf64x4( + _mm512_castpd256_pd512(ploaddup(from).v), ploaddup(from+1).v, 1)); +} + +template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet4cd& from) { EIGEN_DEBUG_ALIGNED_STORE pstore((double*)to, from.v); } +template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet4cd& from) { EIGEN_DEBUG_UNALIGNED_STORE pstoreu((double*)to, from.v); } + +template<> EIGEN_DEVICE_FUNC inline Packet4cd pgather, Packet4cd>(const std::complex* from, Index stride) +{ + return Packet4cd(_mm512_insertf64x4(_mm512_castpd256_pd512( + _mm256_insertf128_pd(_mm256_castpd128_pd256(pload(from+0*stride).v), pload(from+1*stride).v,1)), + _mm256_insertf128_pd(_mm256_castpd128_pd256(pload(from+2*stride).v), pload(from+3*stride).v,1), 1)); +} + +template<> EIGEN_DEVICE_FUNC inline void pscatter, Packet4cd>(std::complex* to, const Packet4cd& from, Index stride) +{ + __m512i fromi = _mm512_castpd_si512(from.v); + double* tod = (double*)(void*)to; + _mm_store_pd(tod+0*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,0)) ); + _mm_store_pd(tod+2*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,1)) ); + _mm_store_pd(tod+4*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,2)) ); + _mm_store_pd(tod+6*stride, _mm_castsi128_pd(_mm512_extracti32x4_epi32(fromi,3)) ); +} + +template<> EIGEN_STRONG_INLINE std::complex pfirst(const Packet4cd& a) +{ + __m128d low = _mm512_extractf64x2_pd(a.v, 0); + EIGEN_ALIGN16 double res[2]; + _mm_store_pd(res, low); + return std::complex(res[0],res[1]); +} + +template<> EIGEN_STRONG_INLINE Packet4cd preverse(const Packet4cd& a) { + return Packet4cd(_mm512_shuffle_f64x2(a.v, a.v, EIGEN_SSE_SHUFFLE_MASK(3,2,1,0))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux(const Packet4cd& a) +{ + return predux(padd(Packet2cd(_mm512_extractf64x4_pd(a.v,0)), + Packet2cd(_mm512_extractf64x4_pd(a.v,1)))); +} + +template<> EIGEN_STRONG_INLINE std::complex predux_mul(const Packet4cd& a) +{ + return predux_mul(pmul(Packet2cd(_mm512_extractf64x4_pd(a.v,0)), + Packet2cd(_mm512_extractf64x4_pd(a.v,1)))); +} + +template +struct palign_impl +{ + static EIGEN_STRONG_INLINE void run(Packet4cd& first, const Packet4cd& second) + { + if (Offset==0) return; + palign_impl::run(first.v, second.v); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const + { + return internal::pmul(a, pconj(b)); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const + { + return internal::pmul(pconj(a), b); + } +}; + +template<> struct conj_helper +{ + EIGEN_STRONG_INLINE Packet4cd pmadd(const Packet4cd& x, const Packet4cd& y, const Packet4cd& c) const + { return padd(pmul(x,y),c); } + + EIGEN_STRONG_INLINE Packet4cd pmul(const Packet4cd& a, const Packet4cd& b) const + { + return pconj(internal::pmul(a, b)); + } +}; + +EIGEN_MAKE_CONJ_HELPER_CPLX_REAL(Packet4cd,Packet8d) + +template<> EIGEN_STRONG_INLINE Packet4cd pdiv(const Packet4cd& a, const Packet4cd& b) +{ + Packet4cd num = pmul(a, pconj(b)); + __m512d tmp = _mm512_mul_pd(b.v, b.v); + __m512d denom = padd(_mm512_permute_pd(tmp,0x55), tmp); + return Packet4cd(_mm512_div_pd(num.v, denom)); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pcplxflip(const Packet4cd& x) +{ + return Packet4cd(_mm512_permute_pd(x.v,0x55)); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + ptranspose(reinterpret_cast&>(kernel)); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + ptranspose(reinterpret_cast&>(kernel)); +} + +EIGEN_DEVICE_FUNC inline void +ptranspose(PacketBlock& kernel) { + __m512d T0 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, EIGEN_SSE_SHUFFLE_MASK(0,1,0,1)); // [a0 a1 b0 b1] + __m512d T1 = _mm512_shuffle_f64x2(kernel.packet[0].v, kernel.packet[1].v, EIGEN_SSE_SHUFFLE_MASK(2,3,2,3)); // [a2 a3 b2 b3] + __m512d T2 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, EIGEN_SSE_SHUFFLE_MASK(0,1,0,1)); // [c0 c1 d0 d1] + __m512d T3 = _mm512_shuffle_f64x2(kernel.packet[2].v, kernel.packet[3].v, EIGEN_SSE_SHUFFLE_MASK(2,3,2,3)); // [c2 c3 d2 d3] + + kernel.packet[3] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, EIGEN_SSE_SHUFFLE_MASK(1,3,1,3))); // [a3 b3 c3 d3] + kernel.packet[2] = Packet4cd(_mm512_shuffle_f64x2(T1, T3, EIGEN_SSE_SHUFFLE_MASK(0,2,0,2))); // [a2 b2 c2 d2] + kernel.packet[1] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, EIGEN_SSE_SHUFFLE_MASK(1,3,1,3))); // [a1 b1 c1 d1] + kernel.packet[0] = Packet4cd(_mm512_shuffle_f64x2(T0, T2, EIGEN_SSE_SHUFFLE_MASK(0,2,0,2))); // [a0 b0 c0 d0] +} + +template<> EIGEN_STRONG_INLINE Packet8cf pinsertfirst(const Packet8cf& a, std::complex b) +{ + Packet2cf tmp = Packet2cf(_mm512_extractf32x4_ps(a.v,0)); + tmp = pinsertfirst(tmp, b); + return Packet8cf( _mm512_insertf32x4(a.v, tmp.v, 0) ); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pinsertfirst(const Packet4cd& a, std::complex b) +{ + return Packet4cd(_mm512_castsi512_pd( _mm512_inserti32x4(_mm512_castpd_si512(a.v), _mm_castpd_si128(pset1(b).v), 0) )); +} + +template<> EIGEN_STRONG_INLINE Packet8cf pinsertlast(const Packet8cf& a, std::complex b) +{ + Packet2cf tmp = Packet2cf(_mm512_extractf32x4_ps(a.v,3) ); + tmp = pinsertlast(tmp, b); + return Packet8cf( _mm512_insertf32x4(a.v, tmp.v, 3) ); +} + +template<> EIGEN_STRONG_INLINE Packet4cd pinsertlast(const Packet4cd& a, std::complex b) +{ + return Packet4cd(_mm512_castsi512_pd( _mm512_inserti32x4(_mm512_castpd_si512(a.v), _mm_castpd_si128(pset1(b).v), 3) )); +} + +} // end namespace internal + +} // end namespace Eigen + +#endif // EIGEN_COMPLEX_AVX512_H diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 1d38fb758..ce453f019 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -417,6 +417,7 @@ EIGEN_STRONG_INLINE Packet16f ploaddup(const float* from) { } #ifdef EIGEN_VECTORIZE_AVX512DQ +// FIXME: this does not look optimal, better load a Packet4d and shuffle... // Loads 4 doubles from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, // a3} template <> diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 0e8e0d2b3..b073fc8d4 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -61,20 +61,22 @@ template<> struct is_arithmetic<__m128> { enum { value = true }; }; template<> struct is_arithmetic<__m128i> { enum { value = true }; }; template<> struct is_arithmetic<__m128d> { enum { value = true }; }; +#define EIGEN_SSE_SHUFFLE_MASK(p,q,r,s) ((s)<<6|(r)<<4|(q)<<2|(p)) + #define vec4f_swizzle1(v,p,q,r,s) \ - (_mm_castsi128_ps(_mm_shuffle_epi32( _mm_castps_si128(v), ((s)<<6|(r)<<4|(q)<<2|(p))))) + (_mm_castsi128_ps(_mm_shuffle_epi32( _mm_castps_si128(v), EIGEN_SSE_SHUFFLE_MASK(p,q,r,s)))) #define vec4i_swizzle1(v,p,q,r,s) \ - (_mm_shuffle_epi32( v, ((s)<<6|(r)<<4|(q)<<2|(p)))) + (_mm_shuffle_epi32( v, EIGEN_SSE_SHUFFLE_MASK(p,q,r,s))) #define vec2d_swizzle1(v,p,q) \ - (_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), ((q*2+1)<<6|(q*2)<<4|(p*2+1)<<2|(p*2))))) + (_mm_castsi128_pd(_mm_shuffle_epi32( _mm_castpd_si128(v), EIGEN_SSE_SHUFFLE_MASK(2*p,2*p+1,2*q,2*q+1)))) #define vec4f_swizzle2(a,b,p,q,r,s) \ - (_mm_shuffle_ps( (a), (b), ((s)<<6|(r)<<4|(q)<<2|(p)))) + (_mm_shuffle_ps( (a), (b), EIGEN_SSE_SHUFFLE_MASK(p,q,r,s))) #define vec4i_swizzle2(a,b,p,q,r,s) \ - (_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), ((s)<<6|(r)<<4|(q)<<2|(p)))))) + (_mm_castps_si128( (_mm_shuffle_ps( _mm_castsi128_ps(a), _mm_castsi128_ps(b), EIGEN_SSE_SHUFFLE_MASK(p,q,r,s))))) #define _EIGEN_DECLARE_CONST_Packet4f(NAME,X) \ const Packet4f p4f_##NAME = pset1(X) diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index 9ca865bd1..9475a6ecc 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -506,13 +506,28 @@ public: p = pset1(ResScalar(0)); } - EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + template + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const { - dest = pset1(*b); + dest = pset1(*b); } EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { + loadRhsQuad_impl(b,dest, typename conditional::type()); + } + + EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const true_type&) const + { + // FIXME we can do better! + // what we want here is a ploadheight + RhsScalar tmp[4] = {b[0],b[0],b[1],b[1]}; + dest = ploadquad(tmp); + } + + EIGEN_STRONG_INLINE void loadRhsQuad_impl(const RhsScalar* b, RhsPacket& dest, const false_type&) const + { + eigen_internal_assert(RhsPacketSize<=8); dest = pset1(*b); } @@ -521,9 +536,10 @@ public: dest = pload(a); } - EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const + template + EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const { - dest = ploadu(a); + dest = ploadu(a); } EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) @@ -536,12 +552,14 @@ public: // pbroadcast2(b, b0, b1); // } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const { madd_impl(a, b, c, tmp, typename conditional::type()); } - EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const + template + EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const true_type&) const { #ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD EIGEN_UNUSED_VARIABLE(tmp); @@ -556,13 +574,14 @@ public: c += a * b; } - EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const + template + EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const { + const conj_helper cj; r = cj.pmadd(c,alpha,r); } protected: - conj_helper cj; }; template @@ -581,13 +600,57 @@ DoublePacket padd(const DoublePacket &a, const DoublePacket the "4" in "downto4" +// corresponds to the number of complexes, so it means "8" +// it terms of real coefficients. + template -const DoublePacket& predux_half_dowto4(const DoublePacket &a) +const DoublePacket& +predux_half_dowto4(const DoublePacket &a, + typename enable_if::size<=8>::type* = 0) { return a; } -template struct unpacket_traits > { typedef DoublePacket half; }; +template +DoublePacket::half> +predux_half_dowto4(const DoublePacket &a, + typename enable_if::size==16>::type* = 0) +{ + // yes, that's pretty hackish :( + DoublePacket::half> res; + typedef std::complex::type> Cplx; + typedef typename packet_traits::type CplxPacket; + res.first = predux_half_dowto4(CplxPacket(a.first)).v; + res.second = predux_half_dowto4(CplxPacket(a.second)).v; + return res; +} + +// same here, "quad" actually means "8" in terms of real coefficients +template +void loadQuadToDoublePacket(const Scalar* b, DoublePacket& dest, + typename enable_if::size<=8>::type* = 0) +{ + dest.first = pset1(real(*b)); + dest.second = pset1(imag(*b)); +} + +template +void loadQuadToDoublePacket(const Scalar* b, DoublePacket& dest, + typename enable_if::size==16>::type* = 0) +{ + // yes, that's pretty hackish too :( + typedef typename NumTraits::Real RealScalar; + RealScalar r[4] = {real(b[0]), real(b[0]), real(b[1]), real(b[1])}; + RealScalar i[4] = {imag(b[0]), imag(b[0]), imag(b[1]), imag(b[1])}; + dest.first = ploadquad(r); + dest.second = ploadquad(i); +} + + +template struct unpacket_traits > { + typedef DoublePacket::half> half; +}; // template // DoublePacket pmadd(const DoublePacket &a, const DoublePacket &b) // { @@ -611,10 +674,10 @@ public: ConjRhs = _ConjRhs, Vectorizable = packet_traits::Vectorizable && packet_traits::Vectorizable, - RealPacketSize = Vectorizable ? packet_traits::size : 1, - ResPacketSize = Vectorizable ? packet_traits::size : 1, - LhsPacketSize = Vectorizable ? packet_traits::size : 1, - RhsPacketSize = Vectorizable ? packet_traits::size : 1, + ResPacketSize = Vectorizable ? packet_traits::size : 1, + LhsPacketSize = Vectorizable ? packet_traits::size : 1, + RhsPacketSize = Vectorizable ? packet_traits::size : 1, + RealPacketSize = Vectorizable ? packet_traits::size : 1, // FIXME: should depend on NumberOfRegisters nr = 4, @@ -626,7 +689,7 @@ public: typedef typename packet_traits::type RealPacket; typedef typename packet_traits::type ScalarPacket; - typedef DoublePacket DoublePacketType; + typedef DoublePacket DoublePacketType; typedef typename conditional::type LhsPacket4Packing; typedef typename conditional::type LhsPacket; @@ -643,16 +706,17 @@ public: } // Scalar path - EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ResPacket& dest) const + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, ScalarPacket& dest) const { - dest = pset1(*b); + dest = pset1(*b); } // Vectorized path - EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacketType& dest) const + template + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, DoublePacket& dest) const { - dest.first = pset1(real(*b)); - dest.second = pset1(imag(*b)); + dest.first = pset1(real(*b)); + dest.second = pset1(imag(*b)); } EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, ResPacket& dest) const @@ -661,8 +725,7 @@ public: } EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, DoublePacketType& dest) const { - eigen_internal_assert(unpacket_traits::size<=4); - loadRhs(b,dest); + loadQuadToDoublePacket(b,dest); } EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) @@ -696,12 +759,14 @@ public: dest = pload((const typename unpacket_traits::type*)(a)); } - EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const + template + EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const { - dest = ploadu((const typename unpacket_traits::type*)(a)); + dest = ploadu((const typename unpacket_traits::type*)(a)); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, DoublePacketType& c, RhsPacket& /*tmp*/) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, DoublePacket& c, TmpType& /*tmp*/) const { c.first = padd(pmul(a,b.first), c.first); c.second = padd(pmul(a,b.second),c.second); @@ -714,29 +779,30 @@ public: EIGEN_STRONG_INLINE void acc(const Scalar& c, const Scalar& alpha, Scalar& r) const { r += alpha * c; } - EIGEN_STRONG_INLINE void acc(const DoublePacketType& c, const ResPacket& alpha, ResPacket& r) const + template + EIGEN_STRONG_INLINE void acc(const DoublePacket& c, const ResPacketType& alpha, ResPacketType& r) const { // assemble c - ResPacket tmp; + ResPacketType tmp; if((!ConjLhs)&&(!ConjRhs)) { - tmp = pcplxflip(pconj(ResPacket(c.second))); - tmp = padd(ResPacket(c.first),tmp); + tmp = pcplxflip(pconj(ResPacketType(c.second))); + tmp = padd(ResPacketType(c.first),tmp); } else if((!ConjLhs)&&(ConjRhs)) { - tmp = pconj(pcplxflip(ResPacket(c.second))); - tmp = padd(ResPacket(c.first),tmp); + tmp = pconj(pcplxflip(ResPacketType(c.second))); + tmp = padd(ResPacketType(c.first),tmp); } else if((ConjLhs)&&(!ConjRhs)) { - tmp = pcplxflip(ResPacket(c.second)); - tmp = padd(pconj(ResPacket(c.first)),tmp); + tmp = pcplxflip(ResPacketType(c.second)); + tmp = padd(pconj(ResPacketType(c.first)),tmp); } else if((ConjLhs)&&(ConjRhs)) { - tmp = pcplxflip(ResPacket(c.second)); - tmp = psub(pconj(ResPacket(c.first)),tmp); + tmp = pcplxflip(ResPacketType(c.second)); + tmp = psub(pconj(ResPacketType(c.first)),tmp); } r = pmadd(tmp,alpha,r); @@ -789,9 +855,10 @@ public: p = pset1(ResScalar(0)); } - EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + template + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketType& dest) const { - dest = pset1(*b); + dest = pset1(*b); } void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) @@ -813,21 +880,23 @@ public: EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const { - eigen_internal_assert(unpacket_traits::size<=4); - loadRhs(b,dest); + dest = ploadquad(b); } - EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacket& dest) const + template + EIGEN_STRONG_INLINE void loadLhsUnaligned(const LhsScalar* a, LhsPacketType& dest) const { - dest = ploaddup(a); + dest = ploaddup(a); } - EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp) const + template + EIGEN_STRONG_INLINE void madd(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp) const { madd_impl(a, b, c, tmp, typename conditional::type()); } - EIGEN_STRONG_INLINE void madd_impl(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& tmp, const true_type&) const + template + EIGEN_STRONG_INLINE void madd_impl(const LhsPacketType& a, const RhsPacketType& b, AccPacketType& c, RhsPacketType& tmp, const true_type&) const { #ifdef EIGEN_HAS_SINGLE_INSTRUCTION_MADD EIGEN_UNUSED_VARIABLE(tmp); @@ -843,13 +912,15 @@ public: c += a * b; } - EIGEN_STRONG_INLINE void acc(const AccPacket& c, const ResPacket& alpha, ResPacket& r) const + template + EIGEN_STRONG_INLINE void acc(const AccPacketType& c, const ResPacketType& alpha, ResPacketType& r) const { + const conj_helper cj; r = cj.pmadd(alpha,c,r); } protected: - conj_helper cj; + }; @@ -1649,7 +1720,7 @@ void gebp_kernel::half>::half>::size; if ((SwappedTraits::LhsProgress % 4) == 0 && (SwappedTraits::LhsProgress<=16) && - (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) && + (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) && (SwappedTraits::LhsProgress!=16 || SResPacketQuarterSize==nr)) { SAccPacket C0, C1, C2, C3; @@ -1704,7 +1775,7 @@ void gebp_kernel=8,typename unpacket_traits::half,SResPacket>::type SResPacketHalf; typedef typename conditional=8,typename unpacket_traits::half,SLhsPacket>::type SLhsPacketHalf; - typedef typename conditional=8,typename unpacket_traits::half,SRhsPacket>::type SRhsPacketHalf; + typedef typename conditional=8,typename unpacket_traits::half,SRhsPacket>::type SRhsPacketHalf; typedef typename conditional=8,typename unpacket_traits::half,SAccPacket>::type SAccPacketHalf; SResPacketHalf R = res.template gatherPacket(i, j2); @@ -1734,7 +1805,7 @@ void gebp_kernel p; - p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0); + p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0); } else { diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 144083f1b..60c9dbc36 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -314,15 +314,18 @@ template void packetmath() ref[0] *= data1[i]; VERIFY(internal::isApprox(ref[0], internal::predux_mul(internal::pload(data1))) && "internal::predux_mul"); - for (int j=0; j(data1+j*PacketSize); + for (int j=0; j(data1+j*PacketSize); + } + internal::pstore(data2, internal::preduxp(packets)); + VERIFY(areApproxAbs(ref, data2, PacketSize, refvalue) && "internal::preduxp"); } - internal::pstore(data2, internal::preduxp(packets)); - VERIFY(areApproxAbs(ref, data2, PacketSize, refvalue) && "internal::preduxp"); for (int i=0; i