From 3580a382980b27cf4d032b9ec5634d3df66ae24a Mon Sep 17 00:00:00 2001 From: Antonio Sanchez Date: Sun, 16 Mar 2025 20:58:59 -0700 Subject: [PATCH] Use native _Float16 for AVX512FP16 and update vectorization. This allows us to do faster native scalar operations. Also updated half/quarter packets to use the native type if available. Benchmark improvement: ``` Comparing ./2910_without_float16 to ./2910_with_float16 Benchmark Time CPU Time Old Time New CPU Old CPU New ------------------------------------------------------------------------------------------------------------------------------------ BM_CalcMat/10000/768/500 -0.0041 -0.0040 58276392 58039442 58273420 58039582 BM_CalcMat<_Float16>/10000/768/500 +0.0073 +0.0073 642506339 647214446 642481384 647188303 BM_CalcMat/10000/768/500 -0.3170 -0.3170 92511115 63182101 92506771 63179258 BM_CalcVec/10000/768/500 +0.0022 +0.0022 5198157 5209469 5197913 5209334 BM_CalcVec<_Float16>/10000/768/500 +0.0025 +0.0026 10133324 10159111 10132641 10158507 BM_CalcVec/10000/768/500 -0.7760 -0.7760 45337937 10156952 45336532 10156389 OVERALL_GEOMEAN -0.2677 -0.2677 0 0 0 0 ``` Fixes #2910. --- Eigen/Core | 16 +- Eigen/src/Core/MathFunctionsImpl.h | 2 +- Eigen/src/Core/arch/AVX/MathFunctions.h | 3 + Eigen/src/Core/arch/AVX/PacketMath.h | 11 +- Eigen/src/Core/arch/AVX/TypeCasting.h | 12 +- Eigen/src/Core/arch/AVX512/MathFunctions.h | 23 +- .../src/Core/arch/AVX512/MathFunctionsFP16.h | 75 ++ Eigen/src/Core/arch/AVX512/PacketMath.h | 198 +++- Eigen/src/Core/arch/AVX512/PacketMathFP16.h | 996 ++++++++++++++---- Eigen/src/Core/arch/AVX512/TypeCasting.h | 75 +- Eigen/src/Core/arch/AVX512/TypeCastingFP16.h | 130 +++ Eigen/src/Core/arch/Default/Half.h | 317 +++--- Eigen/src/Core/util/ConfigureVectorization.h | 2 + test/packet_ostream.h | 3 +- test/packetmath.cpp | 8 +- 15 files changed, 1422 insertions(+), 449 deletions(-) create mode 100644 Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h create mode 100644 Eigen/src/Core/arch/AVX512/TypeCastingFP16.h diff --git a/Eigen/Core b/Eigen/Core index 99cd47390..6ae069a92 100644 --- a/Eigen/Core +++ b/Eigen/Core @@ -193,21 +193,27 @@ using std::ptrdiff_t; #include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h" #if defined EIGEN_VECTORIZE_AVX512 +#include "src/Core/arch/SSE/PacketMath.h" +#include "src/Core/arch/AVX/PacketMath.h" +#include "src/Core/arch/AVX512/PacketMath.h" #if defined EIGEN_VECTORIZE_AVX512FP16 #include "src/Core/arch/AVX512/PacketMathFP16.h" #endif -#include "src/Core/arch/SSE/PacketMath.h" #include "src/Core/arch/SSE/TypeCasting.h" -#include "src/Core/arch/SSE/Complex.h" -#include "src/Core/arch/AVX/PacketMath.h" #include "src/Core/arch/AVX/TypeCasting.h" -#include "src/Core/arch/AVX/Complex.h" -#include "src/Core/arch/AVX512/PacketMath.h" #include "src/Core/arch/AVX512/TypeCasting.h" +#if defined EIGEN_VECTORIZE_AVX512FP16 +#include "src/Core/arch/AVX512/TypeCastingFP16.h" +#endif +#include "src/Core/arch/SSE/Complex.h" +#include "src/Core/arch/AVX/Complex.h" #include "src/Core/arch/AVX512/Complex.h" #include "src/Core/arch/SSE/MathFunctions.h" #include "src/Core/arch/AVX/MathFunctions.h" #include "src/Core/arch/AVX512/MathFunctions.h" +#if defined EIGEN_VECTORIZE_AVX512FP16 +#include "src/Core/arch/AVX512/MathFunctionsFP16.h" +#endif #include "src/Core/arch/AVX512/TrsmKernel.h" #elif defined EIGEN_VECTORIZE_AVX // Use AVX for floats and doubles, SSE for integers diff --git a/Eigen/src/Core/MathFunctionsImpl.h b/Eigen/src/Core/MathFunctionsImpl.h index 8e2705ba9..cf8dcc3b8 100644 --- a/Eigen/src/Core/MathFunctionsImpl.h +++ b/Eigen/src/Core/MathFunctionsImpl.h @@ -76,7 +76,7 @@ struct generic_rsqrt_newton_step { static_assert(Steps > 0, "Steps must be at least 1."); using Scalar = typename unpacket_traits::type; EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) { - constexpr Scalar kMinusHalf = Scalar(-1) / Scalar(2); + const Scalar kMinusHalf = Scalar(-1) / Scalar(2); const Packet cst_minus_half = pset1(kMinusHalf); const Packet cst_minus_one = pset1(Scalar(-1)); diff --git a/Eigen/src/Core/arch/AVX/MathFunctions.h b/Eigen/src/Core/arch/AVX/MathFunctions.h index a5c38e787..eb0011c2b 100644 --- a/Eigen/src/Core/arch/AVX/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX/MathFunctions.h @@ -106,6 +106,8 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt) BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh) + +#ifndef EIGEN_VECTORIZE_AVX512FP16 F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos) F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp) F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp2) @@ -118,6 +120,7 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt) F16_PACKET_FUNCTION(Packet8f, Packet8h, psin) F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt) F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh) +#endif } // end namespace internal diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index c29523a7c..aa93e4516 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -1839,10 +1839,13 @@ EIGEN_STRONG_INLINE Packet8ui pabs(const Packet8ui& a) { return a; } +#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { return _mm_cmpgt_epi16(_mm_setzero_si128(), a); } +#endif // EIGEN_VECTORIZE_AVX512FP16 + template <> EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) { return _mm_cmpgt_epi16(_mm_setzero_si128(), a); @@ -2044,10 +2047,13 @@ EIGEN_STRONG_INLINE bool predux_any(const Packet8ui& x) { return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0; } +#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE bool predux_any(const Packet8h& x) { return _mm_movemask_epi8(x) != 0; } +#endif // EIGEN_VECTORIZE_AVX512FP16 + template <> EIGEN_STRONG_INLINE bool predux_any(const Packet8bf& x) { return _mm_movemask_epi8(x) != 0; @@ -2211,7 +2217,6 @@ struct unpacket_traits { }; typedef Packet8h half; }; -#endif template <> EIGEN_STRONG_INLINE Packet8h pset1(const Eigen::half& from) { @@ -2446,14 +2451,12 @@ EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const to[stride * 7] = aux[7]; } -#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE Eigen::half predux(const Packet8h& a) { Packet8f af = half2float(a); float reduced = predux(af); return Eigen::half(reduced); } -#endif template <> EIGEN_STRONG_INLINE Eigen::half predux_max(const Packet8h& a) { @@ -2553,6 +2556,8 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { kernel.packet[3] = pload(out[3]); } +#endif + // BFloat16 implementation. EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) { diff --git a/Eigen/src/Core/arch/AVX/TypeCasting.h b/Eigen/src/Core/arch/AVX/TypeCasting.h index 9dcd6ef84..5b73ffe86 100644 --- a/Eigen/src/Core/arch/AVX/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX/TypeCasting.h @@ -279,20 +279,22 @@ EIGEN_STRONG_INLINE Packet2l preinterpret(const Packet4l& a) } #endif +#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE Packet8f pcast(const Packet8h& a) { return half2float(a); } -template <> -EIGEN_STRONG_INLINE Packet8f pcast(const Packet8bf& a) { - return Bf16ToF32(a); -} - template <> EIGEN_STRONG_INLINE Packet8h pcast(const Packet8f& a) { return float2half(a); } +#endif + +template <> +EIGEN_STRONG_INLINE Packet8f pcast(const Packet8bf& a) { + return Bf16ToF32(a); +} template <> EIGEN_STRONG_INLINE Packet8bf pcast(const Packet8f& a) { diff --git a/Eigen/src/Core/arch/AVX512/MathFunctions.h b/Eigen/src/Core/arch/AVX512/MathFunctions.h index 603925429..04499a0c2 100644 --- a/Eigen/src/Core/arch/AVX512/MathFunctions.h +++ b/Eigen/src/Core/arch/AVX512/MathFunctions.h @@ -47,16 +47,16 @@ EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exp #if EIGEN_FAST_MATH template <> -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt(const Packet16f& _x) { - return generic_sqrt_newton_step::run(_x, _mm512_rsqrt14_ps(_x)); +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt(const Packet16f& x) { + return generic_sqrt_newton_step::run(x, _mm512_rsqrt14_ps(x)); } template <> -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt(const Packet8d& _x) { +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt(const Packet8d& x) { #ifdef EIGEN_VECTORIZE_AVX512ER - return generic_sqrt_newton_step::run(_x, _mm512_rsqrt28_pd(_x)); + return generic_sqrt_newton_step::run(x, _mm512_rsqrt28_pd(x)); #else - return generic_sqrt_newton_step::run(_x, _mm512_rsqrt14_pd(_x)); + return generic_sqrt_newton_step::run(x, _mm512_rsqrt14_pd(x)); #endif } #else @@ -80,19 +80,19 @@ EIGEN_STRONG_INLINE Packet16f prsqrt(const Packet16f& x) { #elif EIGEN_FAST_MATH template <> -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt(const Packet16f& _x) { - return generic_rsqrt_newton_step::run(_x, _mm512_rsqrt14_ps(_x)); +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt(const Packet16f& x) { + return generic_rsqrt_newton_step::run(x, _mm512_rsqrt14_ps(x)); } #endif // prsqrt for double. #if EIGEN_FAST_MATH template <> -EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt(const Packet8d& _x) { +EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt(const Packet8d& x) { #ifdef EIGEN_VECTORIZE_AVX512ER - return generic_rsqrt_newton_step::run(_x, _mm512_rsqrt28_pd(_x)); + return generic_rsqrt_newton_step::run(x, _mm512_rsqrt28_pd(x)); #else - return generic_rsqrt_newton_step::run(_x, _mm512_rsqrt14_pd(_x)); + return generic_rsqrt_newton_step::run(x, _mm512_rsqrt14_pd(x)); #endif } @@ -118,6 +118,8 @@ BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt) BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh) + +#ifndef EIGEN_VECTORIZE_AVX512FP16 F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos) F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp) F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp2) @@ -130,6 +132,7 @@ F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt) F16_PACKET_FUNCTION(Packet16f, Packet16h, psin) F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt) F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh) +#endif // EIGEN_VECTORIZE_AVX512FP16 } // end namespace internal diff --git a/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h b/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h new file mode 100644 index 000000000..240ade43e --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h @@ -0,0 +1,75 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2025 The Eigen Authors. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_MATH_FUNCTIONS_FP16_AVX512_H +#define EIGEN_MATH_FUNCTIONS_FP16_AVX512_H + +// IWYU pragma: private +#include "../../InternalHeaderCheck.h" + +namespace Eigen { +namespace internal { + +EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) { + __m512i result = _mm512_castsi256_si512(_mm256_castph_si256(a)); + result = _mm512_inserti64x4(result, _mm256_castph_si256(b), 1); + return _mm512_castsi512_ph(result); +} + +EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) { + a = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_castph_si512(x))); + b = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(_mm512_castph_si512(x), 1)); +} + +#define _EIGEN_GENERATE_FP16_MATH_FUNCTION(func) \ + template <> \ + EIGEN_STRONG_INLINE Packet8h func(const Packet8h& a) { \ + return float2half(func(half2float(a))); \ + } \ + \ + template <> \ + EIGEN_STRONG_INLINE Packet16h func(const Packet16h& a) { \ + return float2half(func(half2float(a))); \ + } \ + \ + template <> \ + EIGEN_STRONG_INLINE Packet32h func(const Packet32h& a) { \ + Packet16h low; \ + Packet16h high; \ + extract2Packet16h(a, low, high); \ + return combine2Packet16h(func(low), func(high)); \ + } + +_EIGEN_GENERATE_FP16_MATH_FUNCTION(psin) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(pcos) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog2) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog1p) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexpm1) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp2) +_EIGEN_GENERATE_FP16_MATH_FUNCTION(ptanh) +#undef _EIGEN_GENERATE_FP16_MATH_FUNCTION + +// pfrexp +template <> +EIGEN_STRONG_INLINE Packet32h pfrexp(const Packet32h& a, Packet32h& exponent) { + return pfrexp_generic(a, exponent); +} + +// pldexp +template <> +EIGEN_STRONG_INLINE Packet32h pldexp(const Packet32h& a, const Packet32h& exponent) { + return pldexp_generic(a, exponent); +} + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_MATH_FUNCTIONS_FP16_AVX512_H \ No newline at end of file diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 5d869e42b..c07774964 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -40,6 +40,10 @@ typedef eigen_packet_wrapper<__m256i, 1> Packet16h; #endif typedef eigen_packet_wrapper<__m256i, 2> Packet16bf; +typedef eigen_packet_wrapper<__m512i, 6> Packet32s; +typedef eigen_packet_wrapper<__m256i, 6> Packet16s; +typedef eigen_packet_wrapper<__m128i, 6> Packet8s; + template <> struct is_arithmetic<__m512> { enum { value = true }; @@ -248,6 +252,39 @@ struct unpacket_traits { }; #endif +template <> +struct unpacket_traits { + typedef numext::int16_t type; + typedef Packet16s half; + enum { + size = 32, + alignment = Aligned64, + vectorizable = false, + }; +}; + +template <> +struct unpacket_traits { + typedef numext::int16_t type; + typedef Packet8s half; + enum { + size = 16, + alignment = Aligned32, + vectorizable = false, + }; +}; + +template <> +struct unpacket_traits { + typedef numext::int16_t type; + typedef Packet8s half; + enum { + size = 8, + alignment = Aligned16, + vectorizable = false, + }; +}; + template <> EIGEN_STRONG_INLINE Packet16f pset1(const float& from) { return _mm512_set1_ps(from); @@ -1335,10 +1372,13 @@ EIGEN_STRONG_INLINE Packet8l pabs(const Packet8l& a) { return _mm512_abs_epi64(a); } +#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) { return _mm256_srai_epi16(a, 15); } +#endif // EIGEN_VECTORIZE_AVX512FP16 + template <> EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) { return _mm256_srai_epi16(a, 15); @@ -2199,6 +2239,7 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, const Packet8d& } // Packet math for Eigen::half +#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE Packet16h pset1(const Eigen::half& from) { return _mm256_set1_epi16(from.x); @@ -2369,7 +2410,6 @@ EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { return _mm256_xor_si256(a, sign_mask); } -#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE Packet16h padd(const Packet16h& a, const Packet16h& b) { Packet16f af = half2float(a); @@ -2408,8 +2448,6 @@ EIGEN_STRONG_INLINE half predux(const Packet16h& from) { return half(predux(from_float)); } -#endif - template <> EIGEN_STRONG_INLINE Packet8h predux_half_dowto4(const Packet16h& a) { Packet8h lane0 = _mm256_extractf128_si256(a, 0); @@ -2643,6 +2681,8 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { kernel.packet[3] = pload(out[3]); } +#endif // EIGEN_VECTORIZE_AVX512FP16 + template <> struct is_arithmetic { enum { value = true }; @@ -3095,6 +3135,158 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31); } +// Minimal implementation of 16-bit int packets for use in pfrexp, pldexp. + +template <> +EIGEN_STRONG_INLINE Packet32s pset1(const numext::int16_t& x) { + return _mm512_set1_epi16(x); +} + +template <> +EIGEN_STRONG_INLINE Packet16s pset1(const numext::int16_t& x) { + return _mm256_set1_epi16(x); +} + +template <> +EIGEN_STRONG_INLINE Packet8s pset1(const numext::int16_t& x) { + return _mm_set1_epi16(x); +} + +template <> +EIGEN_STRONG_INLINE void pstore(numext::int16_t* out, const Packet32s& x) { + _mm512_storeu_epi16(out, x); +} + +template <> +EIGEN_STRONG_INLINE void pstore(numext::int16_t* out, const Packet16s& x) { + _mm256_storeu_epi16(out, x); +} + +template <> +EIGEN_STRONG_INLINE void pstore(numext::int16_t* out, const Packet8s& x) { + _mm_storeu_epi16(out, x); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(numext::int16_t* out, const Packet32s& x) { + _mm512_storeu_epi16(out, x); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(numext::int16_t* out, const Packet16s& x) { + _mm256_storeu_epi16(out, x); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(numext::int16_t* out, const Packet8s& x) { + _mm_storeu_epi16(out, x); +} + +template <> +EIGEN_STRONG_INLINE Packet32s padd(const Packet32s& a, const Packet32s& b) { + return _mm512_add_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16s padd(const Packet16s& a, const Packet16s& b) { + return _mm256_add_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8s padd(const Packet8s& a, const Packet8s& b) { + return _mm_add_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet32s psub(const Packet32s& a, const Packet32s& b) { + return _mm512_sub_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16s psub(const Packet16s& a, const Packet16s& b) { + return _mm256_sub_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8s psub(const Packet8s& a, const Packet8s& b) { + return _mm_sub_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet32s pmul(const Packet32s& a, const Packet32s& b) { + return _mm512_mullo_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet16s pmul(const Packet16s& a, const Packet16s& b) { + return _mm256_mullo_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8s pmul(const Packet8s& a, const Packet8s& b) { + return _mm_mullo_epi16(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet32s pnegate(const Packet32s& a) { + return _mm512_sub_epi16(_mm512_setzero_si512(), a); +} + +template <> +EIGEN_STRONG_INLINE Packet16s pnegate(const Packet16s& a) { + return _mm256_sub_epi16(_mm256_setzero_si256(), a); +} + +template <> +EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) { + return _mm_sub_epi16(_mm_setzero_si128(), a); +} + +template +EIGEN_STRONG_INLINE Packet32s parithmetic_shift_right(Packet32s a) { + return _mm512_srai_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet16s parithmetic_shift_right(Packet16s a) { + return _mm256_srai_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) { + return _mm_srai_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet32s plogical_shift_left(Packet32s a) { + return _mm512_slli_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet16s plogical_shift_left(Packet16s a) { + return _mm256_slli_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) { + return _mm_slli_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet32s plogical_shift_right(Packet32s a) { + return _mm512_srli_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet16s plogical_shift_right(Packet16s a) { + return _mm256_srli_epi16(a, N); +} + +template +EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a) { + return _mm_srli_epi16(a, N); +} + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h index df5a0ef7a..a040bbead 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMathFP16.h +++ b/Eigen/src/Core/arch/AVX512/PacketMathFP16.h @@ -1,7 +1,7 @@ // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // -// +// Copyright (C) 2025 The Eigen Authors. // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed @@ -18,8 +18,8 @@ namespace Eigen { namespace internal { typedef __m512h Packet32h; -typedef eigen_packet_wrapper<__m256i, 1> Packet16h; -typedef eigen_packet_wrapper<__m128i, 2> Packet8h; +typedef __m256h Packet16h; +typedef __m128h Packet8h; template <> struct is_arithmetic { @@ -68,6 +68,7 @@ template <> struct unpacket_traits { typedef Eigen::half type; typedef Packet16h half; + typedef Packet32s integer_packet; enum { size = 32, alignment = Aligned64, @@ -81,6 +82,7 @@ template <> struct unpacket_traits { typedef Eigen::half type; typedef Packet8h half; + typedef Packet16s integer_packet; enum { size = 16, alignment = Aligned32, @@ -94,6 +96,7 @@ template <> struct unpacket_traits { typedef Eigen::half type; typedef Packet8h half; + typedef Packet8s integer_packet; enum { size = 8, alignment = Aligned16, @@ -103,14 +106,33 @@ struct unpacket_traits { }; }; +// Conversions + +EIGEN_STRONG_INLINE Packet16f half2float(const Packet16h& a) { return _mm512_cvtxph_ps(a); } + +EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { return _mm256_cvtxph_ps(a); } + +EIGEN_STRONG_INLINE Packet16h float2half(const Packet16f& a) { return _mm512_cvtxps_ph(a); } + +EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { return _mm256_cvtxps_ph(a); } + // Memory functions // pset1 template <> EIGEN_STRONG_INLINE Packet32h pset1(const Eigen::half& from) { - // half/half_raw is bit compatible - return _mm512_set1_ph(numext::bit_cast<_Float16>(from)); + return _mm512_set1_ph(from.x); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pset1(const Eigen::half& from) { + return _mm256_set1_ph(from.x); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pset1(const Eigen::half& from) { + return _mm_set1_ph(from.x); } template <> @@ -118,24 +140,47 @@ EIGEN_STRONG_INLINE Packet32h pzero(const Packet32h& /*a*/) { return _mm512_setzero_ph(); } +template <> +EIGEN_STRONG_INLINE Packet16h pzero(const Packet16h& /*a*/) { + return _mm256_setzero_ph(); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pzero(const Packet8h& /*a*/) { + return _mm_setzero_ph(); +} + // pset1frombits template <> EIGEN_STRONG_INLINE Packet32h pset1frombits(unsigned short from) { return _mm512_castsi512_ph(_mm512_set1_epi16(from)); } +template <> +EIGEN_STRONG_INLINE Packet16h pset1frombits(unsigned short from) { + return _mm256_castsi256_ph(_mm256_set1_epi16(from)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pset1frombits(unsigned short from) { + return _mm_castsi128_ph(_mm_set1_epi16(from)); +} + // pfirst template <> EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet32h& from) { -#ifdef EIGEN_VECTORIZE_AVX512DQ - return half_impl::raw_uint16_to_half( - static_cast(_mm256_extract_epi16(_mm512_extracti32x8_epi32(_mm512_castph_si512(from), 0), 0))); -#else - Eigen::half dest[32]; - _mm512_storeu_ph(dest, from); - return dest[0]; -#endif + return Eigen::half(_mm512_cvtsh_h(from)); +} + +template <> +EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet16h& from) { + return Eigen::half(_mm256_cvtsh_h(from)); +} + +template <> +EIGEN_STRONG_INLINE Eigen::half pfirst(const Packet8h& from) { + return Eigen::half(_mm_cvtsh_h(from)); } // pload @@ -145,6 +190,16 @@ EIGEN_STRONG_INLINE Packet32h pload(const Eigen::half* from) { EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ph(from); } +template <> +EIGEN_STRONG_INLINE Packet16h pload(const Eigen::half* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_ph(from); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pload(const Eigen::half* from) { + EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_ph(from); +} + // ploadu template <> @@ -152,6 +207,16 @@ EIGEN_STRONG_INLINE Packet32h ploadu(const Eigen::half* from) { EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ph(from); } +template <> +EIGEN_STRONG_INLINE Packet16h ploadu(const Eigen::half* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_ph(from); +} + +template <> +EIGEN_STRONG_INLINE Packet8h ploadu(const Eigen::half* from) { + EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_ph(from); +} + // pstore template <> @@ -159,6 +224,16 @@ EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet32h& from) { EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ph(to, from); } +template <> +EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet16h& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm256_store_ph(to, from); +} + +template <> +EIGEN_STRONG_INLINE void pstore(Eigen::half* to, const Packet8h& from) { + EIGEN_DEBUG_ALIGNED_STORE _mm_store_ph(to, from); +} + // pstoreu template <> @@ -166,6 +241,16 @@ EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet32h& from) { EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ph(to, from); } +template <> +EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet16h& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_ph(to, from); +} + +template <> +EIGEN_STRONG_INLINE void pstoreu(Eigen::half* to, const Packet8h& from) { + EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_ph(to, from); +} + // ploaddup template <> EIGEN_STRONG_INLINE Packet32h ploaddup(const Eigen::half* from) { @@ -175,6 +260,17 @@ EIGEN_STRONG_INLINE Packet32h ploaddup(const Eigen::half* from) { a); } +template <> +EIGEN_STRONG_INLINE Packet16h ploaddup(const Eigen::half* from) { + __m256h a = _mm256_castph128_ph256(_mm_loadu_ph(from)); + return _mm256_permutexvar_ph(_mm256_set_epi16(7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0), a); +} + +template <> +EIGEN_STRONG_INLINE Packet8h ploaddup(const Eigen::half* from) { + return _mm_set_ph(from[3].x, from[3].x, from[2].x, from[2].x, from[1].x, from[1].x, from[0].x, from[0].x); +} + // ploadquad template <> EIGEN_STRONG_INLINE Packet32h ploadquad(const Eigen::half* from) { @@ -184,6 +280,17 @@ EIGEN_STRONG_INLINE Packet32h ploadquad(const Eigen::half* from) { a); } +template <> +EIGEN_STRONG_INLINE Packet16h ploadquad(const Eigen::half* from) { + return _mm256_set_ph(from[3].x, from[3].x, from[3].x, from[3].x, from[2].x, from[2].x, from[2].x, from[2].x, + from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x); +} + +template <> +EIGEN_STRONG_INLINE Packet8h ploadquad(const Eigen::half* from) { + return _mm_set_ph(from[1].x, from[1].x, from[1].x, from[1].x, from[0].x, from[0].x, from[0].x, from[0].x); +} + // pabs template <> @@ -191,6 +298,16 @@ EIGEN_STRONG_INLINE Packet32h pabs(const Packet32h& a) { return _mm512_abs_ph(a); } +template <> +EIGEN_STRONG_INLINE Packet16h pabs(const Packet16h& a) { + return _mm256_abs_ph(a); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pabs(const Packet8h& a) { + return _mm_abs_ph(a); +} + // psignbit template <> @@ -198,6 +315,16 @@ EIGEN_STRONG_INLINE Packet32h psignbit(const Packet32h& a) { return _mm512_castsi512_ph(_mm512_srai_epi16(_mm512_castph_si512(a), 15)); } +template <> +EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) { + return _mm256_castsi256_ph(_mm256_srai_epi16(_mm256_castph_si256(a), 15)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) { + return _mm_castsi128_ph(_mm_srai_epi16(_mm_castph_si128(a), 15)); +} + // pmin template <> @@ -205,6 +332,16 @@ EIGEN_STRONG_INLINE Packet32h pmin(const Packet32h& a, const Packet32 return _mm512_min_ph(a, b); } +template <> +EIGEN_STRONG_INLINE Packet16h pmin(const Packet16h& a, const Packet16h& b) { + return _mm256_min_ph(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pmin(const Packet8h& a, const Packet8h& b) { + return _mm_min_ph(a, b); +} + // pmax template <> @@ -212,6 +349,16 @@ EIGEN_STRONG_INLINE Packet32h pmax(const Packet32h& a, const Packet32 return _mm512_max_ph(a, b); } +template <> +EIGEN_STRONG_INLINE Packet16h pmax(const Packet16h& a, const Packet16h& b) { + return _mm256_max_ph(a, b); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pmax(const Packet8h& a, const Packet8h& b) { + return _mm_max_ph(a, b); +} + // plset template <> EIGEN_STRONG_INLINE Packet32h plset(const half& a) { @@ -219,6 +366,16 @@ EIGEN_STRONG_INLINE Packet32h plset(const half& a) { 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)); } +template <> +EIGEN_STRONG_INLINE Packet16h plset(const half& a) { + return _mm256_add_ph(pset1(a), _mm256_set_ph(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h plset(const half& a) { + return _mm_add_ph(pset1(a), _mm_set_ph(7, 6, 5, 4, 3, 2, 1, 0)); +} + // por template <> @@ -226,6 +383,16 @@ EIGEN_STRONG_INLINE Packet32h por(const Packet32h& a, const Packet32h& b) { return _mm512_castsi512_ph(_mm512_or_si512(_mm512_castph_si512(a), _mm512_castph_si512(b))); } +template <> +EIGEN_STRONG_INLINE Packet16h por(const Packet16h& a, const Packet16h& b) { + return _mm256_castsi256_ph(_mm256_or_si256(_mm256_castph_si256(a), _mm256_castph_si256(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h por(const Packet8h& a, const Packet8h& b) { + return _mm_castsi128_ph(_mm_or_si128(_mm_castph_si128(a), _mm_castph_si128(b))); +} + // pxor template <> @@ -233,6 +400,16 @@ EIGEN_STRONG_INLINE Packet32h pxor(const Packet32h& a, const Packet32h& b) { return _mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(a), _mm512_castph_si512(b))); } +template <> +EIGEN_STRONG_INLINE Packet16h pxor(const Packet16h& a, const Packet16h& b) { + return _mm256_castsi256_ph(_mm256_xor_si256(_mm256_castph_si256(a), _mm256_castph_si256(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pxor(const Packet8h& a, const Packet8h& b) { + return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_castph_si128(b))); +} + // pand template <> @@ -240,6 +417,16 @@ EIGEN_STRONG_INLINE Packet32h pand(const Packet32h& a, const Packet32h& b) { return _mm512_castsi512_ph(_mm512_and_si512(_mm512_castph_si512(a), _mm512_castph_si512(b))); } +template <> +EIGEN_STRONG_INLINE Packet16h pand(const Packet16h& a, const Packet16h& b) { + return _mm256_castsi256_ph(_mm256_and_si256(_mm256_castph_si256(a), _mm256_castph_si256(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pand(const Packet8h& a, const Packet8h& b) { + return _mm_castsi128_ph(_mm_and_si128(_mm_castph_si128(a), _mm_castph_si128(b))); +} + // pandnot template <> @@ -247,6 +434,16 @@ EIGEN_STRONG_INLINE Packet32h pandnot(const Packet32h& a, const Packet32h& b) { return _mm512_castsi512_ph(_mm512_andnot_si512(_mm512_castph_si512(b), _mm512_castph_si512(a))); } +template <> +EIGEN_STRONG_INLINE Packet16h pandnot(const Packet16h& a, const Packet16h& b) { + return _mm256_castsi256_ph(_mm256_andnot_si256(_mm256_castph_si256(b), _mm256_castph_si256(a))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pandnot(const Packet8h& a, const Packet8h& b) { + return _mm_castsi128_ph(_mm_andnot_si128(_mm_castph_si128(b), _mm_castph_si128(a))); +} + // pselect template <> @@ -255,6 +452,18 @@ EIGEN_DEVICE_FUNC inline Packet32h pselect(const Packet32h& mask, const Packet32 return _mm512_mask_blend_ph(mask32, a, b); } +template <> +EIGEN_DEVICE_FUNC inline Packet16h pselect(const Packet16h& mask, const Packet16h& a, const Packet16h& b) { + __mmask16 mask16 = _mm256_cmp_epi16_mask(_mm256_castph_si256(mask), _mm256_setzero_si256(), _MM_CMPINT_EQ); + return _mm256_mask_blend_ph(mask16, a, b); +} + +template <> +EIGEN_DEVICE_FUNC inline Packet8h pselect(const Packet8h& mask, const Packet8h& a, const Packet8h& b) { + __mmask8 mask8 = _mm_cmp_epi16_mask(_mm_castph_si128(mask), _mm_setzero_si128(), _MM_CMPINT_EQ); + return _mm_mask_blend_ph(mask8, a, b); +} + // pcmp_eq template <> @@ -263,6 +472,18 @@ EIGEN_STRONG_INLINE Packet32h pcmp_eq(const Packet32h& a, const Packet32h& b) { return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast(0xffffu))); } +template <> +EIGEN_STRONG_INLINE Packet16h pcmp_eq(const Packet16h& a, const Packet16h& b) { + __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_EQ_OQ); + return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast(0xffffu))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) { + __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_EQ_OQ); + return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast(0xffffu))); +} + // pcmp_le template <> @@ -271,6 +492,18 @@ EIGEN_STRONG_INLINE Packet32h pcmp_le(const Packet32h& a, const Packet32h& b) { return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast(0xffffu))); } +template <> +EIGEN_STRONG_INLINE Packet16h pcmp_le(const Packet16h& a, const Packet16h& b) { + __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LE_OQ); + return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast(0xffffu))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) { + __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LE_OQ); + return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast(0xffffu))); +} + // pcmp_lt template <> @@ -279,6 +512,18 @@ EIGEN_STRONG_INLINE Packet32h pcmp_lt(const Packet32h& a, const Packet32h& b) { return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, static_cast(0xffffu))); } +template <> +EIGEN_STRONG_INLINE Packet16h pcmp_lt(const Packet16h& a, const Packet16h& b) { + __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_LT_OQ); + return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast(0xffffu))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) { + __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_LT_OQ); + return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast(0xffffu))); +} + // pcmp_lt_or_nan template <> @@ -287,6 +532,18 @@ EIGEN_STRONG_INLINE Packet32h pcmp_lt_or_nan(const Packet32h& a, const Packet32h return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi16(0), mask, static_cast(0xffffu))); } +template <> +EIGEN_STRONG_INLINE Packet16h pcmp_lt_or_nan(const Packet16h& a, const Packet16h& b) { + __mmask16 mask = _mm256_cmp_ph_mask(a, b, _CMP_NGE_UQ); + return _mm256_castsi256_ph(_mm256_mask_set1_epi16(_mm256_set1_epi32(0), mask, static_cast(0xffffu))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) { + __mmask8 mask = _mm_cmp_ph_mask(a, b, _CMP_NGE_UQ); + return _mm_castsi128_ph(_mm_mask_set1_epi16(_mm_set1_epi32(0), mask, static_cast(0xffffu))); +} + // padd template <> @@ -296,12 +553,12 @@ EIGEN_STRONG_INLINE Packet32h padd(const Packet32h& a, const Packet32 template <> EIGEN_STRONG_INLINE Packet16h padd(const Packet16h& a, const Packet16h& b) { - return _mm256_castph_si256(_mm256_add_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b))); + return _mm256_add_ph(a, b); } template <> EIGEN_STRONG_INLINE Packet8h padd(const Packet8h& a, const Packet8h& b) { - return _mm_castph_si128(_mm_add_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b))); + return _mm_add_ph(a, b); } // psub @@ -313,12 +570,12 @@ EIGEN_STRONG_INLINE Packet32h psub(const Packet32h& a, const Packet32 template <> EIGEN_STRONG_INLINE Packet16h psub(const Packet16h& a, const Packet16h& b) { - return _mm256_castph_si256(_mm256_sub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b))); + return _mm256_sub_ph(a, b); } template <> EIGEN_STRONG_INLINE Packet8h psub(const Packet8h& a, const Packet8h& b) { - return _mm_castph_si128(_mm_sub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b))); + return _mm_sub_ph(a, b); } // pmul @@ -330,12 +587,12 @@ EIGEN_STRONG_INLINE Packet32h pmul(const Packet32h& a, const Packet32 template <> EIGEN_STRONG_INLINE Packet16h pmul(const Packet16h& a, const Packet16h& b) { - return _mm256_castph_si256(_mm256_mul_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b))); + return _mm256_mul_ph(a, b); } template <> EIGEN_STRONG_INLINE Packet8h pmul(const Packet8h& a, const Packet8h& b) { - return _mm_castph_si128(_mm_mul_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b))); + return _mm_mul_ph(a, b); } // pdiv @@ -347,12 +604,13 @@ EIGEN_STRONG_INLINE Packet32h pdiv(const Packet32h& a, const Packet32 template <> EIGEN_STRONG_INLINE Packet16h pdiv(const Packet16h& a, const Packet16h& b) { - return _mm256_castph_si256(_mm256_div_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b))); + return _mm256_div_ph(a, b); } template <> EIGEN_STRONG_INLINE Packet8h pdiv(const Packet8h& a, const Packet8h& b) { - return _mm_castph_si128(_mm_div_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b))); + return _mm_div_ph(a, b); + ; } // pround @@ -361,14 +619,40 @@ template <> EIGEN_STRONG_INLINE Packet32h pround(const Packet32h& a) { // Work-around for default std::round rounding mode. - // Mask for the sign bit - const Packet32h signMask = pset1frombits(static_cast(0x8000u)); - // The largest half-preicision float less than 0.5 + // Mask for the sign bit. + const Packet32h signMask = + pset1frombits(static_cast(static_cast(0x8000u))); + // The largest half-precision float less than 0.5. const Packet32h prev0dot5 = pset1frombits(static_cast(0x37FFu)); return _mm512_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO); } +template <> +EIGEN_STRONG_INLINE Packet16h pround(const Packet16h& a) { + // Work-around for default std::round rounding mode. + + // Mask for the sign bit. + const Packet16h signMask = + pset1frombits(static_cast(static_cast(0x8000u))); + // The largest half-precision float less than 0.5. + const Packet16h prev0dot5 = pset1frombits(static_cast(0x37FFu)); + + return _mm256_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pround(const Packet8h& a) { + // Work-around for default std::round rounding mode. + + // Mask for the sign bit. + const Packet8h signMask = pset1frombits(static_cast(static_cast(0x8000u))); + // The largest half-precision float less than 0.5. + const Packet8h prev0dot5 = pset1frombits(static_cast(0x37FFu)); + + return _mm_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO); +} + // print template <> @@ -376,6 +660,16 @@ EIGEN_STRONG_INLINE Packet32h print(const Packet32h& a) { return _mm512_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION); } +template <> +EIGEN_STRONG_INLINE Packet16h print(const Packet16h& a) { + return _mm256_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION); +} + +template <> +EIGEN_STRONG_INLINE Packet8h print(const Packet8h& a) { + return _mm_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION); +} + // pceil template <> @@ -383,6 +677,16 @@ EIGEN_STRONG_INLINE Packet32h pceil(const Packet32h& a) { return _mm512_roundscale_ph(a, _MM_FROUND_TO_POS_INF); } +template <> +EIGEN_STRONG_INLINE Packet16h pceil(const Packet16h& a) { + return _mm256_roundscale_ph(a, _MM_FROUND_TO_POS_INF); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pceil(const Packet8h& a) { + return _mm_roundscale_ph(a, _MM_FROUND_TO_POS_INF); +} + // pfloor template <> @@ -390,6 +694,16 @@ EIGEN_STRONG_INLINE Packet32h pfloor(const Packet32h& a) { return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF); } +template <> +EIGEN_STRONG_INLINE Packet16h pfloor(const Packet16h& a) { + return _mm256_roundscale_ph(a, _MM_FROUND_TO_NEG_INF); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pfloor(const Packet8h& a) { + return _mm_roundscale_ph(a, _MM_FROUND_TO_NEG_INF); +} + // ptrunc template <> @@ -397,47 +711,99 @@ EIGEN_STRONG_INLINE Packet32h ptrunc(const Packet32h& a) { return _mm512_roundscale_ph(a, _MM_FROUND_TO_ZERO); } +template <> +EIGEN_STRONG_INLINE Packet16h ptrunc(const Packet16h& a) { + return _mm256_roundscale_ph(a, _MM_FROUND_TO_ZERO); +} + +template <> +EIGEN_STRONG_INLINE Packet8h ptrunc(const Packet8h& a) { + return _mm_roundscale_ph(a, _MM_FROUND_TO_ZERO); +} + // predux template <> EIGEN_STRONG_INLINE half predux(const Packet32h& a) { - return (half)_mm512_reduce_add_ph(a); + return half(_mm512_reduce_add_ph(a)); } template <> EIGEN_STRONG_INLINE half predux(const Packet16h& a) { - return (half)_mm256_reduce_add_ph(_mm256_castsi256_ph(a)); + return half(_mm256_reduce_add_ph(a)); } template <> EIGEN_STRONG_INLINE half predux(const Packet8h& a) { - return (half)_mm_reduce_add_ph(_mm_castsi128_ph(a)); + return half(_mm_reduce_add_ph(a)); } // predux_half_dowto4 template <> EIGEN_STRONG_INLINE Packet16h predux_half_dowto4(const Packet32h& a) { -#ifdef EIGEN_VECTORIZE_AVX512DQ - __m256i lowHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 0)); - __m256i highHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 1)); + const __m512i bits = _mm512_castph_si512(a); + Packet16h lo = _mm256_castsi256_ph(_mm512_castsi512_si256(bits)); + Packet16h hi = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(bits, 1)); + return padd(lo, hi); +} - return Packet16h(padd(lowHalf, highHalf)); -#else - Eigen::half data[32]; - _mm512_storeu_ph(data, a); - - __m256i lowHalf = _mm256_castph_si256(_mm256_loadu_ph(data)); - __m256i highHalf = _mm256_castph_si256(_mm256_loadu_ph(data + 16)); - - return Packet16h(padd(lowHalf, highHalf)); -#endif +template <> +EIGEN_STRONG_INLINE Packet8h predux_half_dowto4(const Packet16h& a) { + Packet8h lo = _mm_castsi128_ph(_mm256_castsi256_si128(_mm256_castph_si256(a))); + Packet8h hi = _mm_castps_ph(_mm256_extractf128_ps(_mm256_castph_ps(a), 1)); + return padd(lo, hi); } // predux_max +template <> +EIGEN_STRONG_INLINE half predux_max(const Packet32h& a) { + return half(_mm512_reduce_max_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE half predux_max(const Packet16h& a) { + return half(_mm256_reduce_max_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE half predux_max(const Packet8h& a) { + return half(_mm_reduce_max_ph(a)); +} + // predux_min +template <> +EIGEN_STRONG_INLINE half predux_min(const Packet32h& a) { + return half(_mm512_reduce_min_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE half predux_min(const Packet16h& a) { + return half(_mm256_reduce_min_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE half predux_min(const Packet8h& a) { + return half(_mm_reduce_min_ph(a)); +} + // predux_mul +template <> +EIGEN_STRONG_INLINE half predux_mul(const Packet32h& a) { + return half(_mm512_reduce_mul_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE half predux_mul(const Packet16h& a) { + return half(_mm256_reduce_mul_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE half predux_mul(const Packet8h& a) { + return half(_mm_reduce_mul_ph(a)); +} + #ifdef EIGEN_VECTORIZE_FMA // pmadd @@ -449,12 +815,12 @@ EIGEN_STRONG_INLINE Packet32h pmadd(const Packet32h& a, const Packet32h& b, cons template <> EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) { - return _mm256_castph_si256(_mm256_fmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c))); + return _mm256_fmadd_ph(a, b, c); } template <> EIGEN_STRONG_INLINE Packet8h pmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) { - return _mm_castph_si128(_mm_fmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c))); + return _mm_fmadd_ph(a, b, c); } // pmsub @@ -466,12 +832,12 @@ EIGEN_STRONG_INLINE Packet32h pmsub(const Packet32h& a, const Packet32h& b, cons template <> EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) { - return _mm256_castph_si256(_mm256_fmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c))); + return _mm256_fmsub_ph(a, b, c); } template <> EIGEN_STRONG_INLINE Packet8h pmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) { - return _mm_castph_si128(_mm_fmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c))); + return _mm_fmsub_ph(a, b, c); } // pnmadd @@ -483,12 +849,12 @@ EIGEN_STRONG_INLINE Packet32h pnmadd(const Packet32h& a, const Packet32h& b, con template <> EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) { - return _mm256_castph_si256(_mm256_fnmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c))); + return _mm256_fnmadd_ph(a, b, c); } template <> EIGEN_STRONG_INLINE Packet8h pnmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) { - return _mm_castph_si128(_mm_fnmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c))); + return _mm_fnmadd_ph(a, b, c); } // pnmsub @@ -500,12 +866,12 @@ EIGEN_STRONG_INLINE Packet32h pnmsub(const Packet32h& a, const Packet32h& b, con template <> EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) { - return _mm256_castph_si256(_mm256_fnmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c))); + return _mm256_fnmsub_ph(a, b, c); } template <> EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) { - return _mm_castph_si128(_mm_fnmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c))); + return _mm_fnmsub_ph(a, b, c); } #endif @@ -514,35 +880,74 @@ EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const template <> EIGEN_STRONG_INLINE Packet32h pnegate(const Packet32h& a) { - return psub(pzero(a), a); + return _mm512_castsi512_ph( + _mm512_xor_si512(_mm512_castph_si512(a), _mm512_set1_epi16(static_cast(0x8000u)))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) { + return _mm256_castsi256_ph( + _mm256_xor_si256(_mm256_castph_si256(a), _mm256_set1_epi16(static_cast(0x8000u)))); +} + +template <> +EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) { + return _mm_castsi128_ph(_mm_xor_si128(_mm_castph_si128(a), _mm_set1_epi16(static_cast(0x8000u)))); } // pconj -template <> -EIGEN_STRONG_INLINE Packet32h pconj(const Packet32h& a) { - return a; -} +// Nothing, packets are real. // psqrt template <> EIGEN_STRONG_INLINE Packet32h psqrt(const Packet32h& a) { - return _mm512_sqrt_ph(a); + return generic_sqrt_newton_step::run(a, _mm512_rsqrt_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet16h psqrt(const Packet16h& a) { + return generic_sqrt_newton_step::run(a, _mm256_rsqrt_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h psqrt(const Packet8h& a) { + return generic_sqrt_newton_step::run(a, _mm_rsqrt_ph(a)); } // prsqrt template <> EIGEN_STRONG_INLINE Packet32h prsqrt(const Packet32h& a) { - return _mm512_rsqrt_ph(a); + return generic_rsqrt_newton_step::run(a, _mm512_rsqrt_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet16h prsqrt(const Packet16h& a) { + return generic_rsqrt_newton_step::run(a, _mm256_rsqrt_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h prsqrt(const Packet8h& a) { + return generic_rsqrt_newton_step::run(a, _mm_rsqrt_ph(a)); } // preciprocal template <> EIGEN_STRONG_INLINE Packet32h preciprocal(const Packet32h& a) { - return _mm512_rcp_ph(a); + return generic_reciprocal_newton_step::run(a, _mm512_rcp_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet16h preciprocal(const Packet16h& a) { + return generic_reciprocal_newton_step::run(a, _mm256_rcp_ph(a)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h preciprocal(const Packet8h& a) { + return generic_reciprocal_newton_step::run(a, _mm_rcp_ph(a)); } // ptranspose @@ -663,6 +1068,246 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock& a) { a.packet[3] = _mm512_castsi512_ph(a3); } +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + __m256i a = _mm256_castph_si256(kernel.packet[0]); + __m256i b = _mm256_castph_si256(kernel.packet[1]); + __m256i c = _mm256_castph_si256(kernel.packet[2]); + __m256i d = _mm256_castph_si256(kernel.packet[3]); + __m256i e = _mm256_castph_si256(kernel.packet[4]); + __m256i f = _mm256_castph_si256(kernel.packet[5]); + __m256i g = _mm256_castph_si256(kernel.packet[6]); + __m256i h = _mm256_castph_si256(kernel.packet[7]); + __m256i i = _mm256_castph_si256(kernel.packet[8]); + __m256i j = _mm256_castph_si256(kernel.packet[9]); + __m256i k = _mm256_castph_si256(kernel.packet[10]); + __m256i l = _mm256_castph_si256(kernel.packet[11]); + __m256i m = _mm256_castph_si256(kernel.packet[12]); + __m256i n = _mm256_castph_si256(kernel.packet[13]); + __m256i o = _mm256_castph_si256(kernel.packet[14]); + __m256i p = _mm256_castph_si256(kernel.packet[15]); + + __m256i ab_07 = _mm256_unpacklo_epi16(a, b); + __m256i cd_07 = _mm256_unpacklo_epi16(c, d); + __m256i ef_07 = _mm256_unpacklo_epi16(e, f); + __m256i gh_07 = _mm256_unpacklo_epi16(g, h); + __m256i ij_07 = _mm256_unpacklo_epi16(i, j); + __m256i kl_07 = _mm256_unpacklo_epi16(k, l); + __m256i mn_07 = _mm256_unpacklo_epi16(m, n); + __m256i op_07 = _mm256_unpacklo_epi16(o, p); + + __m256i ab_8f = _mm256_unpackhi_epi16(a, b); + __m256i cd_8f = _mm256_unpackhi_epi16(c, d); + __m256i ef_8f = _mm256_unpackhi_epi16(e, f); + __m256i gh_8f = _mm256_unpackhi_epi16(g, h); + __m256i ij_8f = _mm256_unpackhi_epi16(i, j); + __m256i kl_8f = _mm256_unpackhi_epi16(k, l); + __m256i mn_8f = _mm256_unpackhi_epi16(m, n); + __m256i op_8f = _mm256_unpackhi_epi16(o, p); + + __m256i abcd_03 = _mm256_unpacklo_epi32(ab_07, cd_07); + __m256i abcd_47 = _mm256_unpackhi_epi32(ab_07, cd_07); + __m256i efgh_03 = _mm256_unpacklo_epi32(ef_07, gh_07); + __m256i efgh_47 = _mm256_unpackhi_epi32(ef_07, gh_07); + __m256i ijkl_03 = _mm256_unpacklo_epi32(ij_07, kl_07); + __m256i ijkl_47 = _mm256_unpackhi_epi32(ij_07, kl_07); + __m256i mnop_03 = _mm256_unpacklo_epi32(mn_07, op_07); + __m256i mnop_47 = _mm256_unpackhi_epi32(mn_07, op_07); + + __m256i abcd_8b = _mm256_unpacklo_epi32(ab_8f, cd_8f); + __m256i abcd_cf = _mm256_unpackhi_epi32(ab_8f, cd_8f); + __m256i efgh_8b = _mm256_unpacklo_epi32(ef_8f, gh_8f); + __m256i efgh_cf = _mm256_unpackhi_epi32(ef_8f, gh_8f); + __m256i ijkl_8b = _mm256_unpacklo_epi32(ij_8f, kl_8f); + __m256i ijkl_cf = _mm256_unpackhi_epi32(ij_8f, kl_8f); + __m256i mnop_8b = _mm256_unpacklo_epi32(mn_8f, op_8f); + __m256i mnop_cf = _mm256_unpackhi_epi32(mn_8f, op_8f); + + __m256i abcdefgh_01 = _mm256_unpacklo_epi64(abcd_03, efgh_03); + __m256i abcdefgh_23 = _mm256_unpackhi_epi64(abcd_03, efgh_03); + __m256i ijklmnop_01 = _mm256_unpacklo_epi64(ijkl_03, mnop_03); + __m256i ijklmnop_23 = _mm256_unpackhi_epi64(ijkl_03, mnop_03); + __m256i abcdefgh_45 = _mm256_unpacklo_epi64(abcd_47, efgh_47); + __m256i abcdefgh_67 = _mm256_unpackhi_epi64(abcd_47, efgh_47); + __m256i ijklmnop_45 = _mm256_unpacklo_epi64(ijkl_47, mnop_47); + __m256i ijklmnop_67 = _mm256_unpackhi_epi64(ijkl_47, mnop_47); + __m256i abcdefgh_89 = _mm256_unpacklo_epi64(abcd_8b, efgh_8b); + __m256i abcdefgh_ab = _mm256_unpackhi_epi64(abcd_8b, efgh_8b); + __m256i ijklmnop_89 = _mm256_unpacklo_epi64(ijkl_8b, mnop_8b); + __m256i ijklmnop_ab = _mm256_unpackhi_epi64(ijkl_8b, mnop_8b); + __m256i abcdefgh_cd = _mm256_unpacklo_epi64(abcd_cf, efgh_cf); + __m256i abcdefgh_ef = _mm256_unpackhi_epi64(abcd_cf, efgh_cf); + __m256i ijklmnop_cd = _mm256_unpacklo_epi64(ijkl_cf, mnop_cf); + __m256i ijklmnop_ef = _mm256_unpackhi_epi64(ijkl_cf, mnop_cf); + + // NOTE: no unpacklo/hi instr in this case, so using permute instr. + __m256i a_p_0 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x20); + __m256i a_p_1 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x20); + __m256i a_p_2 = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x20); + __m256i a_p_3 = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x20); + __m256i a_p_4 = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x20); + __m256i a_p_5 = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x20); + __m256i a_p_6 = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x20); + __m256i a_p_7 = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x20); + __m256i a_p_8 = _mm256_permute2x128_si256(abcdefgh_01, ijklmnop_01, 0x31); + __m256i a_p_9 = _mm256_permute2x128_si256(abcdefgh_23, ijklmnop_23, 0x31); + __m256i a_p_a = _mm256_permute2x128_si256(abcdefgh_45, ijklmnop_45, 0x31); + __m256i a_p_b = _mm256_permute2x128_si256(abcdefgh_67, ijklmnop_67, 0x31); + __m256i a_p_c = _mm256_permute2x128_si256(abcdefgh_89, ijklmnop_89, 0x31); + __m256i a_p_d = _mm256_permute2x128_si256(abcdefgh_ab, ijklmnop_ab, 0x31); + __m256i a_p_e = _mm256_permute2x128_si256(abcdefgh_cd, ijklmnop_cd, 0x31); + __m256i a_p_f = _mm256_permute2x128_si256(abcdefgh_ef, ijklmnop_ef, 0x31); + + kernel.packet[0] = _mm256_castsi256_ph(a_p_0); + kernel.packet[1] = _mm256_castsi256_ph(a_p_1); + kernel.packet[2] = _mm256_castsi256_ph(a_p_2); + kernel.packet[3] = _mm256_castsi256_ph(a_p_3); + kernel.packet[4] = _mm256_castsi256_ph(a_p_4); + kernel.packet[5] = _mm256_castsi256_ph(a_p_5); + kernel.packet[6] = _mm256_castsi256_ph(a_p_6); + kernel.packet[7] = _mm256_castsi256_ph(a_p_7); + kernel.packet[8] = _mm256_castsi256_ph(a_p_8); + kernel.packet[9] = _mm256_castsi256_ph(a_p_9); + kernel.packet[10] = _mm256_castsi256_ph(a_p_a); + kernel.packet[11] = _mm256_castsi256_ph(a_p_b); + kernel.packet[12] = _mm256_castsi256_ph(a_p_c); + kernel.packet[13] = _mm256_castsi256_ph(a_p_d); + kernel.packet[14] = _mm256_castsi256_ph(a_p_e); + kernel.packet[15] = _mm256_castsi256_ph(a_p_f); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + EIGEN_ALIGN64 half in[8][16]; + pstore(in[0], kernel.packet[0]); + pstore(in[1], kernel.packet[1]); + pstore(in[2], kernel.packet[2]); + pstore(in[3], kernel.packet[3]); + pstore(in[4], kernel.packet[4]); + pstore(in[5], kernel.packet[5]); + pstore(in[6], kernel.packet[6]); + pstore(in[7], kernel.packet[7]); + + EIGEN_ALIGN64 half out[8][16]; + + for (int i = 0; i < 8; ++i) { + for (int j = 0; j < 8; ++j) { + out[i][j] = in[j][2 * i]; + } + for (int j = 0; j < 8; ++j) { + out[i][j + 8] = in[j][2 * i + 1]; + } + } + + kernel.packet[0] = pload(out[0]); + kernel.packet[1] = pload(out[1]); + kernel.packet[2] = pload(out[2]); + kernel.packet[3] = pload(out[3]); + kernel.packet[4] = pload(out[4]); + kernel.packet[5] = pload(out[5]); + kernel.packet[6] = pload(out[6]); + kernel.packet[7] = pload(out[7]); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + EIGEN_ALIGN64 half in[4][16]; + pstore(in[0], kernel.packet[0]); + pstore(in[1], kernel.packet[1]); + pstore(in[2], kernel.packet[2]); + pstore(in[3], kernel.packet[3]); + + EIGEN_ALIGN64 half out[4][16]; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + out[i][j] = in[j][4 * i]; + } + for (int j = 0; j < 4; ++j) { + out[i][j + 4] = in[j][4 * i + 1]; + } + for (int j = 0; j < 4; ++j) { + out[i][j + 8] = in[j][4 * i + 2]; + } + for (int j = 0; j < 4; ++j) { + out[i][j + 12] = in[j][4 * i + 3]; + } + } + + kernel.packet[0] = pload(out[0]); + kernel.packet[1] = pload(out[1]); + kernel.packet[2] = pload(out[2]); + kernel.packet[3] = pload(out[3]); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + __m128i a = _mm_castph_si128(kernel.packet[0]); + __m128i b = _mm_castph_si128(kernel.packet[1]); + __m128i c = _mm_castph_si128(kernel.packet[2]); + __m128i d = _mm_castph_si128(kernel.packet[3]); + __m128i e = _mm_castph_si128(kernel.packet[4]); + __m128i f = _mm_castph_si128(kernel.packet[5]); + __m128i g = _mm_castph_si128(kernel.packet[6]); + __m128i h = _mm_castph_si128(kernel.packet[7]); + + __m128i a03b03 = _mm_unpacklo_epi16(a, b); + __m128i c03d03 = _mm_unpacklo_epi16(c, d); + __m128i e03f03 = _mm_unpacklo_epi16(e, f); + __m128i g03h03 = _mm_unpacklo_epi16(g, h); + __m128i a47b47 = _mm_unpackhi_epi16(a, b); + __m128i c47d47 = _mm_unpackhi_epi16(c, d); + __m128i e47f47 = _mm_unpackhi_epi16(e, f); + __m128i g47h47 = _mm_unpackhi_epi16(g, h); + + __m128i a01b01c01d01 = _mm_unpacklo_epi32(a03b03, c03d03); + __m128i a23b23c23d23 = _mm_unpackhi_epi32(a03b03, c03d03); + __m128i e01f01g01h01 = _mm_unpacklo_epi32(e03f03, g03h03); + __m128i e23f23g23h23 = _mm_unpackhi_epi32(e03f03, g03h03); + __m128i a45b45c45d45 = _mm_unpacklo_epi32(a47b47, c47d47); + __m128i a67b67c67d67 = _mm_unpackhi_epi32(a47b47, c47d47); + __m128i e45f45g45h45 = _mm_unpacklo_epi32(e47f47, g47h47); + __m128i e67f67g67h67 = _mm_unpackhi_epi32(e47f47, g47h47); + + __m128i a0b0c0d0e0f0g0h0 = _mm_unpacklo_epi64(a01b01c01d01, e01f01g01h01); + __m128i a1b1c1d1e1f1g1h1 = _mm_unpackhi_epi64(a01b01c01d01, e01f01g01h01); + __m128i a2b2c2d2e2f2g2h2 = _mm_unpacklo_epi64(a23b23c23d23, e23f23g23h23); + __m128i a3b3c3d3e3f3g3h3 = _mm_unpackhi_epi64(a23b23c23d23, e23f23g23h23); + __m128i a4b4c4d4e4f4g4h4 = _mm_unpacklo_epi64(a45b45c45d45, e45f45g45h45); + __m128i a5b5c5d5e5f5g5h5 = _mm_unpackhi_epi64(a45b45c45d45, e45f45g45h45); + __m128i a6b6c6d6e6f6g6h6 = _mm_unpacklo_epi64(a67b67c67d67, e67f67g67h67); + __m128i a7b7c7d7e7f7g7h7 = _mm_unpackhi_epi64(a67b67c67d67, e67f67g67h67); + + kernel.packet[0] = _mm_castsi128_ph(a0b0c0d0e0f0g0h0); + kernel.packet[1] = _mm_castsi128_ph(a1b1c1d1e1f1g1h1); + kernel.packet[2] = _mm_castsi128_ph(a2b2c2d2e2f2g2h2); + kernel.packet[3] = _mm_castsi128_ph(a3b3c3d3e3f3g3h3); + kernel.packet[4] = _mm_castsi128_ph(a4b4c4d4e4f4g4h4); + kernel.packet[5] = _mm_castsi128_ph(a5b5c5d5e5f5g5h5); + kernel.packet[6] = _mm_castsi128_ph(a6b6c6d6e6f6g6h6); + kernel.packet[7] = _mm_castsi128_ph(a7b7c7d7e7f7g7h7); +} + +EIGEN_STRONG_INLINE void ptranspose(PacketBlock& kernel) { + EIGEN_ALIGN32 Eigen::half in[4][8]; + pstore(in[0], kernel.packet[0]); + pstore(in[1], kernel.packet[1]); + pstore(in[2], kernel.packet[2]); + pstore(in[3], kernel.packet[3]); + + EIGEN_ALIGN32 Eigen::half out[4][8]; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + out[i][j] = in[j][2 * i]; + } + for (int j = 0; j < 4; ++j) { + out[i][j + 4] = in[j][2 * i + 1]; + } + } + + kernel.packet[0] = pload(out[0]); + kernel.packet[1] = pload(out[1]); + kernel.packet[2] = pload(out[2]); + kernel.packet[3] = pload(out[3]); +} + // preverse template <> @@ -672,6 +1317,20 @@ EIGEN_STRONG_INLINE Packet32h preverse(const Packet32h& a) { a); } +template <> +EIGEN_STRONG_INLINE Packet16h preverse(const Packet16h& a) { + __m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1); + return _mm256_castsi256_ph(_mm256_insertf128_si256( + _mm256_castsi128_si256(_mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 1), m)), + _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castph_si256(a), 0), m), 1)); +} + +template <> +EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a) { + __m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1); + return _mm_castsi128_ph(_mm_shuffle_epi8(_mm_castph_si128(a), m)); +} + // pscatter template <> @@ -684,191 +1343,68 @@ EIGEN_STRONG_INLINE void pscatter(half* to, const Packet32h& fr to[stride * i] = aux[i]; } } +template <> +EIGEN_STRONG_INLINE void pscatter(half* to, const Packet16h& from, Index stride) { + EIGEN_ALIGN64 half aux[16]; + pstore(aux, from); + to[stride * 0] = aux[0]; + to[stride * 1] = aux[1]; + to[stride * 2] = aux[2]; + to[stride * 3] = aux[3]; + to[stride * 4] = aux[4]; + to[stride * 5] = aux[5]; + to[stride * 6] = aux[6]; + to[stride * 7] = aux[7]; + to[stride * 8] = aux[8]; + to[stride * 9] = aux[9]; + to[stride * 10] = aux[10]; + to[stride * 11] = aux[11]; + to[stride * 12] = aux[12]; + to[stride * 13] = aux[13]; + to[stride * 14] = aux[14]; + to[stride * 15] = aux[15]; +} + +template <> +EIGEN_STRONG_INLINE void pscatter(Eigen::half* to, const Packet8h& from, Index stride) { + EIGEN_ALIGN32 Eigen::half aux[8]; + pstore(aux, from); + to[stride * 0] = aux[0]; + to[stride * 1] = aux[1]; + to[stride * 2] = aux[2]; + to[stride * 3] = aux[3]; + to[stride * 4] = aux[4]; + to[stride * 5] = aux[5]; + to[stride * 6] = aux[6]; + to[stride * 7] = aux[7]; +} // pgather template <> EIGEN_STRONG_INLINE Packet32h pgather(const Eigen::half* from, Index stride) { - return _mm512_castsi512_ph(_mm512_set_epi16( - from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x, from[27 * stride].x, - from[26 * stride].x, from[25 * stride].x, from[24 * stride].x, from[23 * stride].x, from[22 * stride].x, - from[21 * stride].x, from[20 * stride].x, from[19 * stride].x, from[18 * stride].x, from[17 * stride].x, - from[16 * stride].x, from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x, - from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x, from[7 * stride].x, - from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x, from[2 * stride].x, - from[1 * stride].x, from[0 * stride].x)); + return _mm512_set_ph(from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x, + from[27 * stride].x, from[26 * stride].x, from[25 * stride].x, from[24 * stride].x, + from[23 * stride].x, from[22 * stride].x, from[21 * stride].x, from[20 * stride].x, + from[19 * stride].x, from[18 * stride].x, from[17 * stride].x, from[16 * stride].x, + from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x, + from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x, + from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, + from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x); } template <> -EIGEN_STRONG_INLINE Packet16h pcos(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h psin(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h plog(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h plog2(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h plog1p(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h pexp(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h pexpm1(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h ptanh(const Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h pfrexp(const Packet16h&, Packet16h&); -template <> -EIGEN_STRONG_INLINE Packet16h pldexp(const Packet16h&, const Packet16h&); - -EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) { - __m512d result = _mm512_undefined_pd(); - result = _mm512_insertf64x4(result, _mm256_castsi256_pd(a), 0); - result = _mm512_insertf64x4(result, _mm256_castsi256_pd(b), 1); - return _mm512_castpd_ph(result); +EIGEN_STRONG_INLINE Packet16h pgather(const Eigen::half* from, Index stride) { + return _mm256_set_ph(from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x, + from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x, + from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, + from[3 * stride].x, from[2 * stride].x, from[1 * stride].x, from[0 * stride].x); } -EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) { - a = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 0)); - b = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 1)); -} - -// psin template <> -EIGEN_STRONG_INLINE Packet32h psin(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = psin(low); - Packet16h highOut = psin(high); - - return combine2Packet16h(lowOut, highOut); -} - -// pcos -template <> -EIGEN_STRONG_INLINE Packet32h pcos(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = pcos(low); - Packet16h highOut = pcos(high); - - return combine2Packet16h(lowOut, highOut); -} - -// plog -template <> -EIGEN_STRONG_INLINE Packet32h plog(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = plog(low); - Packet16h highOut = plog(high); - - return combine2Packet16h(lowOut, highOut); -} - -// plog2 -template <> -EIGEN_STRONG_INLINE Packet32h plog2(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = plog2(low); - Packet16h highOut = plog2(high); - - return combine2Packet16h(lowOut, highOut); -} - -// plog1p -template <> -EIGEN_STRONG_INLINE Packet32h plog1p(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = plog1p(low); - Packet16h highOut = plog1p(high); - - return combine2Packet16h(lowOut, highOut); -} - -// pexp -template <> -EIGEN_STRONG_INLINE Packet32h pexp(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = pexp(low); - Packet16h highOut = pexp(high); - - return combine2Packet16h(lowOut, highOut); -} - -// pexpm1 -template <> -EIGEN_STRONG_INLINE Packet32h pexpm1(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = pexpm1(low); - Packet16h highOut = pexpm1(high); - - return combine2Packet16h(lowOut, highOut); -} - -// ptanh -template <> -EIGEN_STRONG_INLINE Packet32h ptanh(const Packet32h& a) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h lowOut = ptanh(low); - Packet16h highOut = ptanh(high); - - return combine2Packet16h(lowOut, highOut); -} - -// pfrexp -template <> -EIGEN_STRONG_INLINE Packet32h pfrexp(const Packet32h& a, Packet32h& exponent) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h exp1 = _mm256_undefined_si256(); - Packet16h exp2 = _mm256_undefined_si256(); - - Packet16h lowOut = pfrexp(low, exp1); - Packet16h highOut = pfrexp(high, exp2); - - exponent = combine2Packet16h(exp1, exp2); - - return combine2Packet16h(lowOut, highOut); -} - -// pldexp -template <> -EIGEN_STRONG_INLINE Packet32h pldexp(const Packet32h& a, const Packet32h& exponent) { - Packet16h low; - Packet16h high; - extract2Packet16h(a, low, high); - - Packet16h exp1; - Packet16h exp2; - extract2Packet16h(exponent, exp1, exp2); - - Packet16h lowOut = pldexp(low, exp1); - Packet16h highOut = pldexp(high, exp2); - - return combine2Packet16h(lowOut, highOut); +EIGEN_STRONG_INLINE Packet8h pgather(const Eigen::half* from, Index stride) { + return _mm_set_ph(from[7 * stride].x, from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x, + from[2 * stride].x, from[1 * stride].x, from[0 * stride].x); } } // end namespace internal diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index 9508ac66b..fc55fd861 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -237,16 +237,12 @@ EIGEN_STRONG_INLINE Packet4i preinterpret(const Packet16i& return _mm512_castsi512_si128(a); } +#ifndef EIGEN_VECTORIZE_AVX512FP16 template <> EIGEN_STRONG_INLINE Packet8h preinterpret(const Packet16h& a) { return _mm256_castsi256_si128(a); } -template <> -EIGEN_STRONG_INLINE Packet8bf preinterpret(const Packet16bf& a) { - return _mm256_castsi256_si128(a); -} - template <> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16h& a) { return half2float(a); @@ -257,6 +253,13 @@ EIGEN_STRONG_INLINE Packet16h pcast(const Packet16f& a) { return float2half(a); } +#endif + +template <> +EIGEN_STRONG_INLINE Packet8bf preinterpret(const Packet16bf& a) { + return _mm256_castsi256_si128(a); +} + template <> EIGEN_STRONG_INLINE Packet16f pcast(const Packet16bf& a) { return Bf16ToF32(a); @@ -267,68 +270,6 @@ EIGEN_STRONG_INLINE Packet16bf pcast(const Packet16f& a) return F32ToBf16(a); } -#ifdef EIGEN_VECTORIZE_AVX512FP16 - -template <> -EIGEN_STRONG_INLINE Packet16h preinterpret(const Packet32h& a) { - return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0)); -} -template <> -EIGEN_STRONG_INLINE Packet8h preinterpret(const Packet32h& a) { - return _mm256_castsi256_si128(preinterpret(a)); -} - -template <> -EIGEN_STRONG_INLINE Packet16f pcast(const Packet32h& a) { - // Discard second-half of input. - Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0)); - return _mm512_cvtxph_ps(_mm256_castsi256_ph(low)); -} - -template <> -EIGEN_STRONG_INLINE Packet32h pcast(const Packet16f& a, const Packet16f& b) { - __m512d result = _mm512_undefined_pd(); - result = _mm512_insertf64x4( - result, _mm256_castsi256_pd(_mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0); - result = _mm512_insertf64x4( - result, _mm256_castsi256_pd(_mm512_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1); - return _mm512_castpd_ph(result); -} - -template <> -EIGEN_STRONG_INLINE Packet8f pcast(const Packet16h& a) { - // Discard second-half of input. - Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(a), 0)); - return _mm256_cvtxph_ps(_mm_castsi128_ph(low)); -} - -template <> -EIGEN_STRONG_INLINE Packet16h pcast(const Packet8f& a, const Packet8f& b) { - __m256d result = _mm256_undefined_pd(); - result = _mm256_insertf64x2(result, - _mm_castsi128_pd(_mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0); - result = _mm256_insertf64x2(result, - _mm_castsi128_pd(_mm256_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1); - return _mm256_castpd_si256(result); -} - -template <> -EIGEN_STRONG_INLINE Packet4f pcast(const Packet8h& a) { - Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(a)); - // Discard second-half of input. - return _mm256_extractf32x4_ps(full, 0); -} - -template <> -EIGEN_STRONG_INLINE Packet8h pcast(const Packet4f& a, const Packet4f& b) { - __m256 result = _mm256_undefined_ps(); - result = _mm256_insertf128_ps(result, a, 0); - result = _mm256_insertf128_ps(result, b, 1); - return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT); -} - -#endif - } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h b/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h new file mode 100644 index 000000000..f06f13df9 --- /dev/null +++ b/Eigen/src/Core/arch/AVX512/TypeCastingFP16.h @@ -0,0 +1,130 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2025 The Eigen Authors. +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_TYPE_CASTING_FP16_AVX512_H +#define EIGEN_TYPE_CASTING_FP16_AVX512_H + +// IWYU pragma: private +#include "../../InternalHeaderCheck.h" + +namespace Eigen { +namespace internal { + +template <> +EIGEN_STRONG_INLINE Packet32s preinterpret(const Packet32h& a) { + return _mm512_castph_si512(a); +} +template <> +EIGEN_STRONG_INLINE Packet16s preinterpret(const Packet16h& a) { + return _mm256_castph_si256(a); +} +template <> +EIGEN_STRONG_INLINE Packet8s preinterpret(const Packet8h& a) { + return _mm_castph_si128(a); +} + +template <> +EIGEN_STRONG_INLINE Packet32h preinterpret(const Packet32s& a) { + return _mm512_castsi512_ph(a); +} +template <> +EIGEN_STRONG_INLINE Packet16h preinterpret(const Packet16s& a) { + return _mm256_castsi256_ph(a); +} +template <> +EIGEN_STRONG_INLINE Packet8h preinterpret(const Packet8s& a) { + return _mm_castsi128_ph(a); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pcast(const Packet16h& a) { + return half2float(a); +} +template <> +EIGEN_STRONG_INLINE Packet8f pcast(const Packet8h& a) { + return half2float(a); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pcast(const Packet16f& a) { + return float2half(a); +} +template <> +EIGEN_STRONG_INLINE Packet8h pcast(const Packet8f& a) { + return float2half(a); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pcast(const Packet32h& a) { + // Discard second-half of input. + Packet16h low = _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0)); + return _mm512_cvtxph_ps(low); +} +template <> +EIGEN_STRONG_INLINE Packet8f pcast(const Packet16h& a) { + // Discard second-half of input. + Packet8h low = _mm_castps_ph(_mm256_extractf32x4_ps(_mm256_castph_ps(a), 0)); + return _mm256_cvtxph_ps(low); +} +template <> +EIGEN_STRONG_INLINE Packet4f pcast(const Packet8h& a) { + Packet8f full = _mm256_cvtxph_ps(a); + // Discard second-half of input. + return _mm256_extractf32x4_ps(full, 0); +} + +template <> +EIGEN_STRONG_INLINE Packet32h pcast(const Packet16f& a, const Packet16f& b) { + __m512 result = _mm512_castsi512_ps(_mm512_castsi256_si512(_mm256_castph_si256(_mm512_cvtxps_ph(a)))); + result = _mm512_insertf32x8(result, _mm256_castph_ps(_mm512_cvtxps_ph(b)), 1); + return _mm512_castps_ph(result); +} +template <> +EIGEN_STRONG_INLINE Packet16h pcast(const Packet8f& a, const Packet8f& b) { + __m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castph_si128(_mm256_cvtxps_ph(a)))); + result = _mm256_insertf32x4(result, _mm_castph_ps(_mm256_cvtxps_ph(b)), 1); + return _mm256_castps_ph(result); +} +template <> +EIGEN_STRONG_INLINE Packet8h pcast(const Packet4f& a, const Packet4f& b) { + __m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castps_si128(a))); + result = _mm256_insertf128_ps(result, b, 1); + return _mm256_cvtxps_ph(result); +} + +template <> +EIGEN_STRONG_INLINE Packet32s pcast(const Packet32h& a) { + return _mm512_cvtph_epi16(a); +} +template <> +EIGEN_STRONG_INLINE Packet16s pcast(const Packet16h& a) { + return _mm256_cvtph_epi16(a); +} +template <> +EIGEN_STRONG_INLINE Packet8s pcast(const Packet8h& a) { + return _mm_cvtph_epi16(a); +} + +template <> +EIGEN_STRONG_INLINE Packet32h pcast(const Packet32s& a) { + return _mm512_cvtepi16_ph(a); +} +template <> +EIGEN_STRONG_INLINE Packet16h pcast(const Packet16s& a) { + return _mm256_cvtepi16_ph(a); +} +template <> +EIGEN_STRONG_INLINE Packet8h pcast(const Packet8s& a) { + return _mm_cvtepi16_ph(a); +} + +} // namespace internal +} // namespace Eigen + +#endif // EIGEN_TYPE_CASTING_FP16_AVX512_H diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index 95697f3cf..d8c9d5a77 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -37,21 +37,23 @@ // IWYU pragma: private #include "../../InternalHeaderCheck.h" -#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) // When compiling with GPU support, the "__half_raw" base class as well as // some other routines are defined in the GPU compiler header files // (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr // As a consequence, we get compile failures when compiling Eigen with // GPU support. Hence the need to disable EIGEN_CONSTEXPR when building -// Eigen with GPU support -#pragma push_macro("EIGEN_CONSTEXPR") -#undef EIGEN_CONSTEXPR -#define EIGEN_CONSTEXPR +// Eigen with GPU support. +// Any functions that require `numext::bit_cast` may also not be constexpr, +// including any native types when setting via raw bit values. +#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16) +#define _EIGEN_MAYBE_CONSTEXPR +#else +#define _EIGEN_MAYBE_CONSTEXPR constexpr #endif #define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \ template <> \ - EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET_F16 METHOD(const PACKET_F16& _x) { \ + EIGEN_UNUSED EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PACKET_F16 METHOD(const PACKET_F16& _x) { \ return float2half(METHOD(half2float(_x))); \ } @@ -81,8 +83,10 @@ namespace half_impl { // Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves // this error, and hence the following convoluted #if condition #if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE) + // Make our own __half_raw definition that is similar to CUDA's. struct __half_raw { + struct construct_from_rep_tag {}; #if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE)) // Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF) // The element type for shared memory cannot have non-trivial constructors @@ -91,43 +95,53 @@ struct __half_raw { // hence the need for this EIGEN_DEVICE_FUNC __half_raw() {} #else - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw() : x(0) {} #endif + #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) - explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {} + explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {} + EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, __fp16 rep) : x{rep} {} __fp16 x; +#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) + explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<_Float16>(raw)) {} + EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, _Float16 rep) : x{rep} {} + _Float16 x; #else - explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {} + explicit EIGEN_DEVICE_FUNC constexpr __half_raw(numext::uint16_t raw) : x(raw) {} + EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, numext::uint16_t rep) : x{rep} {} numext::uint16_t x; #endif }; #elif defined(EIGEN_HAS_HIP_FP16) -// Nothing to do here +// HIP GPU compile phase: nothing to do here. // HIP fp16 header file has a definition for __half_raw #elif defined(EIGEN_HAS_CUDA_FP16) + +// CUDA GPU compile phase. #if EIGEN_CUDA_SDK_VER < 90000 // In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw typedef __half __half_raw; #endif // defined(EIGEN_HAS_CUDA_FP16) + #elif defined(SYCL_DEVICE_ONLY) typedef cl::sycl::half __half_raw; #endif -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x); +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x); EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff); EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h); struct half_base : public __half_raw { - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {} - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base() {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {} #if defined(EIGEN_HAS_GPU_FP16) #if defined(EIGEN_HAS_HIP_FP16) - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); } + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); } #elif defined(EIGEN_HAS_CUDA_FP16) #if EIGEN_CUDA_SDK_VER >= 90000 - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {} #endif #endif #endif @@ -156,21 +170,29 @@ struct half : public half_impl::half_base { #endif #endif - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half() {} - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {} #if defined(EIGEN_HAS_GPU_FP16) #if defined(EIGEN_HAS_HIP_FP16) - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {} #elif defined(EIGEN_HAS_CUDA_FP16) #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000 - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {} + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {} #endif #endif #endif - explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b) +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(__fp16 b) + : half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {} +#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) + explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(_Float16 b) + : half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {} +#endif + + explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(bool b) : half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {} template explicit EIGEN_DEVICE_FUNC half(T val) @@ -201,99 +223,99 @@ struct half : public half_impl::half_base { namespace half_impl { template struct numeric_limits_half_impl { - static EIGEN_CONSTEXPR const bool is_specialized = true; - static EIGEN_CONSTEXPR const bool is_signed = true; - static EIGEN_CONSTEXPR const bool is_integer = false; - static EIGEN_CONSTEXPR const bool is_exact = false; - static EIGEN_CONSTEXPR const bool has_infinity = true; - static EIGEN_CONSTEXPR const bool has_quiet_NaN = true; - static EIGEN_CONSTEXPR const bool has_signaling_NaN = true; + static constexpr const bool is_specialized = true; + static constexpr const bool is_signed = true; + static constexpr const bool is_integer = false; + static constexpr const bool is_exact = false; + static constexpr const bool has_infinity = true; + static constexpr const bool has_quiet_NaN = true; + static constexpr const bool has_signaling_NaN = true; EIGEN_DIAGNOSTICS(push) EIGEN_DISABLE_DEPRECATED_WARNING - static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm = std::denorm_present; - static EIGEN_CONSTEXPR const bool has_denorm_loss = false; + static constexpr const std::float_denorm_style has_denorm = std::denorm_present; + static constexpr const bool has_denorm_loss = false; EIGEN_DIAGNOSTICS(pop) - static EIGEN_CONSTEXPR const std::float_round_style round_style = std::round_to_nearest; - static EIGEN_CONSTEXPR const bool is_iec559 = true; + static constexpr const std::float_round_style round_style = std::round_to_nearest; + static constexpr const bool is_iec559 = true; // The C++ standard defines this as "true if the set of values representable // by the type is finite." Half has finite precision. - static EIGEN_CONSTEXPR const bool is_bounded = true; - static EIGEN_CONSTEXPR const bool is_modulo = false; - static EIGEN_CONSTEXPR const int digits = 11; - static EIGEN_CONSTEXPR const int digits10 = + static constexpr const bool is_bounded = true; + static constexpr const bool is_modulo = false; + static constexpr const int digits = 11; + static constexpr const int digits10 = 3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html - static EIGEN_CONSTEXPR const int max_digits10 = + static constexpr const int max_digits10 = 5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html - static EIGEN_CONSTEXPR const int radix = std::numeric_limits::radix; - static EIGEN_CONSTEXPR const int min_exponent = -13; - static EIGEN_CONSTEXPR const int min_exponent10 = -4; - static EIGEN_CONSTEXPR const int max_exponent = 16; - static EIGEN_CONSTEXPR const int max_exponent10 = 4; - static EIGEN_CONSTEXPR const bool traps = std::numeric_limits::traps; + static constexpr const int radix = std::numeric_limits::radix; + static constexpr const int min_exponent = -13; + static constexpr const int min_exponent10 = -4; + static constexpr const int max_exponent = 16; + static constexpr const int max_exponent10 = 4; + static constexpr const bool traps = std::numeric_limits::traps; // IEEE754: "The implementer shall choose how tininess is detected, but shall // detect tininess in the same way for all operations in radix two" - static EIGEN_CONSTEXPR const bool tinyness_before = std::numeric_limits::tinyness_before; + static constexpr const bool tinyness_before = std::numeric_limits::tinyness_before; - static EIGEN_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); } - static EIGEN_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); } - static EIGEN_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); } - static EIGEN_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); } - static EIGEN_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); } - static EIGEN_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); } - static EIGEN_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); } - static EIGEN_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); } - static EIGEN_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); } + static _EIGEN_MAYBE_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); } }; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::is_specialized; +constexpr const bool numeric_limits_half_impl::is_specialized; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::is_signed; +constexpr const bool numeric_limits_half_impl::is_signed; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::is_integer; +constexpr const bool numeric_limits_half_impl::is_integer; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::is_exact; +constexpr const bool numeric_limits_half_impl::is_exact; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::has_infinity; +constexpr const bool numeric_limits_half_impl::has_infinity; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::has_quiet_NaN; +constexpr const bool numeric_limits_half_impl::has_quiet_NaN; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::has_signaling_NaN; +constexpr const bool numeric_limits_half_impl::has_signaling_NaN; EIGEN_DIAGNOSTICS(push) EIGEN_DISABLE_DEPRECATED_WARNING template -EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_half_impl::has_denorm; +constexpr const std::float_denorm_style numeric_limits_half_impl::has_denorm; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::has_denorm_loss; +constexpr const bool numeric_limits_half_impl::has_denorm_loss; EIGEN_DIAGNOSTICS(pop) template -EIGEN_CONSTEXPR const std::float_round_style numeric_limits_half_impl::round_style; +constexpr const std::float_round_style numeric_limits_half_impl::round_style; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::is_iec559; +constexpr const bool numeric_limits_half_impl::is_iec559; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::is_bounded; +constexpr const bool numeric_limits_half_impl::is_bounded; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::is_modulo; +constexpr const bool numeric_limits_half_impl::is_modulo; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::digits; +constexpr const int numeric_limits_half_impl::digits; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::digits10; +constexpr const int numeric_limits_half_impl::digits10; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::max_digits10; +constexpr const int numeric_limits_half_impl::max_digits10; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::radix; +constexpr const int numeric_limits_half_impl::radix; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::min_exponent; +constexpr const int numeric_limits_half_impl::min_exponent; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::min_exponent10; +constexpr const int numeric_limits_half_impl::min_exponent10; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::max_exponent; +constexpr const int numeric_limits_half_impl::max_exponent; template -EIGEN_CONSTEXPR const int numeric_limits_half_impl::max_exponent10; +constexpr const int numeric_limits_half_impl::max_exponent10; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::traps; +constexpr const bool numeric_limits_half_impl::traps; template -EIGEN_CONSTEXPR const bool numeric_limits_half_impl::tinyness_before; +constexpr const bool numeric_limits_half_impl::tinyness_before; } // end namespace half_impl } // end namespace Eigen @@ -320,8 +342,7 @@ namespace half_impl { (defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE)) // Note: We deliberately do *not* define this to 1 even if we have Arm's native // fp16 type since GPU half types are rather different from native CPU half types. -// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16 -#define EIGEN_HAS_NATIVE_FP16 +#define EIGEN_HAS_NATIVE_GPU_FP16 #endif // Intrinsics for native fp16 support. Note that on current hardware, @@ -329,7 +350,7 @@ namespace half_impl { // versions to get the ALU speed increased), but you do save the // conversion steps back and forth. -#if defined(EIGEN_HAS_NATIVE_FP16) +#if defined(EIGEN_HAS_NATIVE_GPU_FP16) EIGEN_STRONG_INLINE __device__ half operator+(const half& a, const half& b) { #if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000 return __hadd(::__half(a), ::__half(b)); @@ -371,7 +392,8 @@ EIGEN_STRONG_INLINE __device__ bool operator<(const half& a, const half& b) { re EIGEN_STRONG_INLINE __device__ bool operator<=(const half& a, const half& b) { return __hle(a, b); } EIGEN_STRONG_INLINE __device__ bool operator>(const half& a, const half& b) { return __hgt(a, b); } EIGEN_STRONG_INLINE __device__ bool operator>=(const half& a, const half& b) { return __hge(a, b); } -#endif + +#endif // EIGEN_HAS_NATIVE_GPU_FP16 #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) && !defined(EIGEN_GPU_COMPILE_PHASE) EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(vaddh_f16(a.x, b.x)); } @@ -401,16 +423,47 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return vcleh_f16(a.x, b.x); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return vcgth_f16(a.x, b.x); } EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return vcgeh_f16(a.x, b.x); } + +#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) && !defined(EIGEN_GPU_COMPILE_PHASE) + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(a.x + b.x); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator*(const half& a, const half& b) { return half(a.x * b.x); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a, const half& b) { return half(a.x - b.x); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, const half& b) { return half(a.x / b.x); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a) { return half(-a.x); } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator+=(half& a, const half& b) { + a = a + b; + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator*=(half& a, const half& b) { + a = a * b; + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator-=(half& a, const half& b) { + a = a - b; + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) { + a = a / b; + return a; +} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) { return a.x == b.x; } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return a.x != b.x; } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return a.x < b.x; } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return a.x <= b.x; } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return a.x > b.x; } +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return a.x >= b.x; } + // We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler, // invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation // of the functions, while the latter can only deal with one of them. -#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats +#elif !defined(EIGEN_HAS_NATIVE_GPU_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats #if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC) // We need to provide emulated *host-side* FP16 operators for clang. #pragma push_macro("EIGEN_DEVICE_FUNC") #undef EIGEN_DEVICE_FUNC -#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16) +#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_GPU_FP16) #define EIGEN_DEVICE_FUNC __host__ #else // both host and device need emulated ops. #define EIGEN_DEVICE_FUNC __host__ __device__ @@ -458,6 +511,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& #if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC) #pragma pop_macro("EIGEN_DEVICE_FUNC") #endif + #endif // Emulate support for half floats // Division by an index. Do it in full float precision to avoid accuracy @@ -493,7 +547,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) { // these in hardware. If we need more performance on older/other CPUs, they are // also possible to vectorize directly. -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) { +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) { // We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type // in the hip_fp16 header file, and that will trigger a compile error // On the other hand, having anything but a return statement also triggers a compile error @@ -515,6 +569,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const // For SYCL, cl::sycl::half is _Float16, so cast directly. #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) return numext::bit_cast(h.x); +#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) + return numext::bit_cast(h.x); #elif defined(SYCL_DEVICE_ONLY) return numext::bit_cast(h); #else @@ -528,6 +584,16 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) { __half tmp_ff = __float2half(ff); return *(__half_raw*)&tmp_ff; +#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + __half_raw h; + h.x = static_cast<__fp16>(ff); + return h; + +#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) + __half_raw h; + h.x = static_cast<_Float16>(ff); + return h; + #elif defined(EIGEN_HAS_FP16_C) __half_raw h; #if EIGEN_COMP_MSVC @@ -538,11 +604,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) { #endif return h; -#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) - __half_raw h; - h.x = static_cast<__fp16>(ff); - return h; - #else uint32_t f_bits = Eigen::numext::bit_cast(ff); const uint32_t f32infty_bits = {255 << 23}; @@ -595,6 +656,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) { #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) return __half2float(h); +#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16) + return static_cast(h.x); #elif defined(EIGEN_HAS_FP16_C) #if EIGEN_COMP_MSVC // MSVC does not have scalar instructions. @@ -602,8 +665,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) { #else return _cvtsh_ss(h.x); #endif -#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) - return static_cast(h.x); #else const float magic = Eigen::numext::bit_cast(static_cast(113 << 23)); const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift @@ -628,7 +689,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) { // --- standard functions --- EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const half& a) { -#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16) return (numext::bit_cast(a.x) & 0x7fff) == 0x7c00; #else return (a.x & 0x7fff) == 0x7c00; @@ -638,7 +699,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) { #if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \ (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) return __hisnan(a); -#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) +#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16) return (numext::bit_cast(a.x) & 0x7fff) > 0x7c00; #else return (a.x & 0x7fff) > 0x7c00; @@ -651,6 +712,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) { EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) { #if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) return half(vabsh_f16(a.x)); +#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) + half result; + result.x = + numext::bit_cast<_Float16>(static_cast(numext::bit_cast(a.x) & 0x7FFF)); + return result; #else half result; result.x = a.x & 0x7FFF; @@ -734,26 +800,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) { return half(::fmodf(float(a), float(b))); } -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) { -#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __hlt(b, a) ? b : a; -#else - const float f1 = static_cast(a); - const float f2 = static_cast(b); - return f2 < f1 ? b : a; -#endif -} -EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { -#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \ - (defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE)) - return __hlt(a, b) ? b : a; -#else - const float f1 = static_cast(a); - const float f2 = static_cast(b); - return f1 < f2 ? b : a; -#endif -} +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) { return b < a ? b : a; } + +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { return a < b ? b : a; } #ifndef EIGEN_NO_IO EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) { @@ -794,31 +843,29 @@ template <> struct NumTraits : GenericNumTraits { enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false }; - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() { + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() { return half_impl::raw_uint16_to_half(0x0800); } - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() { + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() { return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f); } - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() { + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() { return half_impl::raw_uint16_to_half(0x7bff); } - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() { + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() { return half_impl::raw_uint16_to_half(0xfbff); } - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() { + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() { return half_impl::raw_uint16_to_half(0x7c00); } - EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() { + EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() { return half_impl::raw_uint16_to_half(0x7e00); } }; } // end namespace Eigen -#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) -#pragma pop_macro("EIGEN_CONSTEXPR") -#endif +#undef _EIGEN_MAYBE_CONSTEXPR namespace Eigen { namespace numext { @@ -976,6 +1023,36 @@ struct cast_impl { } }; +#ifdef EIGEN_VECTORIZE_FMA + +template <> +EIGEN_DEVICE_FUNC inline half pmadd(const half& a, const half& b, const half& c) { +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + return half(vfmah_f16(a.x, b.x, c.x)); +#elif defined(EIGEN_VECTORIZE_AVX512FP16) + // Reduces to vfmadd213sh. + return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x)))); +#else + // Emulate FMA via float. + return half(static_cast(a) * static_cast(b) + static_cast(c)); +#endif +} + +template <> +EIGEN_DEVICE_FUNC inline half pmsub(const half& a, const half& b, const half& c) { +#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) + return half(vfmah_f16(a.x, b.x, -c.x)); +#elif defined(EIGEN_VECTORIZE_AVX512FP16) + // Reduces to vfmadd213sh. + return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), -_mm_set_sh(c.x)))); +#else + // Emulate FMA via float. + return half(static_cast(a) * static_cast(b) - static_cast(c)); +#endif +} + +#endif + } // namespace internal } // namespace Eigen diff --git a/Eigen/src/Core/util/ConfigureVectorization.h b/Eigen/src/Core/util/ConfigureVectorization.h index 5d3f1cf1a..49f307c73 100644 --- a/Eigen/src/Core/util/ConfigureVectorization.h +++ b/Eigen/src/Core/util/ConfigureVectorization.h @@ -285,6 +285,8 @@ #ifdef __AVX512FP16__ #ifdef __AVX512VL__ #define EIGEN_VECTORIZE_AVX512FP16 +// Built-in _Float16. +#define EIGEN_HAS_BUILTIN_FLOAT16 1 #else #if EIGEN_COMP_GNUC #error Please add -mavx512vl to your compiler flags: compiling with -mavx512fp16 alone without AVX512-VL is not supported. diff --git a/test/packet_ostream.h b/test/packet_ostream.h index 49e1bb076..4a3ee9caf 100644 --- a/test/packet_ostream.h +++ b/test/packet_ostream.h @@ -7,7 +7,8 @@ // Include this header to be able to print Packets while debugging. template ::vectorizable> > + typename EnableIf = std::enable_if_t<(Eigen::internal::unpacket_traits::vectorizable || + Eigen::internal::unpacket_traits::size > 1)> > std::ostream& operator<<(std::ostream& os, const Packet& packet) { using Scalar = typename Eigen::internal::unpacket_traits::type; Scalar v[Eigen::internal::unpacket_traits::size]; diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 102817f02..64c55fbfd 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -26,19 +26,19 @@ inline T REF_MUL(const T& a, const T& b) { } template inline T REF_MADD(const T& a, const T& b, const T& c) { - return a * b + c; + return internal::pmadd(a, b, c); } template inline T REF_MSUB(const T& a, const T& b, const T& c) { - return a * b - c; + return internal::pmsub(a, b, c); } template inline T REF_NMADD(const T& a, const T& b, const T& c) { - return c - a * b; + return internal::pnmadd(a, b, c); } template inline T REF_NMSUB(const T& a, const T& b, const T& c) { - return test::negate(a * b + c); + return internal::pnmsub(a, b, c); } template inline T REF_DIV(const T& a, const T& b) {