Use native _Float16 for AVX512FP16 and update vectorization.

This allows us to do faster native scalar operations.  Also
updated half/quarter packets to use the native type if available.

Benchmark improvement:
```
Comparing ./2910_without_float16 to ./2910_with_float16
Benchmark                                               Time             CPU      Time Old      Time New       CPU Old       CPU New
------------------------------------------------------------------------------------------------------------------------------------
BM_CalcMat<float>/10000/768/500                      -0.0041         -0.0040      58276392      58039442      58273420      58039582
BM_CalcMat<_Float16>/10000/768/500                   +0.0073         +0.0073     642506339     647214446     642481384     647188303
BM_CalcMat<Eigen::half>/10000/768/500                -0.3170         -0.3170      92511115      63182101      92506771      63179258
BM_CalcVec<float>/10000/768/500                      +0.0022         +0.0022       5198157       5209469       5197913       5209334
BM_CalcVec<_Float16>/10000/768/500                   +0.0025         +0.0026      10133324      10159111      10132641      10158507
BM_CalcVec<Eigen::half>/10000/768/500                -0.7760         -0.7760      45337937      10156952      45336532      10156389
OVERALL_GEOMEAN                                      -0.2677         -0.2677             0             0             0             0
```

Fixes #2910.
This commit is contained in:
Antonio Sanchez 2025-03-16 20:58:59 -07:00
parent 0259a52b0e
commit 3580a38298
15 changed files with 1422 additions and 449 deletions

View File

@ -193,21 +193,27 @@ using std::ptrdiff_t;
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
#if defined EIGEN_VECTORIZE_AVX512
#include "src/Core/arch/SSE/PacketMath.h"
#include "src/Core/arch/AVX/PacketMath.h"
#include "src/Core/arch/AVX512/PacketMath.h"
#if defined EIGEN_VECTORIZE_AVX512FP16
#include "src/Core/arch/AVX512/PacketMathFP16.h"
#endif
#include "src/Core/arch/SSE/PacketMath.h"
#include "src/Core/arch/SSE/TypeCasting.h"
#include "src/Core/arch/SSE/Complex.h"
#include "src/Core/arch/AVX/PacketMath.h"
#include "src/Core/arch/AVX/TypeCasting.h"
#include "src/Core/arch/AVX/Complex.h"
#include "src/Core/arch/AVX512/PacketMath.h"
#include "src/Core/arch/AVX512/TypeCasting.h"
#if defined EIGEN_VECTORIZE_AVX512FP16
#include "src/Core/arch/AVX512/TypeCastingFP16.h"
#endif
#include "src/Core/arch/SSE/Complex.h"
#include "src/Core/arch/AVX/Complex.h"
#include "src/Core/arch/AVX512/Complex.h"
#include "src/Core/arch/SSE/MathFunctions.h"
#include "src/Core/arch/AVX/MathFunctions.h"
#include "src/Core/arch/AVX512/MathFunctions.h"
#if defined EIGEN_VECTORIZE_AVX512FP16
#include "src/Core/arch/AVX512/MathFunctionsFP16.h"
#endif
#include "src/Core/arch/AVX512/TrsmKernel.h"
#elif defined EIGEN_VECTORIZE_AVX
// Use AVX for floats and doubles, SSE for integers

View File

@ -76,7 +76,7 @@ struct generic_rsqrt_newton_step {
static_assert(Steps > 0, "Steps must be at least 1.");
using Scalar = typename unpacket_traits<Packet>::type;
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) {
constexpr Scalar kMinusHalf = Scalar(-1) / Scalar(2);
const Scalar kMinusHalf = Scalar(-1) / Scalar(2);
const Packet cst_minus_half = pset1<Packet>(kMinusHalf);
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));

View File

@ -106,6 +106,8 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp2)
@ -118,6 +120,7 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
#endif
} // end namespace internal

View File

