From 518fc321cb318c959a51b09030a57c7be0415452 Mon Sep 17 00:00:00 2001 From: b-shi Date: Wed, 16 Mar 2022 18:04:50 +0000 Subject: [PATCH] AVX512 Optimizations for Triangular Solve --- Eigen/Core | 1 + Eigen/src/Core/GenericPacketMath.h | 27 +- Eigen/src/Core/arch/AVX/PacketMath.h | 23 +- Eigen/src/Core/arch/AVX512/PacketMath.h | 201 ++- .../src/Core/arch/AVX512/trsmKernel_impl.hpp | 1106 +++++++++++++++ Eigen/src/Core/arch/AVX512/unrolls_impl.hpp | 1213 +++++++++++++++++ .../Core/products/TriangularSolverMatrix.h | 221 ++- 7 files changed, 2664 insertions(+), 128 deletions(-) create mode 100644 Eigen/src/Core/arch/AVX512/trsmKernel_impl.hpp create mode 100644 Eigen/src/Core/arch/AVX512/unrolls_impl.hpp diff --git a/Eigen/Core b/Eigen/Core index 1074332bb..5f5ccc04f 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -190,6 +190,7 @@ using std::ptrdiff_t; #include "src/Core/arch/SSE/MathFunctions.h" #include "src/Core/arch/AVX/MathFunctions.h" #include "src/Core/arch/AVX512/MathFunctions.h" + #include "src/Core/arch/AVX512/trsmKernel_impl.hpp" #elif defined EIGEN_VECTORIZE_AVX // Use AVX for floats and doubles, SSE for integers #include "src/Core/arch/SSE/PacketMath.h" diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 670af354a..3ea6855eb 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -220,6 +220,15 @@ padd(const Packet& a, const Packet& b) { return a+b; } template<> EIGEN_DEVICE_FUNC inline bool padd(const bool& a, const bool& b) { return a || b; } +/** \internal \returns a packet version of \a *from, (un-aligned masked add) + * There is no generic implementation. We only have implementations for specialized + * cases. Generic case should not be called. + */ +template EIGEN_DEVICE_FUNC inline +std::enable_if_t::masked_fpops_available, Packet> +padd(const Packet& a, const Packet& b, typename unpacket_traits::mask_t umask); + + /** \internal \returns a - b (coeff-wise) */ template EIGEN_DEVICE_FUNC inline Packet psub(const Packet& a, const Packet& b) { return a-b; } @@ -359,16 +368,16 @@ struct bytewise_bitwise_helper { EIGEN_DEVICE_FUNC static inline T bitwise_and(const T& a, const T& b) { return binary(a, b, bit_and()); } - EIGEN_DEVICE_FUNC static inline T bitwise_or(const T& a, const T& b) { + EIGEN_DEVICE_FUNC static inline T bitwise_or(const T& a, const T& b) { return binary(a, b, bit_or()); } EIGEN_DEVICE_FUNC static inline T bitwise_xor(const T& a, const T& b) { return binary(a, b, bit_xor()); } - EIGEN_DEVICE_FUNC static inline T bitwise_not(const T& a) { + EIGEN_DEVICE_FUNC static inline T bitwise_not(const T& a) { return unary(a,bit_not()); } - + private: template EIGEN_DEVICE_FUNC static inline T unary(const T& a, Op op) { @@ -810,7 +819,7 @@ Packet plog10(const Packet& a) { EIGEN_USING_STD(log10); return log10(a); } template EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet plog2(const Packet& a) { typedef typename internal::unpacket_traits::type Scalar; - return pmul(pset1(Scalar(EIGEN_LOG2E)), plog(a)); + return pmul(pset1(Scalar(EIGEN_LOG2E)), plog(a)); } /** \internal \returns the square-root of \a a (coeff-wise) */ @@ -877,7 +886,7 @@ predux(const Packet& a) template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_mul( const Packet& a) { - typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::type Scalar; return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmul))); } @@ -885,14 +894,14 @@ EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_mul( template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_min( const Packet &a) { - typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::type Scalar; return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin))); } template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_min( const Packet& a) { - typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::type Scalar; return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin))); } @@ -900,14 +909,14 @@ EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_min( template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_max( const Packet &a) { - typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::type Scalar; return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax))); } template EIGEN_DEVICE_FUNC inline typename unpacket_traits::type predux_max( const Packet& a) { - typedef typename unpacket_traits::type Scalar; + typedef typename unpacket_traits::type Scalar; return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax))); } diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index 7d7418885..21e26d4e3 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -236,7 +236,11 @@ template<> struct unpacket_traits { typedef Packet4f half; typedef Packet8i integer_packet; typedef uint8_t mask_t; - enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true}; + enum {size=8, alignment=Aligned32, vectorizable=true, masked_load_available=true, masked_store_available=true +#ifdef EIGEN_VECTORIZE_AVX512 + , masked_fpops_available=true +#endif + }; }; template<> struct unpacket_traits { typedef double type; @@ -464,6 +468,13 @@ template<> EIGEN_STRONG_INLINE Packet8f pload1(const float* from) { r template<> EIGEN_STRONG_INLINE Packet4d pload1(const double* from) { return _mm256_broadcast_sd(from); } template<> EIGEN_STRONG_INLINE Packet8f padd(const Packet8f& a, const Packet8f& b) { return _mm256_add_ps(a,b); } +#ifdef EIGEN_VECTORIZE_AVX512 +template <> +EIGEN_STRONG_INLINE Packet8f padd(const Packet8f& a, const Packet8f& b, uint8_t umask) { + __mmask8 mask = static_cast<__mmask8>(umask); + return _mm256_maskz_add_ps(mask, a, b); +} +#endif template<> EIGEN_STRONG_INLINE Packet4d padd(const Packet4d& a, const Packet4d& b) { return _mm256_add_pd(a,b); } template<> EIGEN_STRONG_INLINE Packet8i padd(const Packet8i& a, const Packet8i& b) { #ifdef EIGEN_VECTORIZE_AVX2 @@ -848,11 +859,16 @@ template<> EIGEN_STRONG_INLINE Packet4d ploadu(const double* from) { E template<> EIGEN_STRONG_INLINE Packet8i ploadu(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(reinterpret_cast(from)); } template<> EIGEN_STRONG_INLINE Packet8f ploadu(const float* from, uint8_t umask) { +#ifdef EIGEN_VECTORIZE_AVX512 + __mmask8 mask = static_cast<__mmask8>(umask); + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskz_loadu_ps(mask, from); +#else Packet8i mask = _mm256_set1_epi8(static_cast(umask)); const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe); mask = por(mask, bit_mask); mask = pcmp_eq(mask, _mm256_set1_epi32(0xffffffff)); EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_maskload_ps(from, mask); +#endif } // Loads 4 floats from memory a returns the packet {a0, a0 a1, a1, a2, a2, a3, a3} @@ -911,11 +927,16 @@ template<> EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet4d& template<> EIGEN_STRONG_INLINE void pstoreu(int* to, const Packet8i& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(reinterpret_cast<__m256i*>(to), from); } template<> EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet8f& from, uint8_t umask) { +#ifdef EIGEN_VECTORIZE_AVX512 + __mmask8 mask = static_cast<__mmask8>(umask); + EIGEN_DEBUG_UNALIGNED_STORE return _mm256_mask_storeu_ps(to, mask, from); +#else Packet8i mask = _mm256_set1_epi8(static_cast(umask)); const Packet8i bit_mask = _mm256_set_epi32(0xffffff7f, 0xffffffbf, 0xffffffdf, 0xffffffef, 0xfffffff7, 0xfffffffb, 0xfffffffd, 0xfffffffe); mask = por(mask, bit_mask); mask = pcmp_eq(mask, _mm256_set1_epi32(0xffffffff)); EIGEN_DEBUG_UNALIGNED_STORE return _mm256_maskstore_ps(to, mask, from); +#endif } // NOTE: leverage _mm256_i32gather_ps and _mm256_i32gather_pd if AVX2 instructions are available diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 2a98b78a3..f6916c853 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -180,13 +180,14 @@ struct unpacket_traits { typedef Packet8f half; typedef Packet16i integer_packet; typedef uint16_t mask_t; - enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true }; + enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true }; }; template <> struct unpacket_traits { typedef double type; typedef Packet4d half; - enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false }; + typedef uint8_t mask_t; + enum { size = 8, alignment=Aligned64, vectorizable=true, masked_load_available=true, masked_store_available=true, masked_fpops_available=true }; }; template <> struct unpacket_traits { @@ -244,11 +245,25 @@ template<> EIGEN_STRONG_INLINE Packet8d peven_mask(const Packet8d& /*a*/) { template <> EIGEN_STRONG_INLINE Packet16f pload1(const float* from) { +#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0) + // Inline asm here helps reduce some register spilling in TRSM kernels. + // See note in unrolls::gemm::microKernel in trsmKernel_impl.hpp + Packet16f ret; + __asm__ ("vbroadcastss %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from)); + return ret; +#else return _mm512_broadcastss_ps(_mm_load_ps1(from)); +#endif } template <> EIGEN_STRONG_INLINE Packet8d pload1(const double* from) { +#if (EIGEN_COMP_GNUC != 0) || (EIGEN_COMP_CLANG != 0) + Packet8d ret; + __asm__ ("vbroadcastsd %[mem], %[dst]" : [dst] "=v" (ret) : [mem] "m" (*from)); + return ret; +#else return _mm512_set1_pd(*from); +#endif } template <> @@ -285,6 +300,20 @@ EIGEN_STRONG_INLINE Packet16i padd(const Packet16i& a, const Packet16i& b) { return _mm512_add_epi32(a, b); } +template <> +EIGEN_STRONG_INLINE Packet16f padd(const Packet16f& a, + const Packet16f& b, + uint16_t umask) { + __mmask16 mask = static_cast<__mmask16>(umask); + return _mm512_maskz_add_ps(mask, a, b); +} +template <> +EIGEN_STRONG_INLINE Packet8d padd(const Packet8d& a, + const Packet8d& b, + uint8_t umask) { + __mmask8 mask = static_cast<__mmask8>(umask); + return _mm512_maskz_add_pd(mask, a, b); +} template <> EIGEN_STRONG_INLINE Packet16f psub(const Packet16f& a, @@ -771,12 +800,16 @@ EIGEN_STRONG_INLINE Packet16i ploadu(const int* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_si512( reinterpret_cast(from)); } - template <> EIGEN_STRONG_INLINE Packet16f ploadu(const float* from, uint16_t umask) { __mmask16 mask = static_cast<__mmask16>(umask); EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_ps(mask, from); } +template <> +EIGEN_STRONG_INLINE Packet8d ploadu(const double* from, uint8_t umask) { + __mmask8 mask = static_cast<__mmask8>(umask); + EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_maskz_loadu_pd(mask, from); +} // Loads 8 floats from memory a returns the packet // {a0, a0 a1, a1, a2, a2, a3, a3, a4, a4, a5, a5, a6, a6, a7, a7} @@ -886,6 +919,11 @@ EIGEN_STRONG_INLINE void pstoreu(float* to, const Packet16f& from, uint16 __mmask16 mask = static_cast<__mmask16>(umask); EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_ps(to, mask, from); } +template <> +EIGEN_STRONG_INLINE void pstoreu(double* to, const Packet8d& from, uint8_t umask) { + __mmask8 mask = static_cast<__mmask8>(umask); + EIGEN_DEBUG_UNALIGNED_STORE return _mm512_mask_storeu_pd(to, mask, from); +} template <> EIGEN_DEVICE_FUNC inline Packet16f pgather(const float* from, @@ -1017,7 +1055,7 @@ EIGEN_STRONG_INLINE Packet16f pfrexp(const Packet16f& a, Packet16f& e // Extract exponent without existence of Packet8l. template<> -EIGEN_STRONG_INLINE +EIGEN_STRONG_INLINE Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) { const Packet8d cst_exp_mask = pset1frombits(static_cast(0x7ff0000000000000ull)); #ifdef EIGEN_VECTORIZE_AVX512DQ @@ -1040,11 +1078,11 @@ template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, cons // Clamp exponent to [-2099, 2099] const Packet8d max_exponent = pset1(2099.0); const Packet8i e = _mm512_cvtpd_epi32(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); - + // Split 2^e into four factors and multiply. const Packet8i bias = pset1(1023); Packet8i b = parithmetic_shift_right<2>(e); // floor(e/4) - + // 2^b const Packet8i permute_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); Packet8i hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx); @@ -1052,7 +1090,7 @@ template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, cons hi = _mm256_slli_epi64(_mm256_srli_epi64(hi, 32), 52); Packet8d c = _mm512_castsi512_pd(_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); Packet8d out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) - + // 2^(e - 3b) b = psub(psub(psub(e, b), b), b); // e - 3b hi = _mm256_permutevar8x32_epi32(padd(b, bias), permute_idx); @@ -1072,7 +1110,7 @@ template<> EIGEN_STRONG_INLINE Packet8d pldexp(const Packet8d& a, cons // AVX512F does not define _mm512_extracti32x8_epi32 to extract _m256i from _m512i #define EIGEN_EXTRACT_8i_FROM_16i(INPUT, OUTPUT) \ __m256i OUTPUT##_0 = _mm512_extracti32x8_epi32(INPUT, 0); \ - __m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1) + __m256i OUTPUT##_1 = _mm512_extracti32x8_epi32(INPUT, 1) #else #define EIGEN_EXTRACT_8f_FROM_16f(INPUT, OUTPUT) \ __m256 OUTPUT##_0 = _mm256_insertf128_ps( \ @@ -1392,6 +1430,56 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { EIGEN_INSERT_8f_INTO_16f(OUTPUT[INDEX], INPUT[2 * INDEX], \ INPUT[2 * INDEX + STRIDE]); +EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { + __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0],kernel.packet[1]); + __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0],kernel.packet[1]); + __m512 T2 = _mm512_unpacklo_ps(kernel.packet[2],kernel.packet[3]); + __m512 T3 = _mm512_unpackhi_ps(kernel.packet[2],kernel.packet[3]); + __m512 T4 = _mm512_unpacklo_ps(kernel.packet[4],kernel.packet[5]); + __m512 T5 = _mm512_unpackhi_ps(kernel.packet[4],kernel.packet[5]); + __m512 T6 = _mm512_unpacklo_ps(kernel.packet[6],kernel.packet[7]); + __m512 T7 = _mm512_unpackhi_ps(kernel.packet[6],kernel.packet[7]); + + kernel.packet[0] = reinterpret_cast<__m512>( + _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2))); + kernel.packet[1] = reinterpret_cast<__m512>( + _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T0),reinterpret_cast<__m512d>(T2))); + kernel.packet[2] = reinterpret_cast<__m512>( + _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3))); + kernel.packet[3] = reinterpret_cast<__m512>( + _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T1),reinterpret_cast<__m512d>(T3))); + kernel.packet[4] = reinterpret_cast<__m512>( + _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6))); + kernel.packet[5] = reinterpret_cast<__m512>( + _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T4),reinterpret_cast<__m512d>(T6))); + kernel.packet[6] = reinterpret_cast<__m512>( + _mm512_unpacklo_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7))); + kernel.packet[7] = reinterpret_cast<__m512>( + _mm512_unpackhi_pd(reinterpret_cast<__m512d>(T5),reinterpret_cast<__m512d>(T7))); + + T0 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[4]), 0x4E)); + T0 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[0], T0); + T4 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[0]), 0x4E)); + T4 = _mm512_mask_blend_ps(0xF0F0, T4, kernel.packet[4]); + T1 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[5]), 0x4E)); + T1 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[1], T1); + T5 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[1]), 0x4E)); + T5 = _mm512_mask_blend_ps(0xF0F0, T5, kernel.packet[5]); + T2 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[6]), 0x4E)); + T2 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[2], T2); + T6 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[2]), 0x4E)); + T6 = _mm512_mask_blend_ps(0xF0F0, T6, kernel.packet[6]); + T3 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[7]), 0x4E)); + T3 = _mm512_mask_blend_ps(0xF0F0, kernel.packet[3], T3); + T7 = reinterpret_cast<__m512>(_mm512_permutex_pd(reinterpret_cast<__m512d>(kernel.packet[3]), 0x4E)); + T7 = _mm512_mask_blend_ps(0xF0F0, T7, kernel.packet[7]); + + kernel.packet[0] = T0; kernel.packet[1] = T1; + kernel.packet[2] = T2; kernel.packet[3] = T3; + kernel.packet[4] = T4; kernel.packet[5] = T5; + kernel.packet[6] = T6; kernel.packet[7] = T7; +} + EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { __m512 T0 = _mm512_unpacklo_ps(kernel.packet[0], kernel.packet[1]); __m512 T1 = _mm512_unpackhi_ps(kernel.packet[0], kernel.packet[1]); @@ -1468,62 +1556,53 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { } EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& kernel) { - __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0], kernel.packet[1]); - __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0], kernel.packet[1]); - __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2], kernel.packet[3]); - __m512d T3 = _mm512_unpackhi_pd(kernel.packet[2], kernel.packet[3]); - __m512d T4 = _mm512_unpacklo_pd(kernel.packet[4], kernel.packet[5]); - __m512d T5 = _mm512_unpackhi_pd(kernel.packet[4], kernel.packet[5]); - __m512d T6 = _mm512_unpacklo_pd(kernel.packet[6], kernel.packet[7]); - __m512d T7 = _mm512_unpackhi_pd(kernel.packet[6], kernel.packet[7]); + __m512d T0 = _mm512_unpacklo_pd(kernel.packet[0],kernel.packet[1]); + __m512d T1 = _mm512_unpackhi_pd(kernel.packet[0],kernel.packet[1]); + __m512d T2 = _mm512_unpacklo_pd(kernel.packet[2],kernel.packet[3]); + __m512d T3 = _mm512_unpackhi_pd(kernel.packet[2],kernel.packet[3]); + __m512d T4 = _mm512_unpacklo_pd(kernel.packet[4],kernel.packet[5]); + __m512d T5 = _mm512_unpackhi_pd(kernel.packet[4],kernel.packet[5]); + __m512d T6 = _mm512_unpacklo_pd(kernel.packet[6],kernel.packet[7]); + __m512d T7 = _mm512_unpackhi_pd(kernel.packet[6],kernel.packet[7]); - PacketBlock tmp; + kernel.packet[0] = _mm512_permutex_pd(T2, 0x4E); + kernel.packet[0] = _mm512_mask_blend_pd(0xCC, T0, kernel.packet[0]); + kernel.packet[2] = _mm512_permutex_pd(T0, 0x4E); + kernel.packet[2] = _mm512_mask_blend_pd(0xCC, kernel.packet[2], T2); + kernel.packet[1] = _mm512_permutex_pd(T3, 0x4E); + kernel.packet[1] = _mm512_mask_blend_pd(0xCC, T1, kernel.packet[1]); + kernel.packet[3] = _mm512_permutex_pd(T1, 0x4E); + kernel.packet[3] = _mm512_mask_blend_pd(0xCC, kernel.packet[3], T3); + kernel.packet[4] = _mm512_permutex_pd(T6, 0x4E); + kernel.packet[4] = _mm512_mask_blend_pd(0xCC, T4, kernel.packet[4]); + kernel.packet[6] = _mm512_permutex_pd(T4, 0x4E); + kernel.packet[6] = _mm512_mask_blend_pd(0xCC, kernel.packet[6], T6); + kernel.packet[5] = _mm512_permutex_pd(T7, 0x4E); + kernel.packet[5] = _mm512_mask_blend_pd(0xCC, T5, kernel.packet[5]); + kernel.packet[7] = _mm512_permutex_pd(T5, 0x4E); + kernel.packet[7] = _mm512_mask_blend_pd(0xCC, kernel.packet[7], T7); - tmp.packet[0] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), - _mm512_extractf64x4_pd(T2, 0), 0x20); - tmp.packet[1] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), - _mm512_extractf64x4_pd(T3, 0), 0x20); - tmp.packet[2] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 0), - _mm512_extractf64x4_pd(T2, 0), 0x31); - tmp.packet[3] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 0), - _mm512_extractf64x4_pd(T3, 0), 0x31); + T0 = _mm512_shuffle_f64x2(kernel.packet[4], kernel.packet[4], 0x4E); + T0 = _mm512_mask_blend_pd(0xF0, kernel.packet[0], T0); + T4 = _mm512_shuffle_f64x2(kernel.packet[0], kernel.packet[0], 0x4E); + T4 = _mm512_mask_blend_pd(0xF0, T4, kernel.packet[4]); + T1 = _mm512_shuffle_f64x2(kernel.packet[5], kernel.packet[5], 0x4E); + T1 = _mm512_mask_blend_pd(0xF0, kernel.packet[1], T1); + T5 = _mm512_shuffle_f64x2(kernel.packet[1], kernel.packet[1], 0x4E); + T5 = _mm512_mask_blend_pd(0xF0, T5, kernel.packet[5]); + T2 = _mm512_shuffle_f64x2(kernel.packet[6], kernel.packet[6], 0x4E); + T2 = _mm512_mask_blend_pd(0xF0, kernel.packet[2], T2); + T6 = _mm512_shuffle_f64x2(kernel.packet[2], kernel.packet[2], 0x4E); + T6 = _mm512_mask_blend_pd(0xF0, T6, kernel.packet[6]); + T3 = _mm512_shuffle_f64x2(kernel.packet[7], kernel.packet[7], 0x4E); + T3 = _mm512_mask_blend_pd(0xF0, kernel.packet[3], T3); + T7 = _mm512_shuffle_f64x2(kernel.packet[3], kernel.packet[3], 0x4E); + T7 = _mm512_mask_blend_pd(0xF0, T7, kernel.packet[7]); - tmp.packet[4] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), - _mm512_extractf64x4_pd(T2, 1), 0x20); - tmp.packet[5] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), - _mm512_extractf64x4_pd(T3, 1), 0x20); - tmp.packet[6] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T0, 1), - _mm512_extractf64x4_pd(T2, 1), 0x31); - tmp.packet[7] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T1, 1), - _mm512_extractf64x4_pd(T3, 1), 0x31); - - tmp.packet[8] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0), - _mm512_extractf64x4_pd(T6, 0), 0x20); - tmp.packet[9] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0), - _mm512_extractf64x4_pd(T7, 0), 0x20); - tmp.packet[10] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 0), - _mm512_extractf64x4_pd(T6, 0), 0x31); - tmp.packet[11] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 0), - _mm512_extractf64x4_pd(T7, 0), 0x31); - - tmp.packet[12] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1), - _mm512_extractf64x4_pd(T6, 1), 0x20); - tmp.packet[13] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1), - _mm512_extractf64x4_pd(T7, 1), 0x20); - tmp.packet[14] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T4, 1), - _mm512_extractf64x4_pd(T6, 1), 0x31); - tmp.packet[15] = _mm256_permute2f128_pd(_mm512_extractf64x4_pd(T5, 1), - _mm512_extractf64x4_pd(T7, 1), 0x31); - - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 0, 8); - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 1, 8); - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 2, 8); - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 3, 8); - - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 4, 8); - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 5, 8); - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 6, 8); - PACK_OUTPUT_SQ_D(kernel.packet, tmp.packet, 7, 8); + kernel.packet[0] = T0; kernel.packet[1] = T1; + kernel.packet[2] = T2; kernel.packet[3] = T3; + kernel.packet[4] = T4; kernel.packet[5] = T5; + kernel.packet[6] = T6; kernel.packet[7] = T7; } #define PACK_OUTPUT_I32(OUTPUT, INPUT, INDEX, STRIDE) \ diff --git a/Eigen/src/Core/arch/AVX512/trsmKernel_impl.hpp b/Eigen/src/Core/arch/AVX512/trsmKernel_impl.hpp new file mode 100644 index 000000000..139c69f2b --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/trsmKernel_impl.hpp @@ -0,0 +1,1106 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2022 Intel Corporation +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TRSM_KERNEL_IMPL_H +#define EIGEN_TRSM_KERNEL_IMPL_H + +#include "../../InternalHeaderCheck.h" + +#define EIGEN_USE_AVX512_TRSM_KERNELS // Comment out to prevent using optimized trsm kernels. + +#if defined(EIGEN_HAS_CXX17_IFCONSTEXPR) +#define EIGEN_IF_CONSTEXPR(X) if constexpr (X) +#else +#define EIGEN_IF_CONSTEXPR(X) if (X) +#endif + +// Need this for some std::min calls. +#ifdef min +#undef min +#endif + +namespace Eigen { +namespace internal { + +#define EIGEN_AVX_MAX_NUM_ACC (24L) +#define EIGEN_AVX_MAX_NUM_ROW (8L) // Denoted L in code. +#define EIGEN_AVX_MAX_K_UNROL (4L) +#define EIGEN_AVX_B_LOAD_SETS (2L) +#define EIGEN_AVX_MAX_A_BCAST (2L) +typedef Packet16f vecFullFloat; +typedef Packet8d vecFullDouble; +typedef Packet8f vecHalfFloat; +typedef Packet4d vecHalfDouble; + +// Compile-time unrolls are implemented here +#include "unrolls_impl.hpp" + + +#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) && (EIGEN_COMP_CLANG != 0) +/** + * For smaller problem sizes, and certain compilers, using the optimized kernels trsmKernelL/R directly + * is faster than the packed versions in TriangularSolverMatrix.h. + * + * The current heuristic is based on having having all arrays used in the largest gemm-update + * in triSolve fit in roughly L2Cap (percentage) of the L2 cache. These cutoffs are a bit conservative and could be + * larger for some trsm cases. + * The formula: + * + * (L*M + M*N + L*N)*sizeof(Scalar) < L2Cache*L2Cap + * + * L = number of rows to solve at a time + * N = number of rhs + * M = Dimension of triangular matrix + * + */ +#define EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS // Comment out to disable no-copy dispatch +template +int64_t avx512_trsm_cutoff(int64_t L2Size, int64_t N, double L2Cap){ + const int64_t U3 = 3*packet_traits::size; + const int64_t MaxNb = 5*U3; + int64_t Nb = std::min(MaxNb, N); + double cutoff_d = (((L2Size*L2Cap)/(sizeof(Scalar)))-(EIGEN_AVX_MAX_NUM_ROW)*Nb)/ + ((EIGEN_AVX_MAX_NUM_ROW)+Nb); + int64_t cutoff_l = static_cast(cutoff_d); + return (cutoff_l/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW; +} +#endif + + +/** + * Used by gemmKernel for the case A/B row-major and C col-major. + */ +template +static EIGEN_ALWAYS_INLINE +void transStoreC(PacketBlock &zmm, + Scalar *C_arr, int64_t LDC, int64_t remM_ = 0, int64_t remN_ = 0) { + EIGEN_UNUSED_VARIABLE(remN_); + EIGEN_UNUSED_VARIABLE(remM_); + using urolls = unrolls::trans; + + constexpr int64_t U3 = urolls::PacketSize * 3; + constexpr int64_t U2 = urolls::PacketSize * 2; + constexpr int64_t U1 = urolls::PacketSize * 1; + + static_assert( unrollN == U1 || unrollN == U2 || unrollN == U3, "unrollN should be a multiple of PacketSize"); + static_assert( unrollM == EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW"); + + urolls::template transpose(zmm); + EIGEN_IF_CONSTEXPR(unrollN > U2) urolls::template transpose(zmm); + EIGEN_IF_CONSTEXPR(unrollN > U1) urolls::template transpose(zmm); + + static_assert( (remN && unrollN == U1) || !remN, "When handling N remainder set unrollN=U1"); + EIGEN_IF_CONSTEXPR(!remN) { + urolls::template storeC(C_arr, LDC, zmm, remM_); + EIGEN_IF_CONSTEXPR(unrollN > U1) { + constexpr int64_t unrollN_ = std::min(unrollN-U1, U1); + urolls::template storeC(C_arr + U1*LDC, LDC, zmm, remM_); + } + EIGEN_IF_CONSTEXPR(unrollN > U2) { + constexpr int64_t unrollN_ = std::min(unrollN-U2, U1); + urolls:: template storeC(C_arr + U2*LDC, LDC, zmm, remM_); + } + } + else { + EIGEN_IF_CONSTEXPR( (std::is_same::value) ) { + // Note: without "if constexpr" this section of code will also be + // parsed by the compiler so each of the storeC will still be instantiated. + // We use enable_if in aux_storeC to set it to an empty function for + // these cases. + if(remN_ == 15) + urolls::template storeC<15, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 14) + urolls::template storeC<14, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 13) + urolls::template storeC<13, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 12) + urolls::template storeC<12, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 11) + urolls::template storeC<11, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 10) + urolls::template storeC<10, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 9) + urolls::template storeC<9, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 8) + urolls::template storeC<8, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 7) + urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 6) + urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 5) + urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 4) + urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 3) + urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 2) + urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 1) + urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + } + else { + if(remN_ == 7) + urolls::template storeC<7, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 6) + urolls::template storeC<6, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 5) + urolls::template storeC<5, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 4) + urolls::template storeC<4, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 3) + urolls::template storeC<3, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 2) + urolls::template storeC<2, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + else if(remN_ == 1) + urolls::template storeC<1, unrollN, 0, remM>(C_arr, LDC, zmm, remM_); + } + } +} + +/** + * GEMM like operation for trsm panel updates. + * Computes: C -= A*B + * K must be multipe of 4. + * + * Unrolls used are {1,2,4,8}x{U1,U2,U3}; + * For good performance we want K to be large with M/N relatively small, but also large enough + * to use the {8,U3} unroll block. + * + * isARowMajor: is A_arr row-major? + * isCRowMajor: is C_arr row-major? (B_arr is assumed to be row-major). + * isAdd: C += A*B or C -= A*B (used by trsm) + * handleKRem: Handle arbitrary K? This is not needed for trsm. + */ +template +void gemmKernel(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, + int64_t M, int64_t N, int64_t K, + int64_t LDA, int64_t LDB, int64_t LDC) { + using urolls = unrolls::gemm; + constexpr int64_t U3 = urolls::PacketSize * 3; + constexpr int64_t U2 = urolls::PacketSize * 2; + constexpr int64_t U1 = urolls::PacketSize * 1; + using vec = typename std::conditional::value, + vecFullFloat, + vecFullDouble>::type; + int64_t N_ = (N/U3)*U3; + int64_t M_ = (M/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW; + int64_t K_ = (K/EIGEN_AVX_MAX_K_UNROL)*EIGEN_AVX_MAX_K_UNROL; + int64_t j = 0; + for(; j < N_; j += U3) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*3; + int64_t i = 0; + for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i,0,LDA)], *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<3,EIGEN_AVX_MAX_NUM_ROW>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<3,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<3,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC+ j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC); + } + } + if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<3,4>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<3,4>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<3,4>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4); + } + i += 4; + } + if(M - i >= 2) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<3,2>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<3,2>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<3,2>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2); + } + i += 2; + } + if(M - i > 0) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<3,1>(zmm); + { + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<3,1>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<3,1>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1); + } + } + } + } + if(N - j >= U2) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*2; + int64_t i = 0; + for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i,0,LDA)], *B_t = &B_arr[0*LDB + j]; + EIGEN_IF_CONSTEXPR(isCRowMajor) B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<2,EIGEN_AVX_MAX_NUM_ROW>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<2,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<2,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC); + } + } + if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<2,4>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<2,4>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<2,4>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4); + } + i += 4; + } + if(M - i >= 2) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<2,2>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<2,2>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<2,2>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2); + } + i += 2; + } + if(M - i > 0) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<2,1>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<2,1>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<2,1>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1); + } + } + j += U2; + } + if(N - j >= U1) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*1; + int64_t i = 0; + for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i,0,LDA)], *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,EIGEN_AVX_MAX_NUM_ROW>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<1,EIGEN_AVX_MAX_NUM_ROW>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC); + } + } + if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,4>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,4>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<1,4>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4); + } + i += 4; + } + if(M - i >= 2) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,2>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,2>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<1,2>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2); + } + i += 2; + } + if(M - i > 0) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,1>(zmm); + { + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel(B_t, A_t, LDB, LDA, zmm); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,1>(&C_arr[i*LDC + j], LDC, zmm); + urolls::template storeC<1,1>(&C_arr[i*LDC + j], LDC, zmm); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1); + } + } + } + j += U1; + } + if(N - j > 0) { + constexpr int64_t EIGEN_AVX_MAX_B_LOAD = EIGEN_AVX_B_LOAD_SETS*1; + int64_t i = 0; + for(; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,EIGEN_AVX_MAX_NUM_ROW>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,EIGEN_AVX_MAX_NUM_ROW,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template storeC<1,EIGEN_AVX_MAX_NUM_ROW,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 0, N-j); + } + } + if(M - i >= 4) { // Note: this block assumes EIGEN_AVX_MAX_NUM_ROW = 8. Should be removed otherwise + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,4>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,4,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template storeC<1,4,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 4, N-j); + } + i += 4; + } + if(M - i >= 2) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,2>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,2,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template storeC<1,2,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 2, N-j); + } + i += 2; + } + if(M - i > 0) { + Scalar *A_t = &A_arr[idA(i,0,LDA)]; + Scalar *B_t = &B_arr[0*LDB + j]; + PacketBlock zmm; + urolls::template setzero<1,1>(zmm); + for(int64_t k = 0; k < K_ ; k += EIGEN_AVX_MAX_K_UNROL) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += EIGEN_AVX_MAX_K_UNROL*LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t += EIGEN_AVX_MAX_K_UNROL; else A_t += EIGEN_AVX_MAX_K_UNROL*LDA; + } + EIGEN_IF_CONSTEXPR(handleKRem) { + for(int64_t k = K_; k < K ; k ++) { + urolls:: template microKernel( + B_t, A_t, LDB, LDA, zmm, N - j); + B_t += LDB; + EIGEN_IF_CONSTEXPR(isARowMajor) A_t++; else A_t += LDA; + } + } + EIGEN_IF_CONSTEXPR(isCRowMajor) { + urolls::template updateC<1,1,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + urolls::template storeC<1,1,true>(&C_arr[i*LDC + j], LDC, zmm, N - j); + } + else { + transStoreC(zmm, &C_arr[i + j*LDC], LDC, 1, N-j); + } + } + } +} + +/** + * Triangular solve kernel with A on left with K number of rhs. dim(A) = unrollM + * + * unrollM: dimension of A matrix (triangular matrix). unrollM should be <= EIGEN_AVX_MAX_NUM_ROW + * isFWDSolve: is forward solve? + * isUnitDiag: is the diagonal of A all ones? + * The B matrix (RHS) is assumed to be row-major +*/ +template +static EIGEN_ALWAYS_INLINE +void triSolveKernel(Scalar *A_arr, Scalar *B_arr, int64_t K, int64_t LDA, int64_t LDB) { + + static_assert( unrollM <= EIGEN_AVX_MAX_NUM_ROW, "unrollM should be equal to EIGEN_AVX_MAX_NUM_ROW" ); + using urolls = unrolls::trsm; + constexpr int64_t U3 = urolls::PacketSize * 3; + constexpr int64_t U2 = urolls::PacketSize * 2; + constexpr int64_t U1 = urolls::PacketSize * 1; + + PacketBlock RHSInPacket; + PacketBlock AInPacket; + + int64_t k = 0; + while(K - k >= U3) { + urolls:: template loadRHS(B_arr + k, LDB, RHSInPacket); + urolls:: template triSolveMicroKernel( + A_arr, LDA, RHSInPacket, AInPacket); + urolls:: template storeRHS(B_arr + k, LDB, RHSInPacket); + k += U3; + } + if(K - k >= U2) { + urolls:: template loadRHS(B_arr + k, LDB, RHSInPacket); + urolls:: template triSolveMicroKernel( + A_arr, LDA, RHSInPacket, AInPacket); + urolls:: template storeRHS(B_arr + k, LDB, RHSInPacket); + k += U2; + } + if(K - k >= U1) { + urolls:: template loadRHS(B_arr + k, LDB, RHSInPacket); + urolls:: template triSolveMicroKernel( + A_arr, LDA, RHSInPacket, AInPacket); + urolls:: template storeRHS(B_arr + k, LDB, RHSInPacket); + k += U1; + } + if(K - k > 0) { + // Handle remaining number of RHS + urolls::template loadRHS(B_arr + k, LDB, RHSInPacket, K-k); + urolls::template triSolveMicroKernel( + A_arr, LDA, RHSInPacket, AInPacket); + urolls::template storeRHS(B_arr + k, LDB, RHSInPacket, K-k); + } +} + +/** + * Triangular solve routine with A on left and dimension of at most L with K number of rhs. This is essentially + * a wrapper for triSolveMicrokernel for M = {1,2,3,4,5,6,7,8}. + * + * isFWDSolve: is forward solve? + * isUnitDiag: is the diagonal of A all ones? + * The B matrix (RHS) is assumed to be row-major +*/ +template +void triSolveKernelLxK(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t K, int64_t LDA, int64_t LDB) { + // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted + // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. + using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; + if (M == 8) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + else if (M == 7) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + else if (M == 6) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + else if (M == 5) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + else if (M == 4) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + else if (M == 3) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + else if (M == 2) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + else if (M == 1) + triSolveKernel(A_arr, B_arr, K, LDA, LDB); + return; +} + +/** + * This routine is used to copy B to/from a temporary array (row-major) for cases where B is column-major. + * + * toTemp: true => copy to temporary array, false => copy from temporary array + * remM: true = need to handle remainder values for M (M < EIGEN_AVX_MAX_NUM_ROW) + * + */ +template +static EIGEN_ALWAYS_INLINE +void copyBToRowMajor(Scalar *B_arr, int64_t LDB, int64_t K, + Scalar *B_temp, int64_t LDB_, int64_t remM_ = 0) { + EIGEN_UNUSED_VARIABLE(remM_); + using urolls = unrolls::transB; + using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; + PacketBlock ymm; + constexpr int64_t U3 = urolls::PacketSize * 3; + constexpr int64_t U2 = urolls::PacketSize * 2; + constexpr int64_t U1 = urolls::PacketSize * 1; + int64_t K_ = K/U3*U3; + int64_t k = 0; + + for(; k < K_; k += U3) { + urolls::template transB_kernel(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += U3; + } + if(K - k >= U2) { + urolls::template transB_kernel(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += U2; k += U2; + } + if(K - k >= U1) { + urolls::template transB_kernel(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += U1; k += U1; + } + EIGEN_IF_CONSTEXPR( U1 > 8) { + // Note: without "if constexpr" this section of code will also be + // parsed by the compiler so there is an additional check in {load/store}BBlock + // to make sure the counter is not non-negative. + if(K - k >= 8) { + urolls::template transB_kernel<8, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 8; k += 8; + } + } + EIGEN_IF_CONSTEXPR( U1 > 4) { + // Note: without "if constexpr" this section of code will also be + // parsed by the compiler so there is an additional check in {load/store}BBlock + // to make sure the counter is not non-negative. + if(K - k >= 4) { + urolls::template transB_kernel<4, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 4; k += 4; + } + } + if(K - k >= 2) { + urolls::template transB_kernel<2, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 2; k += 2; + } + if(K - k >= 1) { + urolls::template transB_kernel<1, toTemp, remM>(B_arr + k*LDB, LDB, B_temp, LDB_, ymm, remM_); + B_temp += 1; k += 1; + } +} + +/** + * Main triangular solve driver + * + * Triangular solve with A on the left. + * Scalar: Scalar precision, only float/double is supported. + * isARowMajor: is A row-major? + * isBRowMajor: is B row-major? + * isFWDSolve: is this forward solve or backward (true => forward)? + * isUnitDiag: is diagonal of A unit or nonunit (true => A has unit diagonal)? + * + * M: dimension of A + * numRHS: number of right hand sides (coincides with K dimension for gemm updates) + * + * Here are the mapping between the different TRSM cases (col-major) and triSolve: + * + * LLN (left , lower, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=true + * LUT (left , upper, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=true + * LUN (left , upper, A non-transposed) :: isARowMajor=false, isBRowMajor=false, isFWDSolve=false + * LLT (left , lower, A transposed) :: isARowMajor=true, isBRowMajor=false, isFWDSolve=false + * RUN (right, upper, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=true + * RLT (right, lower, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=true + * RUT (right, upper, A transposed) :: isARowMajor=false, isBRowMajor=true, isFWDSolve=false + * RLN (right, lower, A non-transposed) :: isARowMajor=true, isBRowMajor=true, isFWDSolve=false + * + * Note: For RXX cases M,numRHS should be swapped. + * +*/ +template +void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t LDA, int64_t LDB) { + /** + * The values for kB, numM were determined experimentally. + * kB: Number of RHS we process at a time. + * numM: number of rows of B we will store in a temporary array (see below.) This should be a multiple of L. + * + * kB was determined by initially setting kB = numRHS and benchmarking triSolve (TRSM-RUN case) + * performance with M=numRHS. + * It was observed that performance started to drop around M=numRHS=240. This is likely machine dependent. + * + * numM was chosen "arbitrarily". It should be relatively small so B_temp is not too large, but it should be + * large enough to allow GEMM updates to have larger "K"s (see below.) No benchmarking has been done so far to + * determine optimal values for numM. + */ + const int64_t kB = (3*packet_traits::size)*5; // 5*U3 + const int64_t numM = 64; + + int64_t sizeBTemp = 0; + Scalar *B_temp = NULL; + EIGEN_IF_CONSTEXPR(!isBRowMajor) { + /** + * If B is col-major, we copy it to a fixed-size temporary array of size at most ~numM*kB and + * transpose it to row-major. Call the solve routine, and copy+transpose it back to the original array. + * The updated row-major copy of B is reused in the GEMM updates. + */ + sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM; + B_temp = (Scalar*) aligned_alloc(4096,sizeof(Scalar)*sizeBTemp); + } + for(int64_t k = 0; k < numRHS; k += kB) { + int64_t bK = numRHS - k > kB ? kB : numRHS - k; + int64_t M_ = (M/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW, gemmOff = 0; + + // bK rounded up to next multiple of L=EIGEN_AVX_MAX_NUM_ROW. When B_temp is used, we solve for bkL RHS + // instead of bK RHS in triSolveKernelLxK. + int64_t bkL = ((bK + (EIGEN_AVX_MAX_NUM_ROW-1))/EIGEN_AVX_MAX_NUM_ROW)*EIGEN_AVX_MAX_NUM_ROW; + const int64_t numScalarPerCache = 64/sizeof(Scalar); + // Leading dimension of B_temp, will be a multiple of the cache line size. + int64_t LDT = ((bkL+(numScalarPerCache-1))/numScalarPerCache)*numScalarPerCache; + int64_t offsetBTemp = 0; + for(int64_t i = 0; i < M_; i += EIGEN_AVX_MAX_NUM_ROW) { + EIGEN_IF_CONSTEXPR(!isBRowMajor) { + int64_t indA_i = isFWDSolve ? i : M - 1 - i; + int64_t indB_i = isFWDSolve ? i : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t offB_1 = isFWDSolve ? offsetBTemp : sizeBTemp - EIGEN_AVX_MAX_NUM_ROW*LDT - offsetBTemp; + int64_t offB_2 = isFWDSolve ? offsetBTemp : sizeBTemp - LDT - offsetBTemp; + // Copy values from B to B_temp. + copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp + offB_1, LDT); + // Triangular solve with a small block of A and long horizontal blocks of B (or B_temp if B col-major) + triSolveKernelLxK( + &A_arr[idA(indA_i, indA_i, LDA)], B_temp + offB_2, EIGEN_AVX_MAX_NUM_ROW, bkL, LDA, LDT); + // Copy values from B_temp back to B. B_temp will be reused in gemm call below. + copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp + offB_1, LDT); + + offsetBTemp += EIGEN_AVX_MAX_NUM_ROW*LDT; + } + else { + int64_t ind = isFWDSolve ? i : M - 1 - i; + triSolveKernelLxK( + &A_arr[idA(ind, ind, LDA)], B_arr + k + ind*LDB, EIGEN_AVX_MAX_NUM_ROW, bK, LDA, LDB); + } + if(i+EIGEN_AVX_MAX_NUM_ROW < M_) { + /** + * For the GEMM updates, we want "K" (K=i+8 in this case) to be large as soon as possible + * to reuse the accumulators in GEMM as much as possible. So we only update 8xbK blocks of + * B as follows: + * + * A B + * __ + * |__|__ |__| + * |__|__|__ |__| + * |__|__|__|__ |__| + * |********|__| |**| + */ + EIGEN_IF_CONSTEXPR(isBRowMajor) { + int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); + int64_t indA_j = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i = isFWDSolve ? 0 : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i2 = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); + gemmKernel( + &A_arr[idA(indA_i,indA_j,LDA)], + B_arr + k + indB_i*LDB, + B_arr + k + indB_i2*LDB, + EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW, + LDA, LDB, LDB); + } + else { + if(offsetBTemp + EIGEN_AVX_MAX_NUM_ROW*LDT > sizeBTemp) { + /** + * Similar idea as mentioned above, but here we are limited by the number of updated values of B + * that can be stored (row-major) in B_temp. + * + * If there is not enough space to store the next batch of 8xbK of B in B_temp, we call GEMM + * update and partially update the remaining old values of B which depends on the new values + * of B stored in B_temp. These values are then no longer needed and can be overwritten. + */ + int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0; + int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : 0; + int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; + gemmKernel( + &A_arr[idA(indA_i, indA_j,LDA)], + B_temp + offB_1, + B_arr + indB_i + (k)*LDB, + M - (i + EIGEN_AVX_MAX_NUM_ROW), bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, + LDA, LDT, LDB); + offsetBTemp = 0; gemmOff = i + EIGEN_AVX_MAX_NUM_ROW; + } + else { + /** + * If there is enough space in B_temp, we only update the next 8xbK values of B. + */ + int64_t indA_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); + int64_t indA_j = isFWDSolve ? gemmOff : M - (i + EIGEN_AVX_MAX_NUM_ROW); + int64_t indB_i = isFWDSolve ? i + EIGEN_AVX_MAX_NUM_ROW : M - (i + 2*EIGEN_AVX_MAX_NUM_ROW); + int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; + gemmKernel( + &A_arr[idA(indA_i,indA_j,LDA)], + B_temp + offB_1, + B_arr + indB_i + (k)*LDB, + EIGEN_AVX_MAX_NUM_ROW, bK, i + EIGEN_AVX_MAX_NUM_ROW - gemmOff, + LDA, LDT, LDB); + } + } + } + } + // Handle M remainder.. + int64_t bM = M-M_; + if (bM > 0){ + if( M_ > 0) { + EIGEN_IF_CONSTEXPR(isBRowMajor) { + int64_t indA_i = isFWDSolve ? M_ : 0; + int64_t indA_j = isFWDSolve ? 0 : bM; + int64_t indB_i = isFWDSolve ? 0 : bM; + int64_t indB_i2 = isFWDSolve ? M_ : 0; + gemmKernel( + &A_arr[idA(indA_i,indA_j,LDA)], + B_arr + k +indB_i*LDB, + B_arr + k + indB_i2*LDB, + bM , bK, M_, + LDA, LDB, LDB); + } + else { + int64_t indA_i = isFWDSolve ? M_ : 0; + int64_t indA_j = isFWDSolve ? gemmOff : bM; + int64_t indB_i = isFWDSolve ? M_ : 0; + int64_t offB_1 = isFWDSolve ? 0 : sizeBTemp - offsetBTemp; + gemmKernel( + &A_arr[idA(indA_i,indA_j,LDA)], + B_temp + offB_1, + B_arr + indB_i + (k)*LDB, + bM , bK, M_ - gemmOff, + LDA, LDT, LDB); + } + } + EIGEN_IF_CONSTEXPR(!isBRowMajor) { + int64_t indA_i = isFWDSolve ? M_ : M - 1 - M_; + int64_t indB_i = isFWDSolve ? M_ : 0; + int64_t offB_1 = isFWDSolve ? 0 : (bM-1)*bkL; + copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp, bkL, bM); + triSolveKernelLxK( + &A_arr[idA(indA_i, indA_i, LDA)], B_temp + offB_1, bM, bkL, LDA, bkL); + copyBToRowMajor(B_arr + indB_i + k*LDB, LDB, bK, B_temp, bkL, bM); + } + else { + int64_t ind = isFWDSolve ? M_ : M - 1 - M_; + triSolveKernelLxK( + &A_arr[idA(ind, ind, LDA)], B_arr + k + ind*LDB, bM, bK, LDA, LDB); + } + } + } + EIGEN_IF_CONSTEXPR(!isBRowMajor) free(B_temp); +} + +template +void gemmKer(Scalar *A_arr, Scalar *B_arr, Scalar *C_arr, + int64_t M, int64_t N, int64_t K, + int64_t LDA, int64_t LDB, int64_t LDC) { + gemmKernel(B_arr, A_arr, C_arr, N, M, K, LDB, LDA, LDC); +} + + +// Template specializations of trsmKernelL/R for float/double and inner strides of 1. +#if defined(EIGEN_USE_AVX512_TRSM_KERNELS) +template +struct trsm_kernels; + +template +struct trsm_kernels{ + static void trsmKernelL(Index size, Index otherSize, const float* _tri, Index triStride, + float* _other, Index otherIncr, Index otherStride); + static void trsmKernelR(Index size, Index otherSize, const float* _tri, Index triStride, + float* _other, Index otherIncr, Index otherStride); +}; + +template +struct trsm_kernels{ + static void trsmKernelL(Index size, Index otherSize, const double* _tri, Index triStride, + double* _other, Index otherIncr, Index otherStride); + static void trsmKernelR(Index size, Index otherSize, const double* _tri, Index triStride, + double* _other, Index otherIncr, Index otherStride); +}; + +template +EIGEN_DONT_INLINE void trsm_kernels::trsmKernelL( + Index size, Index otherSize, + const float* _tri, Index triStride, + float* _other, Index otherIncr, Index otherStride) +{ + EIGEN_UNUSED_VARIABLE(otherIncr); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); +} + +template +EIGEN_DONT_INLINE void trsm_kernels::trsmKernelR( + Index size, Index otherSize, + const float* _tri, Index triStride, + float* _other, Index otherIncr, Index otherStride) +{ + EIGEN_UNUSED_VARIABLE(otherIncr); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); +} + +template +EIGEN_DONT_INLINE void trsm_kernels::trsmKernelL( + Index size, Index otherSize, + const double* _tri, Index triStride, + double* _other, Index otherIncr, Index otherStride) +{ + EIGEN_UNUSED_VARIABLE(otherIncr); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); +} + +template +EIGEN_DONT_INLINE void trsm_kernels::trsmKernelR( + Index size, Index otherSize, + const double* _tri, Index triStride, + double* _other, Index otherIncr, Index otherStride) +{ + EIGEN_UNUSED_VARIABLE(otherIncr); + triSolve( + const_cast(_tri), _other, size, otherSize, triStride, otherStride); +} +#endif //EIGEN_USE_AVX512_TRSM_KERNELS +} +} +#endif //EIGEN_TRSM_KERNEL_IMPL_H diff --git a/Eigen/src/Core/arch/AVX512/unrolls_impl.hpp b/Eigen/src/Core/arch/AVX512/unrolls_impl.hpp new file mode 100644 index 000000000..db1308a81 --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/unrolls_impl.hpp @@ -0,0 +1,1213 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2022 Intel Corporation +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_UNROLLS_IMPL_H +#define EIGEN_UNROLLS_IMPL_H + +template +static EIGEN_ALWAYS_INLINE +int64_t idA(int64_t i, int64_t j, int64_t LDA) { + EIGEN_IF_CONSTEXPR(isARowMajor) return i * LDA + j; + else return i + j * LDA; +} + +/** + * This namespace contains various classes used to generate compile-time unrolls which are + * used throughout the trsm/gemm kernels. The unrolls are characterized as for-loops (1-D), nested + * for-loops (2-D), or triple nested for-loops (3-D). Unrolls are generated using template recursion + * + * Example, the 2-D for-loop is unrolled recursively by first flattening to a 1-D loop. + * + * for(startI = 0; startI < endI; startI++) for(startC = 0; startC < endI*endJ; startC++) + * for(startJ = 0; startJ < endJ; startJ++) ----> startI = (startC)/(endJ) + * func(startI,startJ) startJ = (startC)%(endJ) + * func(...) + * + * The 1-D loop can be unrolled recursively by using enable_if and defining an auxillary function + * with a template parameter used as a counter. + * + * template + * std::enable_if_t<(counter <= 0)> <---- tail case. + * aux_func {} + * + * template + * std::enable_if_t<(counter > 0)> <---- actual for-loop + * aux_func { + * startC = endI*endJ - counter + * startI = (startC)/(endJ) + * startJ = (startC)%(endJ) + * func(startI, startJ) + * aux_func() + * } + * + * Note: Additional wrapper functions are provided for aux_func which hides the counter template + * parameter since counter usually depends on endI, endJ, etc... + * + * Conventions: + * 1) endX: specifies the terminal value for the for-loop, (ex: for(startX = 0; startX < endX; startX++)) + * + * 2) rem, remM, remK template parameters are used for deciding whether to use masked operations for + * handling remaining tails (when sizes are not multiples of PacketSize or EIGEN_AVX_MAX_NUM_ROW) + */ +namespace unrolls { + +template +EIGEN_ALWAYS_INLINE auto remMask(int64_t m) { + EIGEN_IF_CONSTEXPR( N == 16) { return 0xFFFF >> (16 - m); } + else EIGEN_IF_CONSTEXPR( N == 8) { return 0xFF >> (8 - m); } + else EIGEN_IF_CONSTEXPR( N == 4) { return 0x0F >> (4 - m); } + return 0; +} + +template +T2 castPacket(T1 &a) { + return reinterpret_cast(a); +} + +template<> +vecHalfFloat castPacket(vecFullFloat &a) { + return _mm512_castps512_ps256(a); +} + +template<> +vecFullDouble castPacket(vecFullDouble &a) { + return a; +} + +/*** + * Unrolls for tranposed C stores + */ +template +class trans { +public: + using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; + using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; + static constexpr int64_t PacketSize = packet_traits::size; + + + /*********************************** + * Auxillary Functions for: + * - storeC + *********************************** + */ + + /** + * aux_storeC + * + * 1-D unroll + * for(startN = 0; startN < endN; startN++) + * + * (endN <= PacketSize) is required to handle the fp32 case, see comments in transStoreC + * + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && endN <= PacketSize)> + aux_storeC(Scalar *C_arr, int64_t LDC, + PacketBlock &zmm, int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN-counter; + constexpr int64_t startN = counterReverse; + + EIGEN_IF_CONSTEXPR(startN < EIGEN_AVX_MAX_NUM_ROW) { + EIGEN_IF_CONSTEXPR(remM) { + pstoreu( + C_arr + LDC*startN, + padd(ploadu((const Scalar*)C_arr + LDC*startN, remMask(remM_)), + castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]), + remMask(remM_)), + remMask(remM_)); + } + else { + pstoreu( + C_arr + LDC*startN, + padd(ploadu((const Scalar*)C_arr + LDC*startN), + castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*startN]))); + } + } + else { + zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)] = + _mm512_shuffle_f32x4( + zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], + zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN - EIGEN_AVX_MAX_NUM_ROW)], 0b01001110); + EIGEN_IF_CONSTEXPR(remM) { + pstoreu( + C_arr + LDC*startN, + padd(ploadu((const Scalar*)C_arr + LDC*startN, + remMask(remM_)), + castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)])), + remMask(remM_)); + } + else { + pstoreu( + C_arr + LDC*startN, + padd(ploadu((const Scalar*)C_arr + LDC*startN), + castPacket(zmm.packet[packetIndexOffset + (unrollN/PacketSize)*(startN-EIGEN_AVX_MAX_NUM_ROW)]))); + } + } + aux_storeC(C_arr, LDC, zmm, remM_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && endN <= PacketSize)> + aux_storeC(Scalar *C_arr, int64_t LDC, + PacketBlock &zmm, int64_t remM_ = 0) + { + EIGEN_UNUSED_VARIABLE(C_arr); + EIGEN_UNUSED_VARIABLE(LDC); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(remM_); + } + + template + static EIGEN_ALWAYS_INLINE + void storeC(Scalar *C_arr, int64_t LDC, + PacketBlock &zmm, int64_t remM_ = 0){ + aux_storeC(C_arr, LDC, zmm, remM_); + } + + /** + * Transposes LxunrollN row major block of matrices stored EIGEN_AVX_MAX_NUM_ACC zmm registers to + * "unrollN"xL ymm registers to be stored col-major into C. + * + * For 8x48, the 8x48 block (row-major) is stored in zmm as follows: + * + * row0: zmm0 zmm1 zmm2 + * row1: zmm3 zmm4 zmm5 + * . + * . + * row7: zmm21 zmm22 zmm23 + * + * For 8x32, the 8x32 block (row-major) is stored in zmm as follows: + * + * row0: zmm0 zmm1 + * row1: zmm2 zmm3 + * . + * . + * row7: zmm14 zmm15 + * + * + * In general we will have {1,2,3} groups of avx registers each of size + * EIGEN_AVX_MAX_NUM_ROW. packetIndexOffset is used to select which "block" of + * avx registers are being transposed. + */ + template + static EIGEN_ALWAYS_INLINE + void transpose(PacketBlock &zmm) { + // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted + // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. + constexpr int64_t zmmStride = unrollN/PacketSize; + PacketBlock r; + r.packet[0] = zmm.packet[packetIndexOffset + zmmStride*0]; + r.packet[1] = zmm.packet[packetIndexOffset + zmmStride*1]; + r.packet[2] = zmm.packet[packetIndexOffset + zmmStride*2]; + r.packet[3] = zmm.packet[packetIndexOffset + zmmStride*3]; + r.packet[4] = zmm.packet[packetIndexOffset + zmmStride*4]; + r.packet[5] = zmm.packet[packetIndexOffset + zmmStride*5]; + r.packet[6] = zmm.packet[packetIndexOffset + zmmStride*6]; + r.packet[7] = zmm.packet[packetIndexOffset + zmmStride*7]; + ptranspose(r); + zmm.packet[packetIndexOffset + zmmStride*0] = r.packet[0]; + zmm.packet[packetIndexOffset + zmmStride*1] = r.packet[1]; + zmm.packet[packetIndexOffset + zmmStride*2] = r.packet[2]; + zmm.packet[packetIndexOffset + zmmStride*3] = r.packet[3]; + zmm.packet[packetIndexOffset + zmmStride*4] = r.packet[4]; + zmm.packet[packetIndexOffset + zmmStride*5] = r.packet[5]; + zmm.packet[packetIndexOffset + zmmStride*6] = r.packet[6]; + zmm.packet[packetIndexOffset + zmmStride*7] = r.packet[7]; + } +}; + +/** + * Unrolls for copyBToRowMajor + * + * Idea: + * 1) Load a block of right-hand sides to registers (using loadB). + * 2) Convert the block from column-major to row-major (transposeLxL) + * 3) Store the blocks from register either to a temp array (toTemp == true), or back to B (toTemp == false). + * + * We use at most EIGEN_AVX_MAX_NUM_ACC avx registers to store the blocks of B. The remaining registers are + * used as temps for transposing. + * + * Blocks will be of size Lx{U1,U2,U3}. packetIndexOffset is used to index between these subblocks + * For fp32, PacketSize = 2*EIGEN_AVX_MAX_NUM_ROW, so we cast packets to packets half the size (zmm -> ymm). + */ +template +class transB { +public: + using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; + using vecHalf = typename std::conditional::value, vecHalfFloat, vecFullDouble>::type; + static constexpr int64_t PacketSize = packet_traits::size; + + /*********************************** + * Auxillary Functions for: + * - loadB + * - storeB + * - loadBBlock + * - storeBBlock + *********************************** + */ + + /** + * aux_loadB + * + * 1-D unroll + * for(startN = 0; startN < endN; startN++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_loadB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN-counter; + constexpr int64_t startN = counterReverse; + + EIGEN_IF_CONSTEXPR(remM) { + ymm.packet[packetIndexOffset + startN] = ploadu( + (const Scalar*)&B_arr[startN*LDB], remMask(remM_)); + } + else + ymm.packet[packetIndexOffset + startN] = ploadu((const Scalar*)&B_arr[startN*LDB]); + + aux_loadB(B_arr, LDB, ymm, remM_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_loadB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, int64_t remM_ = 0) + { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(remM_); + } + + /** + * aux_storeB + * + * 1-D unroll + * for(startN = 0; startN < endN; startN++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_storeB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, int64_t rem_ = 0) { + constexpr int64_t counterReverse = endN-counter; + constexpr int64_t startN = counterReverse; + + EIGEN_IF_CONSTEXPR( remK || remM) { + pstoreu( + &B_arr[startN*LDB], + ymm.packet[packetIndexOffset + startN], + remMask(rem_)); + } + else { + pstoreu(&B_arr[startN*LDB], ymm.packet[packetIndexOffset + startN]); + } + + aux_storeB(B_arr, LDB, ymm, rem_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_storeB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, int64_t rem_ = 0) + { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(rem_); + } + + /** + * aux_loadBBlock + * + * 1-D unroll + * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN-counter; + constexpr int64_t startN = counterReverse; + transB::template loadB(&B_temp[startN], LDB_, ymm); + aux_loadBBlock( + B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) + { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(B_temp); + EIGEN_UNUSED_VARIABLE(LDB_); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(remM_); + } + + + /** + * aux_storeBBlock + * + * 1-D unroll + * for(startN = 0; startN < endN; startN += EIGEN_AVX_MAX_NUM_ROW) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) { + constexpr int64_t counterReverse = endN-counter; + constexpr int64_t startN = counterReverse; + + EIGEN_IF_CONSTEXPR(toTemp) { + transB::template storeB( + &B_temp[startN], LDB_, ymm, remK_); + } + else { + transB::template storeB( + &B_arr[0 + startN*LDB], LDB, ymm, remM_); + } + aux_storeBBlock( + B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) + { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(B_temp); + EIGEN_UNUSED_VARIABLE(LDB_); + EIGEN_UNUSED_VARIABLE(ymm); + EIGEN_UNUSED_VARIABLE(remM_); + } + + + /******************************************************** + * Wrappers for aux_XXXX to hide counter parameter + ********************************************************/ + + template + static EIGEN_ALWAYS_INLINE + void loadB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, int64_t remM_ = 0) { + aux_loadB(B_arr, LDB, ymm, remM_); + } + + template + static EIGEN_ALWAYS_INLINE + void storeB(Scalar *B_arr, int64_t LDB, + PacketBlock &ymm, int64_t rem_ = 0) { + aux_storeB(B_arr, LDB, ymm, rem_); + } + + template + static EIGEN_ALWAYS_INLINE + void loadBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) { + EIGEN_IF_CONSTEXPR(toTemp) { + transB::template loadB(&B_arr[0],LDB, ymm, remM_); + } + else { + aux_loadBBlock( + B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + } + + template + static EIGEN_ALWAYS_INLINE + void storeBBlock(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, + int64_t remM_ = 0) { + aux_storeBBlock( + B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + + template + static EIGEN_ALWAYS_INLINE + void transposeLxL(PacketBlock &ymm){ + // Note: this assumes EIGEN_AVX_MAX_NUM_ROW = 8. Unrolls should be adjusted + // accordingly if EIGEN_AVX_MAX_NUM_ROW is smaller. + PacketBlock r; + r.packet[0] = ymm.packet[packetIndexOffset + 0]; + r.packet[1] = ymm.packet[packetIndexOffset + 1]; + r.packet[2] = ymm.packet[packetIndexOffset + 2]; + r.packet[3] = ymm.packet[packetIndexOffset + 3]; + r.packet[4] = ymm.packet[packetIndexOffset + 4]; + r.packet[5] = ymm.packet[packetIndexOffset + 5]; + r.packet[6] = ymm.packet[packetIndexOffset + 6]; + r.packet[7] = ymm.packet[packetIndexOffset + 7]; + ptranspose(r); + ymm.packet[packetIndexOffset + 0] = r.packet[0]; + ymm.packet[packetIndexOffset + 1] = r.packet[1]; + ymm.packet[packetIndexOffset + 2] = r.packet[2]; + ymm.packet[packetIndexOffset + 3] = r.packet[3]; + ymm.packet[packetIndexOffset + 4] = r.packet[4]; + ymm.packet[packetIndexOffset + 5] = r.packet[5]; + ymm.packet[packetIndexOffset + 6] = r.packet[6]; + ymm.packet[packetIndexOffset + 7] = r.packet[7]; + } + + template + static EIGEN_ALWAYS_INLINE + void transB_kernel(Scalar *B_arr, int64_t LDB, Scalar *B_temp, int64_t LDB_, + PacketBlock &ymm, int64_t remM_ = 0) { + constexpr int64_t U3 = PacketSize * 3; + constexpr int64_t U2 = PacketSize * 2; + constexpr int64_t U1 = PacketSize * 1; + /** + * Unrolls needed for each case: + * - AVX512 fp32 48 32 16 8 4 2 1 + * - AVX512 fp64 24 16 8 4 2 1 + * + * For fp32 L and U1 are 1:2 so for U3/U2 cases the loads/stores need to be split up. + */ + EIGEN_IF_CONSTEXPR(unrollN == U3) { + // load LxU3 B col major, transpose LxU3 row major + constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U3); + transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + + EIGEN_IF_CONSTEXPR( maxUBlock < U3) { + transB::template loadBBlock(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); + transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template storeBBlock(&B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); + } + } + else EIGEN_IF_CONSTEXPR(unrollN == U2) { + // load LxU2 B col major, transpose LxU2 row major + constexpr int64_t maxUBlock = std::min(3*EIGEN_AVX_MAX_NUM_ROW, U2); + transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); + EIGEN_IF_CONSTEXPR(maxUBlock < U2) transB::template transposeLxL<2*EIGEN_AVX_MAX_NUM_ROW>(ymm); + transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + + EIGEN_IF_CONSTEXPR( maxUBlock < U2) { + transB::template loadBBlock( + &B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); + transB::template transposeLxL<0>(ymm); + transB::template storeBBlock( + &B_arr[maxUBlock*LDB], LDB, &B_temp[maxUBlock], LDB_, ymm, remM_); + } + } + else EIGEN_IF_CONSTEXPR(unrollN == U1) { + // load LxU1 B col major, transpose LxU1 row major + transB::template loadBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0>(ymm); + EIGEN_IF_CONSTEXPR(EIGEN_AVX_MAX_NUM_ROW < U1) { + transB::template transposeLxL<1*EIGEN_AVX_MAX_NUM_ROW>(ymm); + } + transB::template storeBBlock(B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + else EIGEN_IF_CONSTEXPR(unrollN == 8 && U1 > 8) { + // load Lx4 B col major, transpose Lx4 row major + transB::template loadBBlock<8,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0>(ymm); + transB::template storeBBlock<8,toTemp,remM,8>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + else EIGEN_IF_CONSTEXPR(unrollN == 4 && U1 > 4) { + // load Lx4 B col major, transpose Lx4 row major + transB::template loadBBlock<4,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0>(ymm); + transB::template storeBBlock<4,toTemp,remM,4>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + else EIGEN_IF_CONSTEXPR(unrollN == 2) { + // load Lx2 B col major, transpose Lx2 row major + transB::template loadBBlock<2,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0>(ymm); + transB::template storeBBlock<2,toTemp,remM,2>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + else EIGEN_IF_CONSTEXPR(unrollN == 1) { + // load Lx1 B col major, transpose Lx1 row major + transB::template loadBBlock<1,toTemp,remM>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + transB::template transposeLxL<0>(ymm); + transB::template storeBBlock<1,toTemp,remM,1>(B_arr, LDB, B_temp, LDB_, ymm, remM_); + } + } +}; + +/** + * Unrolls for triSolveKernel + * + * Idea: + * 1) Load a block of right-hand sides to registers in RHSInPacket (using loadRHS). + * 2) Do triangular solve with RHSInPacket and a small block of A (triangular matrix) + * stored in AInPacket (using triSolveMicroKernel). + * 3) Store final results (in avx registers) back into memory (using storeRHS). + * + * RHSInPacket uses at most EIGEN_AVX_MAX_NUM_ACC avx registers and AInPacket uses at most + * EIGEN_AVX_MAX_NUM_ROW registers. + */ +template +class trsm { +public: + using vec = typename std::conditional::value, + vecFullFloat, + vecFullDouble>::type; + static constexpr int64_t PacketSize = packet_traits::size; + + /*********************************** + * Auxillary Functions for: + * - loadRHS + * - storeRHS + * - divRHSByDiag + * - updateRHS + * - triSolveMicroKernel + ************************************/ + /** + * aux_loadRHS + * + * 2-D unroll + * for(startM = 0; startM < endM; startM++) + * for(startK = 0; startK < endK; startK++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + + constexpr int64_t counterReverse = endM*endK-counter; + constexpr int64_t startM = counterReverse/(endK); + constexpr int64_t startK = counterReverse%endK; + + constexpr int64_t packetIndex = startM*endK + startK; + constexpr int64_t startM_ = isFWDSolve ? startM : -startM; + const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB; + EIGEN_IF_CONSTEXPR(krem) { + RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex], remMask(rem)); + } + else { + RHSInPacket.packet[packetIndex] = ploadu(&B_arr[rhsIndex]); + } + aux_loadRHS(B_arr, LDB, RHSInPacket, rem); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) + { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(rem); + } + + /** + * aux_storeRHS + * + * 2-D unroll + * for(startM = 0; startM < endM; startM++) + * for(startK = 0; startK < endK; startK++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + constexpr int64_t counterReverse = endM*endK-counter; + constexpr int64_t startM = counterReverse/(endK); + constexpr int64_t startK = counterReverse%endK; + + constexpr int64_t packetIndex = startM*endK + startK; + constexpr int64_t startM_ = isFWDSolve ? startM : -startM; + const int64_t rhsIndex = (startK*PacketSize) + startM_*LDB; + EIGEN_IF_CONSTEXPR(krem) { + pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex], remMask(rem)); + } + else { + pstoreu(&B_arr[rhsIndex], RHSInPacket.packet[packetIndex]); + } + aux_storeRHS(B_arr, LDB, RHSInPacket, rem); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) + { + EIGEN_UNUSED_VARIABLE(B_arr); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(rem); + } + + /** + * aux_divRHSByDiag + * + * currM may be -1, (currM >=0) in enable_if checks for this + * + * 1-D unroll + * for(startK = 0; startK < endK; startK++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0 && currM >= 0)> + aux_divRHSByDiag(PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + constexpr int64_t counterReverse = endK-counter; + constexpr int64_t startK = counterReverse; + + constexpr int64_t packetIndex = currM*endK + startK; + RHSInPacket.packet[packetIndex] = pmul(AInPacket.packet[currM], RHSInPacket.packet[packetIndex]); + aux_divRHSByDiag(RHSInPacket, AInPacket); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t 0 && currM >= 0)> + aux_divRHSByDiag(PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(AInPacket); + } + + /** + * aux_updateRHS + * + * 2-D unroll + * for(startM = initM; startM < endM; startM++) + * for(startK = 0; startK < endK; startK++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + + constexpr int64_t counterReverse = (endM-initM)*endK-counter; + constexpr int64_t startM = initM + counterReverse/(endK); + constexpr int64_t startK = counterReverse%endK; + + // For each row of A, first update all corresponding RHS + constexpr int64_t packetIndex = startM*endK + startK; + EIGEN_IF_CONSTEXPR(currentM > 0) { + RHSInPacket.packet[packetIndex] = + pnmadd(AInPacket.packet[startM], + RHSInPacket.packet[(currentM-1)*endK+startK], + RHSInPacket.packet[packetIndex]); + } + + EIGEN_IF_CONSTEXPR(startK == endK - 1) { + // Once all RHS for previous row of A is updated, we broadcast the next element in the column A_{i, currentM}. + EIGEN_IF_CONSTEXPR(startM == currentM && !isUnitDiag) { + // If diagonal is not unit, we broadcast reciprocals of diagonals AinPacket.packet[currentM]. + // This will be used in divRHSByDiag + EIGEN_IF_CONSTEXPR(isFWDSolve) + AInPacket.packet[currentM] = pset1(Scalar(1)/A_arr[idA(currentM,currentM,LDA)]); + else + AInPacket.packet[currentM] = pset1(Scalar(1)/A_arr[idA(-currentM,-currentM,LDA)]); + } + else { + // Broadcast next off diagonal element of A + EIGEN_IF_CONSTEXPR(isFWDSolve) + AInPacket.packet[startM] = pset1(A_arr[idA(startM,currentM,LDA)]); + else + AInPacket.packet[startM] = pset1(A_arr[idA(-startM,-currentM,LDA)]); + } + } + + aux_updateRHS(A_arr, LDA, RHSInPacket, AInPacket); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + EIGEN_UNUSED_VARIABLE(A_arr); + EIGEN_UNUSED_VARIABLE(LDA); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(AInPacket); + } + + /** + * aux_triSolverMicroKernel + * + * 1-D unroll + * for(startM = 0; startM < endM; startM++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + + constexpr int64_t counterReverse = endM-counter; + constexpr int64_t startM = counterReverse; + + constexpr int64_t currentM = startM; + // Divides the right-hand side in row startM, by digonal value of A + // broadcasted to AInPacket.packet[startM-1] in the previous iteration. + // + // Without "if constexpr" the compiler instantiates the case <-1, numK> + // this is handled with enable_if to prevent out-of-bound warnings + // from the compiler + EIGEN_IF_CONSTEXPR(!isUnitDiag && startM > 0) + trsm::template divRHSByDiag(RHSInPacket, AInPacket); + + // After division, the rhs corresponding to subsequent rows of A can be partially updated + // We also broadcast the reciprocal of the next diagonal to AInPacket.packet[currentM] (if needed) + // to be used in the next iteration. + trsm::template + updateRHS( + A_arr, LDA, RHSInPacket, AInPacket); + + // Handle division for the RHS corresponding to the final row of A. + EIGEN_IF_CONSTEXPR(!isUnitDiag && startM == endM-1) + trsm::template divRHSByDiag(RHSInPacket, AInPacket); + + aux_triSolveMicroKernel(A_arr, LDA, RHSInPacket, AInPacket); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) + { + EIGEN_UNUSED_VARIABLE(A_arr); + EIGEN_UNUSED_VARIABLE(LDA); + EIGEN_UNUSED_VARIABLE(RHSInPacket); + EIGEN_UNUSED_VARIABLE(AInPacket); + } + + /******************************************************** + * Wrappers for aux_XXXX to hide counter parameter + ********************************************************/ + + /** + * Load endMxendK block of B to RHSInPacket + * Masked loads are used for cases where endK is not a multiple of PacketSize + */ + template + static EIGEN_ALWAYS_INLINE + void loadRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + aux_loadRHS(B_arr, LDB, RHSInPacket, rem); + } + + /** + * Load endMxendK block of B to RHSInPacket + * Masked loads are used for cases where endK is not a multiple of PacketSize + */ + template + static EIGEN_ALWAYS_INLINE + void storeRHS(Scalar* B_arr, int64_t LDB, PacketBlock &RHSInPacket, int64_t rem = 0) { + aux_storeRHS(B_arr, LDB, RHSInPacket, rem); + } + + /** + * Only used if Triangular matrix has non-unit diagonal values + */ + template + static EIGEN_ALWAYS_INLINE + void divRHSByDiag(PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + aux_divRHSByDiag(RHSInPacket, AInPacket); + } + + /** + * Update right-hand sides (stored in avx registers) + * Traversing along the column A_{i,currentM}, where currentM <= i <= endM, and broadcasting each value to AInPacket. + **/ + template + static EIGEN_ALWAYS_INLINE + void updateRHS(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + aux_updateRHS( + A_arr, LDA, RHSInPacket, AInPacket); + } + + /** + * endM: dimension of A. 1 <= endM <= EIGEN_AVX_MAX_NUM_ROW + * numK: number of avx registers to use for each row of B (ex fp32: 48 rhs => 3 avx reg used). 1 <= endK <= 3. + * isFWDSolve: true => forward substitution, false => backwards substitution + * isUnitDiag: true => triangular matrix has unit diagonal. + */ + template + static EIGEN_ALWAYS_INLINE + void triSolveMicroKernel(Scalar *A_arr, int64_t LDA, PacketBlock &RHSInPacket, PacketBlock &AInPacket) { + static_assert( numK >= 1 && numK <= 3, "numK out of range" ); + aux_triSolveMicroKernel( + A_arr, LDA, RHSInPacket, AInPacket); + } +}; + +/** + * Unrolls for gemm kernel + * + * isAdd: true => C += A*B, false => C -= A*B + */ +template +class gemm { +public: + using vec = typename std::conditional::value, vecFullFloat, vecFullDouble>::type; + static constexpr int64_t PacketSize = packet_traits::size; + + /*********************************** + * Auxillary Functions for: + * - setzero + * - updateC + * - storeC + * - startLoadB + * - triSolveMicroKernel + ************************************/ + + /** + * aux_setzero + * + * 2-D unroll + * for(startM = 0; startM < endM; startM++) + * for(startN = 0; startN < endN; startN++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_setzero(PacketBlock &zmm) { + constexpr int64_t counterReverse = endM*endN-counter; + constexpr int64_t startM = counterReverse/(endN); + constexpr int64_t startN = counterReverse%endN; + + zmm.packet[startN*endM + startM] = pzero(zmm.packet[startN*endM + startM]); + aux_setzero(zmm); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_setzero(PacketBlock &zmm) + { + EIGEN_UNUSED_VARIABLE(zmm); + } + + /** + * aux_updateC + * + * 2-D unroll + * for(startM = 0; startM < endM; startM++) + * for(startN = 0; startN < endN; startN++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(rem_); + constexpr int64_t counterReverse = endM*endN-counter; + constexpr int64_t startM = counterReverse/(endN); + constexpr int64_t startN = counterReverse%endN; + + EIGEN_IF_CONSTEXPR(rem) + zmm.packet[startN*endM + startM] = + padd(ploadu(&C_arr[(startN) * LDC + startM*PacketSize], remMask(rem_)), + zmm.packet[startN*endM + startM], + remMask(rem_)); + else + zmm.packet[startN*endM + startM] = + padd(ploadu(&C_arr[(startN) * LDC + startM*PacketSize]), zmm.packet[startN*endM + startM]); + aux_updateC(C_arr, LDC, zmm, rem_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_updateC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) + { + EIGEN_UNUSED_VARIABLE(C_arr); + EIGEN_UNUSED_VARIABLE(LDC); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(rem_); + } + + /** + * aux_storeC + * + * 2-D unroll + * for(startM = 0; startM < endM; startM++) + * for(startN = 0; startN < endN; startN++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(rem_); + constexpr int64_t counterReverse = endM*endN-counter; + constexpr int64_t startM = counterReverse/(endN); + constexpr int64_t startN = counterReverse%endN; + + EIGEN_IF_CONSTEXPR(rem) + pstoreu(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM], remMask(rem_)); + else + pstoreu(&C_arr[(startN) * LDC + startM*PacketSize], zmm.packet[startN*endM + startM]); + aux_storeC(C_arr, LDC, zmm, rem_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_storeC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0) + { + EIGEN_UNUSED_VARIABLE(C_arr); + EIGEN_UNUSED_VARIABLE(LDC); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(rem_); + } + + /** + * aux_startLoadB + * + * 1-D unroll + * for(startL = 0; startL < endL; startL++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_startLoadB(Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(rem_); + constexpr int64_t counterReverse = endL-counter; + constexpr int64_t startL = counterReverse; + + EIGEN_IF_CONSTEXPR(rem) + zmm.packet[unrollM*unrollN+startL] = + ploadu(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize], remMask(rem_)); + else + zmm.packet[unrollM*unrollN+startL] = ploadu(&B_t[(startL/unrollM)*LDB + (startL%unrollM)*PacketSize]); + + aux_startLoadB(B_t, LDB, zmm, rem_); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_startLoadB( + Scalar *B_t, int64_t LDB, + PacketBlock &zmm, int64_t rem_ = 0) + { + EIGEN_UNUSED_VARIABLE(B_t); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(rem_); + } + + /** + * aux_startBCastA + * + * 1-D unroll + * for(startB = 0; startB < endB; startB++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock &zmm) { + constexpr int64_t counterReverse = endB-counter; + constexpr int64_t startB = counterReverse; + + zmm.packet[unrollM*unrollN+numLoad+startB] = pload1(&A_t[idA(startB, 0,LDA)]); + + aux_startBCastA(A_t, LDA, zmm); + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_startBCastA(Scalar *A_t, int64_t LDA, PacketBlock &zmm) + { + EIGEN_UNUSED_VARIABLE(A_t); + EIGEN_UNUSED_VARIABLE(LDA); + EIGEN_UNUSED_VARIABLE(zmm); + } + + /** + * aux_loadB + * currK: current K + * + * 1-D unroll + * for(startM = 0; startM < endM; startM++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_loadB(Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(rem_); + if ((numLoad/endM + currK < unrollK)) { + constexpr int64_t counterReverse = endM-counter; + constexpr int64_t startM = counterReverse; + + EIGEN_IF_CONSTEXPR(rem) { + zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] = + ploadu(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize], remMask(rem_)); + } + else { + zmm.packet[endM*unrollN+(startM+currK*endM)%numLoad] = + ploadu(&B_t[(numLoad/endM + currK)*LDB + startM*PacketSize]); + } + + aux_loadB(B_t, LDB, zmm, rem_); + } + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_loadB( + Scalar *B_t, int64_t LDB, + PacketBlock &zmm, int64_t rem_ = 0) + { + EIGEN_UNUSED_VARIABLE(B_t); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(rem_); + } + + /** + * aux_microKernel + * + * 3-D unroll + * for(startM = 0; startM < endM; startM++) + * for(startN = 0; startN < endN; startN++) + * for(startK = 0; startK < endK; startK++) + **/ + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter > 0)> + aux_microKernel( + Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, + PacketBlock &zmm, int64_t rem_ = 0) { + EIGEN_UNUSED_VARIABLE(rem_); + constexpr int64_t counterReverse = endM*endN*endK-counter; + constexpr int startK = counterReverse/(endM*endN); + constexpr int startN = (counterReverse/(endM))%endN; + constexpr int startM = counterReverse%endM; + + EIGEN_IF_CONSTEXPR(startK == 0 && startM == 0 && startN == 0) { + gemm:: template + startLoadB(B_t, LDB, zmm, rem_); + gemm:: template + startBCastA(A_t, LDA, zmm); + } + + { + // Interleave FMA and Bcast + EIGEN_IF_CONSTEXPR(isAdd) { + zmm.packet[startN*endM + startM] = + pmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast], + zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]); + } + else { + zmm.packet[startN*endM + startM] = + pnmadd(zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast], + zmm.packet[endM*endN+(startM+startK*endM)%numLoad], zmm.packet[startN*endM + startM]); + } + // Bcast + EIGEN_IF_CONSTEXPR(startM == endM - 1 && (numBCast + startN + startK*endN < endK*endN)) { + zmm.packet[endM*endN+numLoad+(startN+startK*endN)%numBCast] = + pload1(&A_t[idA((numBCast + startN + startK*endN)%endN, + (numBCast + startN + startK*endN)/endN, LDA)]); + } + } + + // We have updated all accumlators, time to load next set of B's + EIGEN_IF_CONSTEXPR( (startN == endN - 1) && (startM == endM - 1) ) { + gemm::template loadB(B_t, LDB, zmm, rem_); + } + aux_microKernel(B_t, A_t, LDB, LDA, zmm, rem_); + + } + + template + static EIGEN_ALWAYS_INLINE std::enable_if_t<(counter <= 0)> + aux_microKernel( + Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, + PacketBlock &zmm, int64_t rem_ = 0) + { + EIGEN_UNUSED_VARIABLE(B_t); + EIGEN_UNUSED_VARIABLE(A_t); + EIGEN_UNUSED_VARIABLE(LDB); + EIGEN_UNUSED_VARIABLE(LDA); + EIGEN_UNUSED_VARIABLE(zmm); + EIGEN_UNUSED_VARIABLE(rem_); + } + + /******************************************************** + * Wrappers for aux_XXXX to hide counter parameter + ********************************************************/ + + template + static EIGEN_ALWAYS_INLINE + void setzero(PacketBlock &zmm){ + aux_setzero(zmm); + } + + /** + * Ideally the compiler folds these into vaddp{s,d} with an embedded memory load. + */ + template + static EIGEN_ALWAYS_INLINE + void updateC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0){ + EIGEN_UNUSED_VARIABLE(rem_); + aux_updateC(C_arr, LDC, zmm, rem_); + } + + template + static EIGEN_ALWAYS_INLINE + void storeC(Scalar *C_arr, int64_t LDC, PacketBlock &zmm, int64_t rem_ = 0){ + EIGEN_UNUSED_VARIABLE(rem_); + aux_storeC(C_arr, LDC, zmm, rem_); + } + + /** + * Use numLoad registers for loading B at start of microKernel + */ + template + static EIGEN_ALWAYS_INLINE + void startLoadB(Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0){ + EIGEN_UNUSED_VARIABLE(rem_); + aux_startLoadB(B_t, LDB, zmm, rem_); + } + + /** + * Use numBCast registers for broadcasting A at start of microKernel + */ + template + static EIGEN_ALWAYS_INLINE + void startBCastA(Scalar *A_t, int64_t LDA, PacketBlock &zmm){ + aux_startBCastA(A_t, LDA, zmm); + } + + /** + * Loads next set of B into vector registers between each K unroll. + */ + template + static EIGEN_ALWAYS_INLINE + void loadB( + Scalar *B_t, int64_t LDB, PacketBlock &zmm, int64_t rem_ = 0){ + EIGEN_UNUSED_VARIABLE(rem_); + aux_loadB(B_t, LDB, zmm, rem_); + } + + /** + * Generates a microkernel for gemm (row-major) with unrolls {1,2,4,8}x{U1,U2,U3} to compute C -= A*B. + * A matrix can be row/col-major. B matrix is assumed row-major. + * + * isARowMajor: is A row major + * endM: Number registers per row + * endN: Number of rows + * endK: Loop unroll for K. + * numLoad: Number of registers for loading B. + * numBCast: Number of registers for broadcasting A. + * + * Ex: microkernel: 8x48 unroll (24 accumulators), k unrolled 4 times, + * 6 register for loading B, 2 for broadcasting A. + * + * Note: Ideally the microkernel should not have any register spilling. + * The avx instruction counts should be: + * - endK*endN vbroadcasts{s,d} + * - endK*endM vmovup{s,d} + * - endK*endN*endM FMAs + * + * From testing, there are no register spills with clang. There are register spills with GNU, which + * causes a performance hit. + */ + template + static EIGEN_ALWAYS_INLINE + void microKernel( + Scalar *B_t, Scalar* A_t, int64_t LDB, int64_t LDA, + PacketBlock &zmm, int64_t rem_ = 0){ + EIGEN_UNUSED_VARIABLE(rem_); + aux_microKernel( + B_t, A_t, LDB, LDA, zmm, rem_); + } + +}; +} // namespace unrolls + + +#endif //EIGEN_UNROLLS_IMPL_H diff --git a/Eigen/src/Core/products/TriangularSolverMatrix.h b/Eigen/src/Core/products/TriangularSolverMatrix.h index 520cfc98a..def6a28f2 100644 --- a/Eigen/src/Core/products/TriangularSolverMatrix.h +++ b/Eigen/src/Core/products/TriangularSolverMatrix.h @@ -2,6 +2,7 @@ // for linear algebra. // // Copyright (C) 2009 Gael Guennebaud +// Modifications Copyright (C) 2022 Intel Corporation // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -12,10 +13,118 @@ #include "../InternalHeaderCheck.h" -namespace Eigen { +namespace Eigen { namespace internal { +template +struct trsm_kernels { + // Generic Implementation of triangular solve for triangular matrix on left and multiple rhs. + // Handles non-packed matrices. + static void trsmKernelL( + Index size, Index otherSize, + const Scalar* _tri, Index triStride, + Scalar* _other, Index otherIncr, Index otherStride); + + // Generic Implementation of triangular solve for triangular matrix on right and multiple lhs. + // Handles non-packed matrices. + static void trsmKernelR( + Index size, Index otherSize, + const Scalar* _tri, Index triStride, + Scalar* _other, Index otherIncr, Index otherStride); +}; + +template +EIGEN_STRONG_INLINE void trsm_kernels::trsmKernelL( + Index size, Index otherSize, + const Scalar* _tri, Index triStride, + Scalar* _other, Index otherIncr, Index otherStride) + { + typedef const_blas_data_mapper TriMapper; + typedef blas_data_mapper OtherMapper; + TriMapper tri(_tri, triStride); + OtherMapper other(_other, otherStride, otherIncr); + + enum { IsLower = (Mode&Lower) == Lower }; + conj_if conj; + + // tr solve + for (Index k=0; k +EIGEN_STRONG_INLINE void trsm_kernels::trsmKernelR( + Index size, Index otherSize, + const Scalar* _tri, Index triStride, + Scalar* _other, Index otherIncr, Index otherStride) +{ + typedef typename NumTraits::Real RealScalar; + typedef blas_data_mapper LhsMapper; + typedef const_blas_data_mapper RhsMapper; + LhsMapper lhs(_other, otherStride, otherIncr); + RhsMapper rhs(_tri, triStride); + + enum { + RhsStorageOrder = TriStorageOrder, + IsLower = (Mode&Lower) == Lower + }; + conj_if conj; + + for (Index k=0; k struct triangular_solve_matrix @@ -46,6 +155,7 @@ struct triangular_solve_matrix& blocking); }; + template EIGEN_DONT_INLINE void triangular_solve_matrix::run( Index size, Index otherSize, @@ -55,6 +165,25 @@ EIGEN_DONT_INLINE void triangular_solve_matrix::value || + std::is_same::value)) ) { + // Very rough cutoffs to determine when to call trsm w/o packing + // For small problem sizes trsmKernel compiled with clang is generally faster. + // TODO: Investigate better heuristics for cutoffs. + double L2Cap = 0.5; // 50% of L2 size + if (size < avx512_trsm_cutoff(l2, cols, L2Cap)) { + trsm_kernels::trsmKernelL( + size, cols, _tri, triStride, _other, 1, otherStride); + return; + } + } +#endif + typedef const_blas_data_mapper TriMapper; typedef blas_data_mapper OtherMapper; TriMapper tri(_tri, triStride); @@ -76,15 +205,12 @@ EIGEN_DONT_INLINE void triangular_solve_matrix conj; gebp_kernel gebp_kernel; gemm_pack_lhs pack_lhs; gemm_pack_rhs pack_rhs; // the goal here is to subdivise the Rhs panels such that we keep some cache // coherence when accessing the rhs elements - std::ptrdiff_t l1, l2, l3; - manage_caching_sizes(GetAction, &l1, &l2, &l3); Index subcols = cols>0 ? l2/(4 * sizeof(Scalar) * std::max(otherStride,size)) : 0; subcols = std::max((subcols/Traits::nr)*Traits::nr, Traits::nr); @@ -115,38 +241,19 @@ EIGEN_DONT_INLINE void triangular_solve_matrix(actual_kc-k1, SmallPanelWidth); // tr solve - for (Index k=0; k::value || + std::is_same::value)) ) { + i = IsLower ? k2 + k1: k2 - k1 - actualPanelWidth; } +#endif + trsm_kernels::trsmKernelL( + actualPanelWidth, actual_cols, + _tri + i + (i)*triStride, triStride, + _other + i*OtherInnerStride + j2*otherStride, otherIncr, otherStride); } Index lengthTarget = actual_kc-k1-actualPanelWidth; @@ -168,7 +275,7 @@ EIGEN_DONT_INLINE void triangular_solve_matrix GEPP { Index start = IsLower ? k2+kc : 0; @@ -198,6 +305,7 @@ struct triangular_solve_matrix& blocking); }; + template EIGEN_DONT_INLINE void triangular_solve_matrix::run( Index size, Index otherSize, @@ -206,7 +314,22 @@ EIGEN_DONT_INLINE void triangular_solve_matrix& blocking) { Index rows = otherSize; - typedef typename NumTraits::Real RealScalar; + +#if defined(EIGEN_ENABLE_AVX512_NOCOPY_TRSM_CUTOFFS) + EIGEN_IF_CONSTEXPR( (OtherInnerStride == 1 && + (std::is_same::value || + std::is_same::value)) ) { + // TODO: Investigate better heuristics for cutoffs. + std::ptrdiff_t l1, l2, l3; + manage_caching_sizes(GetAction, &l1, &l2, &l3); + double L2Cap = 0.5; // 50% of L2 size + if (size < avx512_trsm_cutoff(l2, rows, L2Cap)) { + trsm_kernels:: + trsmKernelR(size, rows, _tri, triStride, _other, 1, otherStride); + return; + } + } +#endif typedef blas_data_mapper LhsMapper; typedef const_blas_data_mapper RhsMapper; @@ -229,7 +352,6 @@ EIGEN_DONT_INLINE void triangular_solve_matrix conj; gebp_kernel gebp_kernel; gemm_pack_rhs pack_rhs; gemm_pack_rhs pack_rhs_panel; @@ -296,27 +418,13 @@ EIGEN_DONT_INLINE void triangular_solve_matrix:: + trsmKernelR(actualPanelWidth, actual_mc, + _tri + absolute_j2 + absolute_j2*triStride, triStride, + _other + i2*OtherInnerStride + absolute_j2*otherStride, otherIncr, otherStride); } - // pack the just computed part of lhs to A pack_lhs_panel(blockA, lhs.getSubMapper(i2,absolute_j2), actualPanelWidth, actual_mc, @@ -331,7 +439,6 @@ EIGEN_DONT_INLINE void triangular_solve_matrix