@ -1839,10 +1839,13 @@ EIGEN_STRONG_INLINE Packet8ui pabs(const Packet8ui& a) {
return a;
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) {
return _mm_cmpgt_epi16(_mm_setzero_si128(), a);
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) {
return _mm_cmpgt_epi16(_mm_setzero_si128(), a);
@ -2044,10 +2047,13 @@ EIGEN_STRONG_INLINE bool predux_any(const Packet8ui& x) {
return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0;
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8h& x) {
return _mm_movemask_epi8(x) != 0;
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8bf& x) {
return _mm_movemask_epi8(x) != 0;
@ -2211,7 +2217,6 @@ struct unpacket_traits<Packet8h> {
};
typedef Packet8h half;
};
#endif
template <>
EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
@ -2446,14 +2451,12 @@ EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const
to[stride * 7] = aux[7];
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
Packet8f af = half2float(a);
float reduced = predux<Packet8f>(af);
return Eigen::half(reduced);
}
#endif
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
@ -2553,6 +2556,8 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 4>& kernel) {
kernel.packet[3] = pload<Packet8h>(out[3]);
}
#endif
// BFloat16 implementation.
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {

View File

@ -279,20 +279,22 @@ EIGEN_STRONG_INLINE Packet2l preinterpret<Packet2l, Packet4l>(const Packet4l& a)
}
#endif
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
return half2float(a);
}
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
return Bf16ToF32(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
return float2half(a);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
return Bf16ToF32(a);
}
template <>
EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {

View File

@ -47,16 +47,16 @@ EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exp
#if EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt<Packet16f>(const Packet16f& _x) {
return generic_sqrt_newton_step<Packet16f>::run(_x, _mm512_rsqrt14_ps(_x));
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt<Packet16f>(const Packet16f& x) {
return generic_sqrt_newton_step<Packet16f>::run(x, _mm512_rsqrt14_ps(x));
}
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt<Packet8d>(const Packet8d& _x) {
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt<Packet8d>(const Packet8d& x) {
#ifdef EIGEN_VECTORIZE_AVX512ER
return generic_sqrt_newton_step<Packet8d, /*Steps=*/1>::run(_x, _mm512_rsqrt28_pd(_x));
return generic_sqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
#else
return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(_x, _mm512_rsqrt14_pd(_x));
return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
#endif
}
#else
@ -80,19 +80,19 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
#elif EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt<Packet16f>(const Packet16f& _x) {
return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(_x, _mm512_rsqrt14_ps(_x));
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt<Packet16f>(const Packet16f& x) {
return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(x, _mm512_rsqrt14_ps(x));
}
#endif
// prsqrt for double.
#if EIGEN_FAST_MATH
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt<Packet8d>(const Packet8d& _x) {
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt<Packet8d>(const Packet8d& x) {
#ifdef EIGEN_VECTORIZE_AVX512ER
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/1>::run(_x, _mm512_rsqrt28_pd(_x));
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
#else
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/2>::run(_x, _mm512_rsqrt14_pd(_x));
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
#endif
}
@ -118,6 +118,8 @@ BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
#ifndef EIGEN_VECTORIZE_AVX512FP16
F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp2)
@ -130,6 +132,7 @@ F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
#endif // EIGEN_VECTORIZE_AVX512FP16
} // end namespace internal

View File

@ -0,0 +1,75 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
#define EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
__m512i result = _mm512_castsi256_si512(_mm256_castph_si256(a));
result = _mm512_inserti64x4(result, _mm256_castph_si256(b), 1);
return _mm512_castsi512_ph(result);
}
EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
a = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_castph_si512(x)));
b = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(_mm512_castph_si512(x), 1));
}
#define _EIGEN_GENERATE_FP16_MATH_FUNCTION(func) \
template <> \
EIGEN_STRONG_INLINE Packet8h func<Packet8h>(const Packet8h& a) { \
return float2half(func(half2float(a))); \
} \
\
template <> \
EIGEN_STRONG_INLINE Packet16h func<Packet16h>(const Packet16h& a) { \
return float2half(func(half2float(a))); \
} \
\
template <> \
EIGEN_STRONG_INLINE Packet32h func<Packet32h>(const Packet32h& a) { \
Packet16h low; \
Packet16h high; \
extract2Packet16h(a, low, high); \
return combine2Packet16h(func(low), func(high)); \
}
_EIGEN_GENERATE_FP16_MATH_FUNCTION(psin)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pcos)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog2)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog1p)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexpm1)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp2)
_EIGEN_GENERATE_FP16_MATH_FUNCTION(ptanh)
#undef _EIGEN_GENERATE_FP16_MATH_FUNCTION
// pfrexp
template <>
EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
return pfrexp_generic(a, exponent);
}
// pldexp
template <>
EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
return pldexp_generic(a, exponent);
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_MATH_FUNCTIONS_FP16_AVX512_H

View File

@ -40,6 +40,10 @@ typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
#endif
typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
typedef eigen_packet_wrapper<__m512i, 6> Packet32s;
typedef eigen_packet_wrapper<__m256i, 6> Packet16s;
typedef eigen_packet_wrapper<__m128i, 6> Packet8s;
template <>
struct is_arithmetic<__m512> {
enum { value = true };
@ -248,6 +252,39 @@ struct unpacket_traits<Packet16h> {
};
#endif
template <>
struct unpacket_traits<Packet32s> {
typedef numext::int16_t type;
typedef Packet16s half;
enum {
size = 32,
alignment = Aligned64,
vectorizable = false,
};
};
template <>
struct unpacket_traits<Packet16s> {
typedef numext::int16_t type;
typedef Packet8s half;
enum {
size = 16,
alignment = Aligned32,
vectorizable = false,
};
};
template <>
struct unpacket_traits<Packet8s> {
typedef numext::int16_t type;
typedef Packet8s half;
enum {
size = 8,
alignment = Aligned16,
vectorizable = false,
};
};
template <>
EIGEN_STRONG_INLINE Packet16f pset1<Packet16f>(const float& from) {
return _mm512_set1_ps(from);
@ -1335,10 +1372,13 @@ EIGEN_STRONG_INLINE Packet8l pabs(const Packet8l& a) {
return _mm512_abs_epi64(a);
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) {
return _mm256_srai_epi16(a, 15);
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) {
return _mm256_srai_epi16(a, 15);
@ -2199,6 +2239,7 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, const Packet8d&
}
// Packet math for Eigen::half
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
return _mm256_set1_epi16(from.x);
@ -2369,7 +2410,6 @@ EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
return _mm256_xor_si256(a, sign_mask);
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
Packet16f af = half2float(a);
@ -2408,8 +2448,6 @@ EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
return half(predux(from_float));
}
#endif
template <>
EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
Packet8h lane0 = _mm256_extractf128_si256(a, 0);
@ -2643,6 +2681,8 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 4>& kernel) {
kernel.packet[3] = pload<Packet16h>(out[3]);
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template <>
struct is_arithmetic<Packet16bf> {
enum { value = true };
@ -3095,6 +3135,158 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf, 4>& kernel) {
kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
}
// Minimal implementation of 16-bit int packets for use in pfrexp, pldexp.
template <>
EIGEN_STRONG_INLINE Packet32s pset1<Packet32s>(const numext::int16_t& x) {
return _mm512_set1_epi16(x);
}
template <>
EIGEN_STRONG_INLINE Packet16s pset1<Packet16s>(const numext::int16_t& x) {
return _mm256_set1_epi16(x);
}
template <>
EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const numext::int16_t& x) {
return _mm_set1_epi16(x);
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
_mm512_storeu_epi16(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
_mm256_storeu_epi16(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
_mm_storeu_epi16(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
_mm512_storeu_epi16(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
_mm256_storeu_epi16(out, x);
}
template <>
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
_mm_storeu_epi16(out, x);
}
template <>
EIGEN_STRONG_INLINE Packet32s padd(const Packet32s& a, const Packet32s& b) {
return _mm512_add_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16s padd(const Packet16s& a, const Packet16s& b) {
return _mm256_add_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8s padd(const Packet8s& a, const Packet8s& b) {
return _mm_add_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet32s psub(const Packet32s& a, const Packet32s& b) {
return _mm512_sub_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16s psub(const Packet16s& a, const Packet16s& b) {
return _mm256_sub_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8s psub(const Packet8s& a, const Packet8s& b) {
return _mm_sub_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet32s pmul(const Packet32s& a, const Packet32s& b) {
return _mm512_mullo_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet16s pmul(const Packet16s& a, const Packet16s& b) {
return _mm256_mullo_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8s pmul(const Packet8s& a, const Packet8s& b) {
return _mm_mullo_epi16(a, b);
}
template <>
EIGEN_STRONG_INLINE Packet32s pnegate(const Packet32s& a) {
return _mm512_sub_epi16(_mm512_setzero_si512(), a);
}
template <>
EIGEN_STRONG_INLINE Packet16s pnegate(const Packet16s& a) {
return _mm256_sub_epi16(_mm256_setzero_si256(), a);
}
template <>
EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) {
return _mm_sub_epi16(_mm_setzero_si128(), a);
}
template <int N>
EIGEN_STRONG_INLINE Packet32s parithmetic_shift_right(Packet32s a) {
return _mm512_srai_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16s parithmetic_shift_right(Packet16s a) {
return _mm256_srai_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) {
return _mm_srai_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet32s plogical_shift_left(Packet32s a) {
return _mm512_slli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16s plogical_shift_left(Packet16s a) {
return _mm256_slli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) {
return _mm_slli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet32s plogical_shift_right(Packet32s a) {
return _mm512_srli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet16s plogical_shift_right(Packet16s a) {
return _mm256_srli_epi16(a, N);
}
template <int N>
EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a) {
return _mm_srli_epi16(a, N);
}
} // end namespace internal
} // end namespace Eigen

File diff suppressed because it is too large Load Diff

View File

@ -237,16 +237,12 @@ EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet16i>(const Packet16i&
return _mm512_castsi512_si128(a);
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
return _mm256_castsi256_si128(a);
}
template <>
EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
return _mm256_castsi256_si128(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
return half2float(a);
@ -257,6 +253,13 @@ EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
return float2half(a);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
return _mm256_castsi256_si128(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
return Bf16ToF32(a);
@ -267,68 +270,6 @@ EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a)
return F32ToBf16(a);
}
#ifdef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet32h>(const Packet32h& a) {
return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
}
template <>
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet32h>(const Packet32h& a) {
return _mm256_castsi256_si128(preinterpret<Packet16h>(a));
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
// Discard second-half of input.
Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
return _mm512_cvtxph_ps(_mm256_castsi256_ph(low));
}
template <>
EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
__m512d result = _mm512_undefined_pd();
result = _mm512_insertf64x4(
result, _mm256_castsi256_pd(_mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
result = _mm512_insertf64x4(
result, _mm256_castsi256_pd(_mm512_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
return _mm512_castpd_ph(result);
}
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
// Discard second-half of input.
Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(a), 0));
return _mm256_cvtxph_ps(_mm_castsi128_ph(low));
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
__m256d result = _mm256_undefined_pd();
result = _mm256_insertf64x2(result,
_mm_castsi128_pd(_mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
result = _mm256_insertf64x2(result,
_mm_castsi128_pd(_mm256_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
return _mm256_castpd_si256(result);
}
template <>
EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(a));
// Discard second-half of input.
return _mm256_extractf32x4_ps(full, 0);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
__m256 result = _mm256_undefined_ps();
result = _mm256_insertf128_ps(result, a, 0);
result = _mm256_insertf128_ps(result, b, 1);
return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT);
}
#endif
} // end namespace internal
} // end namespace Eigen

View File

@ -0,0 +1,130 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 The Eigen Authors.
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_TYPE_CASTING_FP16_AVX512_H
#define EIGEN_TYPE_CASTING_FP16_AVX512_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
template <>
EIGEN_STRONG_INLINE Packet32s preinterpret<Packet32s, Packet32h>(const Packet32h& a) {
return _mm512_castph_si512(a);
}
template <>
EIGEN_STRONG_INLINE Packet16s preinterpret<Packet16s, Packet16h>(const Packet16h& a) {
return _mm256_castph_si256(a);
}
template <>
EIGEN_STRONG_INLINE Packet8s preinterpret<Packet8s, Packet8h>(const Packet8h& a) {
return _mm_castph_si128(a);
}
template <>
EIGEN_STRONG_INLINE Packet32h preinterpret<Packet32h, Packet32s>(const Packet32s& a) {
return _mm512_castsi512_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet16s>(const Packet16s& a) {
return _mm256_castsi256_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet8s>(const Packet8s& a) {
return _mm_castsi128_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
return half2float(a);
}
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
return half2float(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
return float2half(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
return float2half(a);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
// Discard second-half of input.
Packet16h low = _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
return _mm512_cvtxph_ps(low);
}
template <>
EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
// Discard second-half of input.
Packet8h low = _mm_castps_ph(_mm256_extractf32x4_ps(_mm256_castph_ps(a), 0));
return _mm256_cvtxph_ps(low);
}
template <>
EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
Packet8f full = _mm256_cvtxph_ps(a);
// Discard second-half of input.
return _mm256_extractf32x4_ps(full, 0);
}
template <>
EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
__m512 result = _mm512_castsi512_ps(_mm512_castsi256_si512(_mm256_castph_si256(_mm512_cvtxps_ph(a))));
result = _mm512_insertf32x8(result, _mm256_castph_ps(_mm512_cvtxps_ph(b)), 1);
return _mm512_castps_ph(result);
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
__m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castph_si128(_mm256_cvtxps_ph(a))));
result = _mm256_insertf32x4(result, _mm_castph_ps(_mm256_cvtxps_ph(b)), 1);
return _mm256_castps_ph(result);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
__m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castps_si128(a)));
result = _mm256_insertf128_ps(result, b, 1);
return _mm256_cvtxps_ph(result);
}
template <>
EIGEN_STRONG_INLINE Packet32s pcast<Packet32h, Packet32s>(const Packet32h& a) {
return _mm512_cvtph_epi16(a);
}
template <>
EIGEN_STRONG_INLINE Packet16s pcast<Packet16h, Packet16s>(const Packet16h& a) {
return _mm256_cvtph_epi16(a);
}
template <>
EIGEN_STRONG_INLINE Packet8s pcast<Packet8h, Packet8s>(const Packet8h& a) {
return _mm_cvtph_epi16(a);
}
template <>
EIGEN_STRONG_INLINE Packet32h pcast<Packet32s, Packet32h>(const Packet32s& a) {
return _mm512_cvtepi16_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet16h pcast<Packet16s, Packet16h>(const Packet16s& a) {
return _mm256_cvtepi16_ph(a);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcast<Packet8s, Packet8h>(const Packet8s& a) {
return _mm_cvtepi16_ph(a);
}
} // namespace internal
} // namespace Eigen
#endif // EIGEN_TYPE_CASTING_FP16_AVX512_H

View File

@ -37,21 +37,23 @@
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
// When compiling with GPU support, the "__half_raw" base class as well as
// some other routines are defined in the GPU compiler header files
// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
// As a consequence, we get compile failures when compiling Eigen with
// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
// Eigen with GPU support
#pragma push_macro("EIGEN_CONSTEXPR")
#undef EIGEN_CONSTEXPR
#define EIGEN_CONSTEXPR
// Eigen with GPU support.
// Any functions that require `numext::bit_cast` may also not be constexpr,
// including any native types when setting via raw bit values.
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
#define _EIGEN_MAYBE_CONSTEXPR
#else
#define _EIGEN_MAYBE_CONSTEXPR constexpr
#endif
#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
template <> \
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
EIGEN_UNUSED EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
return float2half(METHOD<PACKET_F>(half2float(_x))); \
}
@ -81,8 +83,10 @@ namespace half_impl {
// Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves
// this error, and hence the following convoluted #if condition
#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
// Make our own __half_raw definition that is similar to CUDA's.
struct __half_raw {
struct construct_from_rep_tag {};
#if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE))
// Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF)
// The element type for shared memory cannot have non-trivial constructors
@ -91,43 +95,53 @@ struct __half_raw {
// hence the need for this
EIGEN_DEVICE_FUNC __half_raw() {}
#else
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw() : x(0) {}
#endif
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {}
explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {}
EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, __fp16 rep) : x{rep} {}
__fp16 x;
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<_Float16>(raw)) {}
EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, _Float16 rep) : x{rep} {}
_Float16 x;
#else
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
explicit EIGEN_DEVICE_FUNC constexpr __half_raw(numext::uint16_t raw) : x(raw) {}
EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, numext::uint16_t rep) : x{rep} {}
numext::uint16_t x;
#endif
};
#elif defined(EIGEN_HAS_HIP_FP16)
// Nothing to do here
// HIP GPU compile phase: nothing to do here.
// HIP fp16 header file has a definition for __half_raw
#elif defined(EIGEN_HAS_CUDA_FP16)
// CUDA GPU compile phase.
#if EIGEN_CUDA_SDK_VER < 90000
// In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
typedef __half __half_raw;
#endif // defined(EIGEN_HAS_CUDA_FP16)
#elif defined(SYCL_DEVICE_ONLY)
typedef cl::sycl::half __half_raw;
#endif
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
struct half_base : public __half_raw {
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base() {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
#if defined(EIGEN_HAS_GPU_FP16)
#if defined(EIGEN_HAS_HIP_FP16)
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
#elif defined(EIGEN_HAS_CUDA_FP16)
#if EIGEN_CUDA_SDK_VER >= 90000
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
#endif
#endif
#endif
@ -156,21 +170,29 @@ struct half : public half_impl::half_base {
#endif
#endif
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half() {}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
#if defined(EIGEN_HAS_GPU_FP16)
#if defined(EIGEN_HAS_HIP_FP16)
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
#elif defined(EIGEN_HAS_CUDA_FP16)
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
#endif
#endif
#endif
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b)
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(__fp16 b)
: half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(_Float16 b)
: half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
#endif
explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(bool b)
: half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
template <class T>
explicit EIGEN_DEVICE_FUNC half(T val)
@ -201,99 +223,99 @@ struct half : public half_impl::half_base {
namespace half_impl {
template <typename = void>
struct numeric_limits_half_impl {
static EIGEN_CONSTEXPR const bool is_specialized = true;
static EIGEN_CONSTEXPR const bool is_signed = true;
static EIGEN_CONSTEXPR const bool is_integer = false;
static EIGEN_CONSTEXPR const bool is_exact = false;
static EIGEN_CONSTEXPR const bool has_infinity = true;
static EIGEN_CONSTEXPR const bool has_quiet_NaN = true;
static EIGEN_CONSTEXPR const bool has_signaling_NaN = true;
static constexpr const bool is_specialized = true;
static constexpr const bool is_signed = true;
static constexpr const bool is_integer = false;
static constexpr const bool is_exact = false;
static constexpr const bool has_infinity = true;
static constexpr const bool has_quiet_NaN = true;
static constexpr const bool has_signaling_NaN = true;
EIGEN_DIAGNOSTICS(push)
EIGEN_DISABLE_DEPRECATED_WARNING
static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm = std::denorm_present;
static EIGEN_CONSTEXPR const bool has_denorm_loss = false;
static constexpr const std::float_denorm_style has_denorm = std::denorm_present;
static constexpr const bool has_denorm_loss = false;
EIGEN_DIAGNOSTICS(pop)
static EIGEN_CONSTEXPR const std::float_round_style round_style = std::round_to_nearest;
static EIGEN_CONSTEXPR const bool is_iec559 = true;
static constexpr const std::float_round_style round_style = std::round_to_nearest;
static constexpr const bool is_iec559 = true;
// The C++ standard defines this as "true if the set of values representable
// by the type is finite." Half has finite precision.
static EIGEN_CONSTEXPR const bool is_bounded = true;
static EIGEN_CONSTEXPR const bool is_modulo = false;
static EIGEN_CONSTEXPR const int digits = 11;
static EIGEN_CONSTEXPR const int digits10 =
static constexpr const bool is_bounded = true;
static constexpr const bool is_modulo = false;
static constexpr const int digits = 11;
static constexpr const int digits10 =
3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
static EIGEN_CONSTEXPR const int max_digits10 =
static constexpr const int max_digits10 =
5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
static EIGEN_CONSTEXPR const int radix = std::numeric_limits<float>::radix;
static EIGEN_CONSTEXPR const int min_exponent = -13;
static EIGEN_CONSTEXPR const int min_exponent10 = -4;
static EIGEN_CONSTEXPR const int max_exponent = 16;
static EIGEN_CONSTEXPR const int max_exponent10 = 4;
static EIGEN_CONSTEXPR const bool traps = std::numeric_limits<float>::traps;
static constexpr const int radix = std::numeric_limits<float>::radix;
static constexpr const int min_exponent = -13;
static constexpr const int min_exponent10 = -4;
static constexpr const int max_exponent = 16;
static constexpr const int max_exponent10 = 4;
static constexpr const bool traps = std::numeric_limits<float>::traps;
// IEEE754: "The implementer shall choose how tininess is detected, but shall
// detect tininess in the same way for all operations in radix two"
static EIGEN_CONSTEXPR const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
static constexpr const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
static EIGEN_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); }
static EIGEN_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
static EIGEN_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
static EIGEN_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); }
static EIGEN_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); }
static EIGEN_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
static EIGEN_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
static EIGEN_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
static EIGEN_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
static _EIGEN_MAYBE_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); }
};
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_specialized;
constexpr const bool numeric_limits_half_impl<T>::is_specialized;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_signed;
constexpr const bool numeric_limits_half_impl<T>::is_signed;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_integer;
constexpr const bool numeric_limits_half_impl<T>::is_integer;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_exact;
constexpr const bool numeric_limits_half_impl<T>::is_exact;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_infinity;
constexpr const bool numeric_limits_half_impl<T>::has_infinity;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_quiet_NaN;
constexpr const bool numeric_limits_half_impl<T>::has_quiet_NaN;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_signaling_NaN;
constexpr const bool numeric_limits_half_impl<T>::has_signaling_NaN;
EIGEN_DIAGNOSTICS(push)
EIGEN_DISABLE_DEPRECATED_WARNING
template <typename T>
EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_half_impl<T>::has_denorm;
constexpr const std::float_denorm_style numeric_limits_half_impl<T>::has_denorm;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_denorm_loss;
constexpr const bool numeric_limits_half_impl<T>::has_denorm_loss;
EIGEN_DIAGNOSTICS(pop)
template <typename T>
EIGEN_CONSTEXPR const std::float_round_style numeric_limits_half_impl<T>::round_style;
constexpr const std::float_round_style numeric_limits_half_impl<T>::round_style;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_iec559;
constexpr const bool numeric_limits_half_impl<T>::is_iec559;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_bounded;
constexpr const bool numeric_limits_half_impl<T>::is_bounded;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_modulo;
constexpr const bool numeric_limits_half_impl<T>::is_modulo;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::digits;
constexpr const int numeric_limits_half_impl<T>::digits;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::digits10;
constexpr const int numeric_limits_half_impl<T>::digits10;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_digits10;
constexpr const int numeric_limits_half_impl<T>::max_digits10;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::radix;
constexpr const int numeric_limits_half_impl<T>::radix;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::min_exponent;
constexpr const int numeric_limits_half_impl<T>::min_exponent;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::min_exponent10;
constexpr const int numeric_limits_half_impl<T>::min_exponent10;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_exponent;
constexpr const int numeric_limits_half_impl<T>::max_exponent;
template <typename T>
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_exponent10;
constexpr const int numeric_limits_half_impl<T>::max_exponent10;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::traps;
constexpr const bool numeric_limits_half_impl<T>::traps;
template <typename T>
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::tinyness_before;
constexpr const bool numeric_limits_half_impl<T>::tinyness_before;
} // end namespace half_impl
} // end namespace Eigen
@ -320,8 +342,7 @@ namespace half_impl {
(defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
// Note: We deliberately do *not* define this to 1 even if we have Arm's native
// fp16 type since GPU half types are rather different from native CPU half types.
// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
#define EIGEN_HAS_NATIVE_FP16
#define EIGEN_HAS_NATIVE_GPU_FP16
#endif
// Intrinsics for native fp16 support. Note that on current hardware,
@ -329,7 +350,7 @@ namespace half_impl {
// versions to get the ALU speed increased), but you do save the
// conversion steps back and forth.
#if defined(EIGEN_HAS_NATIVE_FP16)
#if defined(EIGEN_HAS_NATIVE_GPU_FP16)
EIGEN_STRONG_INLINE __device__ half operator+(const half& a, const half& b) {
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
return __hadd(::__half(a), ::__half(b));
@ -371,7 +392,8 @@ EIGEN_STRONG_INLINE __device__ bool operator<(const half& a, const half& b) { re
EIGEN_STRONG_INLINE __device__ bool operator<=(const half& a, const half& b) { return __hle(a, b); }
EIGEN_STRONG_INLINE __device__ bool operator>(const half& a, const half& b) { return __hgt(a, b); }
EIGEN_STRONG_INLINE __device__ bool operator>=(const half& a, const half& b) { return __hge(a, b); }
#endif
#endif // EIGEN_HAS_NATIVE_GPU_FP16
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) && !defined(EIGEN_GPU_COMPILE_PHASE)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(vaddh_f16(a.x, b.x)); }
@ -401,16 +423,47 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half&
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return vcleh_f16(a.x, b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return vcgth_f16(a.x, b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return vcgeh_f16(a.x, b.x); }
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) && !defined(EIGEN_GPU_COMPILE_PHASE)
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(a.x + b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator*(const half& a, const half& b) { return half(a.x * b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a, const half& b) { return half(a.x - b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, const half& b) { return half(a.x / b.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a) { return half(-a.x); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator+=(half& a, const half& b) {
a = a + b;
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator*=(half& a, const half& b) {
a = a * b;
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator-=(half& a, const half& b) {
a = a - b;
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
a = a / b;
return a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) { return a.x == b.x; }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return a.x != b.x; }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return a.x < b.x; }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return a.x <= b.x; }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return a.x > b.x; }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return a.x >= b.x; }
// We need to distinguish clang as the CUDA compiler from clang as the host compiler,
// invoked by NVCC (e.g. on MacOS). The former needs to see both host and device implementation
// of the functions, while the latter can only deal with one of them.
#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
#elif !defined(EIGEN_HAS_NATIVE_GPU_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
// We need to provide emulated *host-side* FP16 operators for clang.
#pragma push_macro("EIGEN_DEVICE_FUNC")
#undef EIGEN_DEVICE_FUNC
#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16)
#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_GPU_FP16)
#define EIGEN_DEVICE_FUNC __host__
#else // both host and device need emulated ops.
#define EIGEN_DEVICE_FUNC __host__ __device__
@ -458,6 +511,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half&
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
#endif
#endif // Emulate support for half floats
// Division by an index. Do it in full float precision to avoid accuracy
@ -493,7 +547,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) {
// these in hardware. If we need more performance on older/other CPUs, they are
// also possible to vectorize directly.
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
// We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
// in the hip_fp16 header file, and that will trigger a compile error
// On the other hand, having anything but a return statement also triggers a compile error
@ -515,6 +569,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const
// For SYCL, cl::sycl::half is _Float16, so cast directly.
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return numext::bit_cast<numext::uint16_t>(h.x);
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
return numext::bit_cast<numext::uint16_t>(h.x);
#elif defined(SYCL_DEVICE_ONLY)
return numext::bit_cast<numext::uint16_t>(h);
#else
@ -528,6 +584,16 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
__half tmp_ff = __float2half(ff);
return *(__half_raw*)&tmp_ff;
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
__half_raw h;
h.x = static_cast<__fp16>(ff);
return h;
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
__half_raw h;
h.x = static_cast<_Float16>(ff);
return h;
#elif defined(EIGEN_HAS_FP16_C)
__half_raw h;
#if EIGEN_COMP_MSVC
@ -538,11 +604,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
#endif
return h;
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
__half_raw h;
h.x = static_cast<__fp16>(ff);
return h;
#else
uint32_t f_bits = Eigen::numext::bit_cast<uint32_t>(ff);
const uint32_t f32infty_bits = {255 << 23};
@ -595,6 +656,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __half2float(h);
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
return static_cast<float>(h.x);
#elif defined(EIGEN_HAS_FP16_C)
#if EIGEN_COMP_MSVC
// MSVC does not have scalar instructions.
@ -602,8 +665,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
#else
return _cvtsh_ss(h.x);
#endif
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return static_cast<float>(h.x);
#else
const float magic = Eigen::numext::bit_cast<float>(static_cast<uint32_t>(113 << 23));
const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
@ -628,7 +689,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
// --- standard functions ---
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const half& a) {
#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
#else
return (a.x & 0x7fff) == 0x7c00;
@ -638,7 +699,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __hisnan(a);
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
#else
return (a.x & 0x7fff) > 0x7c00;
@ -651,6 +712,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vabsh_f16(a.x));
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
half result;
result.x =
numext::bit_cast<_Float16>(static_cast<numext::uint16_t>(numext::bit_cast<numext::uint16_t>(a.x) & 0x7FFF));
return result;
#else
half result;
result.x = a.x & 0x7FFF;
@ -734,26 +800,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) {
return half(::fmodf(float(a), float(b)));
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __hlt(b, a) ? b : a;
#else
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f2 < f1 ? b : a;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) {
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
return __hlt(a, b) ? b : a;
#else
const float f1 = static_cast<float>(a);
const float f2 = static_cast<float>(b);
return f1 < f2 ? b : a;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) { return b < a ? b : a; }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { return a < b ? b : a; }
#ifndef EIGEN_NO_IO
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) {
@ -794,31 +843,29 @@ template <>
struct NumTraits<Eigen::half> : GenericNumTraits<Eigen::half> {
enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false };
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
return half_impl::raw_uint16_to_half(0x0800);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
return half_impl::raw_uint16_to_half(0x7bff);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
return half_impl::raw_uint16_to_half(0xfbff);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
return half_impl::raw_uint16_to_half(0x7c00);
}
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
return half_impl::raw_uint16_to_half(0x7e00);
}
};
} // end namespace Eigen
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
#pragma pop_macro("EIGEN_CONSTEXPR")
#endif
#undef _EIGEN_MAYBE_CONSTEXPR
namespace Eigen {
namespace numext {
@ -976,6 +1023,36 @@ struct cast_impl<half, float> {
}
};
#ifdef EIGEN_VECTORIZE_FMA
template <>
EIGEN_DEVICE_FUNC inline half pmadd(const half& a, const half& b, const half& c) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vfmah_f16(a.x, b.x, c.x));
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
// Reduces to vfmadd213sh.
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
#else
// Emulate FMA via float.
return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
#endif
}
template <>
EIGEN_DEVICE_FUNC inline half pmsub(const half& a, const half& b, const half& c) {
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
return half(vfmah_f16(a.x, b.x, -c.x));
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
// Reduces to vfmadd213sh.
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), -_mm_set_sh(c.x))));
#else
// Emulate FMA via float.
return half(static_cast<float>(a) * static_cast<float>(b) - static_cast<float>(c));
#endif
}
#endif
} // namespace internal
} // namespace Eigen

View File

@ -285,6 +285,8 @@
#ifdef __AVX512FP16__
#ifdef __AVX512VL__
#define EIGEN_VECTORIZE_AVX512FP16
// Built-in _Float16.
#define EIGEN_HAS_BUILTIN_FLOAT16 1
#else
#if EIGEN_COMP_GNUC
#error Please add -mavx512vl to your compiler flags: compiling with -mavx512fp16 alone without AVX512-VL is not supported.

View File

@ -7,7 +7,8 @@
// Include this header to be able to print Packets while debugging.
template <typename Packet,
typename EnableIf = std::enable_if_t<Eigen::internal::unpacket_traits<Packet>::vectorizable> >
typename EnableIf = std::enable_if_t<(Eigen::internal::unpacket_traits<Packet>::vectorizable ||
Eigen::internal::unpacket_traits<Packet>::size > 1)> >
std::ostream& operator<<(std::ostream& os, const Packet& packet) {
using Scalar = typename Eigen::internal::unpacket_traits<Packet>::type;
Scalar v[Eigen::internal::unpacket_traits<Packet>::size];

View File

@ -26,19 +26,19 @@ inline T REF_MUL(const T& a, const T& b) {
}
template <typename T>
inline T REF_MADD(const T& a, const T& b, const T& c) {
return a * b + c;
return internal::pmadd(a, b, c);
}
template <typename T>
inline T REF_MSUB(const T& a, const T& b, const T& c) {
return a * b - c;
return internal::pmsub(a, b, c);
}
template <typename T>
inline T REF_NMADD(const T& a, const T& b, const T& c) {
return c - a * b;
return internal::pnmadd(a, b, c);
}
template <typename T>
inline T REF_NMSUB(const T& a, const T& b, const T& c) {
return test::negate(a * b + c);
return internal::pnmsub(a, b, c);
}
template <typename T>
inline T REF_DIV(const T& a, const T& b) {