mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-18 07:39:37 +08:00
Use native _Float16 for AVX512FP16 and update vectorization.
This allows us to do faster native scalar operations. Also updated half/quarter packets to use the native type if available. Benchmark improvement: ``` Comparing ./2910_without_float16 to ./2910_with_float16 Benchmark Time CPU Time Old Time New CPU Old CPU New ------------------------------------------------------------------------------------------------------------------------------------ BM_CalcMat<float>/10000/768/500 -0.0041 -0.0040 58276392 58039442 58273420 58039582 BM_CalcMat<_Float16>/10000/768/500 +0.0073 +0.0073 642506339 647214446 642481384 647188303 BM_CalcMat<Eigen::half>/10000/768/500 -0.3170 -0.3170 92511115 63182101 92506771 63179258 BM_CalcVec<float>/10000/768/500 +0.0022 +0.0022 5198157 5209469 5197913 5209334 BM_CalcVec<_Float16>/10000/768/500 +0.0025 +0.0026 10133324 10159111 10132641 10158507 BM_CalcVec<Eigen::half>/10000/768/500 -0.7760 -0.7760 45337937 10156952 45336532 10156389 OVERALL_GEOMEAN -0.2677 -0.2677 0 0 0 0 ``` Fixes #2910.
This commit is contained in:
parent
0259a52b0e
commit
3580a38298
16
Eigen/Core
16
Eigen/Core
@ -193,21 +193,27 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
||||
|
||||
#if defined EIGEN_VECTORIZE_AVX512
|
||||
#include "src/Core/arch/SSE/PacketMath.h"
|
||||
#include "src/Core/arch/AVX/PacketMath.h"
|
||||
#include "src/Core/arch/AVX512/PacketMath.h"
|
||||
#if defined EIGEN_VECTORIZE_AVX512FP16
|
||||
#include "src/Core/arch/AVX512/PacketMathFP16.h"
|
||||
#endif
|
||||
#include "src/Core/arch/SSE/PacketMath.h"
|
||||
#include "src/Core/arch/SSE/TypeCasting.h"
|
||||
#include "src/Core/arch/SSE/Complex.h"
|
||||
#include "src/Core/arch/AVX/PacketMath.h"
|
||||
#include "src/Core/arch/AVX/TypeCasting.h"
|
||||
#include "src/Core/arch/AVX/Complex.h"
|
||||
#include "src/Core/arch/AVX512/PacketMath.h"
|
||||
#include "src/Core/arch/AVX512/TypeCasting.h"
|
||||
#if defined EIGEN_VECTORIZE_AVX512FP16
|
||||
#include "src/Core/arch/AVX512/TypeCastingFP16.h"
|
||||
#endif
|
||||
#include "src/Core/arch/SSE/Complex.h"
|
||||
#include "src/Core/arch/AVX/Complex.h"
|
||||
#include "src/Core/arch/AVX512/Complex.h"
|
||||
#include "src/Core/arch/SSE/MathFunctions.h"
|
||||
#include "src/Core/arch/AVX/MathFunctions.h"
|
||||
#include "src/Core/arch/AVX512/MathFunctions.h"
|
||||
#if defined EIGEN_VECTORIZE_AVX512FP16
|
||||
#include "src/Core/arch/AVX512/MathFunctionsFP16.h"
|
||||
#endif
|
||||
#include "src/Core/arch/AVX512/TrsmKernel.h"
|
||||
#elif defined EIGEN_VECTORIZE_AVX
|
||||
// Use AVX for floats and doubles, SSE for integers
|
||||
|
@ -76,7 +76,7 @@ struct generic_rsqrt_newton_step {
|
||||
static_assert(Steps > 0, "Steps must be at least 1.");
|
||||
using Scalar = typename unpacket_traits<Packet>::type;
|
||||
EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE Packet run(const Packet& a, const Packet& approx_rsqrt) {
|
||||
constexpr Scalar kMinusHalf = Scalar(-1) / Scalar(2);
|
||||
const Scalar kMinusHalf = Scalar(-1) / Scalar(2);
|
||||
const Packet cst_minus_half = pset1<Packet>(kMinusHalf);
|
||||
const Packet cst_minus_one = pset1<Packet>(Scalar(-1));
|
||||
|
||||
|
@ -106,6 +106,8 @@ BF16_PACKET_FUNCTION(Packet8f, Packet8bf, prsqrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, psqrt)
|
||||
BF16_PACKET_FUNCTION(Packet8f, Packet8bf, ptanh)
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pcos)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, pexp2)
|
||||
@ -118,6 +120,7 @@ F16_PACKET_FUNCTION(Packet8f, Packet8h, prsqrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, psin)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, psqrt)
|
||||
F16_PACKET_FUNCTION(Packet8f, Packet8h, ptanh)
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
@ -1839,10 +1839,13 @@ EIGEN_STRONG_INLINE Packet8ui pabs(const Packet8ui& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h psignbit(const Packet8h& a) {
|
||||
return _mm_cmpgt_epi16(_mm_setzero_si128(), a);
|
||||
}
|
||||
#endif // EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf psignbit(const Packet8bf& a) {
|
||||
return _mm_cmpgt_epi16(_mm_setzero_si128(), a);
|
||||
@ -2044,10 +2047,13 @@ EIGEN_STRONG_INLINE bool predux_any(const Packet8ui& x) {
|
||||
return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0;
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bool predux_any(const Packet8h& x) {
|
||||
return _mm_movemask_epi8(x) != 0;
|
||||
}
|
||||
#endif // EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE bool predux_any(const Packet8bf& x) {
|
||||
return _mm_movemask_epi8(x) != 0;
|
||||
@ -2211,7 +2217,6 @@ struct unpacket_traits<Packet8h> {
|
||||
};
|
||||
typedef Packet8h half;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
|
||||
@ -2446,14 +2451,12 @@ EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const
|
||||
to[stride * 7] = aux[7];
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
|
||||
Packet8f af = half2float(a);
|
||||
float reduced = predux<Packet8f>(af);
|
||||
return Eigen::half(reduced);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
|
||||
@ -2553,6 +2556,8 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet8h, 4>& kernel) {
|
||||
kernel.packet[3] = pload<Packet8h>(out[3]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// BFloat16 implementation.
|
||||
|
||||
EIGEN_STRONG_INLINE Packet8f Bf16ToF32(const Packet8bf& a) {
|
||||
|
@ -279,20 +279,22 @@ EIGEN_STRONG_INLINE Packet2l preinterpret<Packet2l, Packet4l>(const Packet4l& a)
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
|
||||
return half2float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
|
||||
return Bf16ToF32(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
|
||||
return float2half(a);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f pcast<Packet8bf, Packet8f>(const Packet8bf& a) {
|
||||
return Bf16ToF32(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf pcast<Packet8f, Packet8bf>(const Packet8f& a) {
|
||||
|
@ -47,16 +47,16 @@ EIGEN_STRONG_INLINE Packet16bf pldexp(const Packet16bf& a, const Packet16bf& exp
|
||||
|
||||
#if EIGEN_FAST_MATH
|
||||
template <>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt<Packet16f>(const Packet16f& _x) {
|
||||
return generic_sqrt_newton_step<Packet16f>::run(_x, _mm512_rsqrt14_ps(_x));
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f psqrt<Packet16f>(const Packet16f& x) {
|
||||
return generic_sqrt_newton_step<Packet16f>::run(x, _mm512_rsqrt14_ps(x));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt<Packet8d>(const Packet8d& _x) {
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d psqrt<Packet8d>(const Packet8d& x) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512ER
|
||||
return generic_sqrt_newton_step<Packet8d, /*Steps=*/1>::run(_x, _mm512_rsqrt28_pd(_x));
|
||||
return generic_sqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
|
||||
#else
|
||||
return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(_x, _mm512_rsqrt14_pd(_x));
|
||||
return generic_sqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
@ -80,19 +80,19 @@ EIGEN_STRONG_INLINE Packet16f prsqrt<Packet16f>(const Packet16f& x) {
|
||||
#elif EIGEN_FAST_MATH
|
||||
|
||||
template <>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt<Packet16f>(const Packet16f& _x) {
|
||||
return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(_x, _mm512_rsqrt14_ps(_x));
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f prsqrt<Packet16f>(const Packet16f& x) {
|
||||
return generic_rsqrt_newton_step<Packet16f, /*Steps=*/1>::run(x, _mm512_rsqrt14_ps(x));
|
||||
}
|
||||
#endif
|
||||
|
||||
// prsqrt for double.
|
||||
#if EIGEN_FAST_MATH
|
||||
template <>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt<Packet8d>(const Packet8d& _x) {
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d prsqrt<Packet8d>(const Packet8d& x) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512ER
|
||||
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/1>::run(_x, _mm512_rsqrt28_pd(_x));
|
||||
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/1>::run(x, _mm512_rsqrt28_pd(x));
|
||||
#else
|
||||
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/2>::run(_x, _mm512_rsqrt14_pd(_x));
|
||||
return generic_rsqrt_newton_step<Packet8d, /*Steps=*/2>::run(x, _mm512_rsqrt14_pd(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -118,6 +118,8 @@ BF16_PACKET_FUNCTION(Packet16f, Packet16bf, prsqrt)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, psqrt)
|
||||
BF16_PACKET_FUNCTION(Packet16f, Packet16bf, ptanh)
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pcos)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, pexp2)
|
||||
@ -130,6 +132,7 @@ F16_PACKET_FUNCTION(Packet16f, Packet16h, prsqrt)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, psin)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, psqrt)
|
||||
F16_PACKET_FUNCTION(Packet16f, Packet16h, ptanh)
|
||||
#endif // EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
|
75
Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h
Normal file
75
Eigen/src/Core/arch/AVX512/MathFunctionsFP16.h
Normal file
@ -0,0 +1,75 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2025 The Eigen Authors.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
|
||||
#define EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
|
||||
|
||||
// IWYU pragma: private
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
|
||||
__m512i result = _mm512_castsi256_si512(_mm256_castph_si256(a));
|
||||
result = _mm512_inserti64x4(result, _mm256_castph_si256(b), 1);
|
||||
return _mm512_castsi512_ph(result);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
|
||||
a = _mm256_castsi256_ph(_mm512_castsi512_si256(_mm512_castph_si512(x)));
|
||||
b = _mm256_castsi256_ph(_mm512_extracti64x4_epi64(_mm512_castph_si512(x), 1));
|
||||
}
|
||||
|
||||
#define _EIGEN_GENERATE_FP16_MATH_FUNCTION(func) \
|
||||
template <> \
|
||||
EIGEN_STRONG_INLINE Packet8h func<Packet8h>(const Packet8h& a) { \
|
||||
return float2half(func(half2float(a))); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
EIGEN_STRONG_INLINE Packet16h func<Packet16h>(const Packet16h& a) { \
|
||||
return float2half(func(half2float(a))); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
EIGEN_STRONG_INLINE Packet32h func<Packet32h>(const Packet32h& a) { \
|
||||
Packet16h low; \
|
||||
Packet16h high; \
|
||||
extract2Packet16h(a, low, high); \
|
||||
return combine2Packet16h(func(low), func(high)); \
|
||||
}
|
||||
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(psin)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pcos)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog2)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(plog1p)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexpm1)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(pexp2)
|
||||
_EIGEN_GENERATE_FP16_MATH_FUNCTION(ptanh)
|
||||
#undef _EIGEN_GENERATE_FP16_MATH_FUNCTION
|
||||
|
||||
// pfrexp
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
|
||||
return pfrexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
// pldexp
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
|
||||
return pldexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_MATH_FUNCTIONS_FP16_AVX512_H
|
@ -40,6 +40,10 @@ typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
|
||||
#endif
|
||||
typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
|
||||
|
||||
typedef eigen_packet_wrapper<__m512i, 6> Packet32s;
|
||||
typedef eigen_packet_wrapper<__m256i, 6> Packet16s;
|
||||
typedef eigen_packet_wrapper<__m128i, 6> Packet8s;
|
||||
|
||||
template <>
|
||||
struct is_arithmetic<__m512> {
|
||||
enum { value = true };
|
||||
@ -248,6 +252,39 @@ struct unpacket_traits<Packet16h> {
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet32s> {
|
||||
typedef numext::int16_t type;
|
||||
typedef Packet16s half;
|
||||
enum {
|
||||
size = 32,
|
||||
alignment = Aligned64,
|
||||
vectorizable = false,
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet16s> {
|
||||
typedef numext::int16_t type;
|
||||
typedef Packet8s half;
|
||||
enum {
|
||||
size = 16,
|
||||
alignment = Aligned32,
|
||||
vectorizable = false,
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet8s> {
|
||||
typedef numext::int16_t type;
|
||||
typedef Packet8s half;
|
||||
enum {
|
||||
size = 8,
|
||||
alignment = Aligned16,
|
||||
vectorizable = false,
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pset1<Packet16f>(const float& from) {
|
||||
return _mm512_set1_ps(from);
|
||||
@ -1335,10 +1372,13 @@ EIGEN_STRONG_INLINE Packet8l pabs(const Packet8l& a) {
|
||||
return _mm512_abs_epi64(a);
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h psignbit(const Packet16h& a) {
|
||||
return _mm256_srai_epi16(a, 15);
|
||||
}
|
||||
#endif // EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16bf psignbit(const Packet16bf& a) {
|
||||
return _mm256_srai_epi16(a, 15);
|
||||
@ -2199,6 +2239,7 @@ EIGEN_STRONG_INLINE Packet8d pblend(const Selector<8>& ifPacket, const Packet8d&
|
||||
}
|
||||
|
||||
// Packet math for Eigen::half
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pset1<Packet16h>(const Eigen::half& from) {
|
||||
return _mm256_set1_epi16(from.x);
|
||||
@ -2369,7 +2410,6 @@ EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
|
||||
return _mm256_xor_si256(a, sign_mask);
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
|
||||
Packet16f af = half2float(a);
|
||||
@ -2408,8 +2448,6 @@ EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
|
||||
return half(predux(from_float));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
|
||||
Packet8h lane0 = _mm256_extractf128_si256(a, 0);
|
||||
@ -2643,6 +2681,8 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16h, 4>& kernel) {
|
||||
kernel.packet[3] = pload<Packet16h>(out[3]);
|
||||
}
|
||||
|
||||
#endif // EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
template <>
|
||||
struct is_arithmetic<Packet16bf> {
|
||||
enum { value = true };
|
||||
@ -3095,6 +3135,158 @@ EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet16bf, 4>& kernel) {
|
||||
kernel.packet[3] = _mm256_permute2x128_si256(abcd_8b, abcd_cf, 0x31);
|
||||
}
|
||||
|
||||
// Minimal implementation of 16-bit int packets for use in pfrexp, pldexp.
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32s pset1<Packet32s>(const numext::int16_t& x) {
|
||||
return _mm512_set1_epi16(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16s pset1<Packet16s>(const numext::int16_t& x) {
|
||||
return _mm256_set1_epi16(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8s pset1<Packet8s>(const numext::int16_t& x) {
|
||||
return _mm_set1_epi16(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
|
||||
_mm512_storeu_epi16(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
|
||||
_mm256_storeu_epi16(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
|
||||
_mm_storeu_epi16(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet32s>(numext::int16_t* out, const Packet32s& x) {
|
||||
_mm512_storeu_epi16(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet16s>(numext::int16_t* out, const Packet16s& x) {
|
||||
_mm256_storeu_epi16(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<numext::int16_t, Packet8s>(numext::int16_t* out, const Packet8s& x) {
|
||||
_mm_storeu_epi16(out, x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32s padd(const Packet32s& a, const Packet32s& b) {
|
||||
return _mm512_add_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16s padd(const Packet16s& a, const Packet16s& b) {
|
||||
return _mm256_add_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8s padd(const Packet8s& a, const Packet8s& b) {
|
||||
return _mm_add_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32s psub(const Packet32s& a, const Packet32s& b) {
|
||||
return _mm512_sub_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16s psub(const Packet16s& a, const Packet16s& b) {
|
||||
return _mm256_sub_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8s psub(const Packet8s& a, const Packet8s& b) {
|
||||
return _mm_sub_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32s pmul(const Packet32s& a, const Packet32s& b) {
|
||||
return _mm512_mullo_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16s pmul(const Packet16s& a, const Packet16s& b) {
|
||||
return _mm256_mullo_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8s pmul(const Packet8s& a, const Packet8s& b) {
|
||||
return _mm_mullo_epi16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32s pnegate(const Packet32s& a) {
|
||||
return _mm512_sub_epi16(_mm512_setzero_si512(), a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16s pnegate(const Packet16s& a) {
|
||||
return _mm256_sub_epi16(_mm256_setzero_si256(), a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a) {
|
||||
return _mm_sub_epi16(_mm_setzero_si128(), a);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet32s parithmetic_shift_right(Packet32s a) {
|
||||
return _mm512_srai_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet16s parithmetic_shift_right(Packet16s a) {
|
||||
return _mm256_srai_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet8s parithmetic_shift_right(Packet8s a) {
|
||||
return _mm_srai_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet32s plogical_shift_left(Packet32s a) {
|
||||
return _mm512_slli_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet16s plogical_shift_left(Packet16s a) {
|
||||
return _mm256_slli_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet8s plogical_shift_left(Packet8s a) {
|
||||
return _mm_slli_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet32s plogical_shift_right(Packet32s a) {
|
||||
return _mm512_srli_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet16s plogical_shift_right(Packet16s a) {
|
||||
return _mm256_srli_epi16(a, N);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
EIGEN_STRONG_INLINE Packet8s plogical_shift_right(Packet8s a) {
|
||||
return _mm_srli_epi16(a, N);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -237,16 +237,12 @@ EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i, Packet16i>(const Packet16i&
|
||||
return _mm512_castsi512_si128(a);
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet16h>(const Packet16h& a) {
|
||||
return _mm256_castsi256_si128(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
|
||||
return _mm256_castsi256_si128(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
|
||||
return half2float(a);
|
||||
@ -257,6 +253,13 @@ EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
|
||||
return float2half(a);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8bf preinterpret<Packet8bf, Packet16bf>(const Packet16bf& a) {
|
||||
return _mm256_castsi256_si128(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcast<Packet16bf, Packet16f>(const Packet16bf& a) {
|
||||
return Bf16ToF32(a);
|
||||
@ -267,68 +270,6 @@ EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Packet16f& a)
|
||||
return F32ToBf16(a);
|
||||
}
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet32h>(const Packet32h& a) {
|
||||
return _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet32h>(const Packet32h& a) {
|
||||
return _mm256_castsi256_si128(preinterpret<Packet16h>(a));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
|
||||
// Discard second-half of input.
|
||||
Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
|
||||
return _mm512_cvtxph_ps(_mm256_castsi256_ph(low));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
|
||||
__m512d result = _mm512_undefined_pd();
|
||||
result = _mm512_insertf64x4(
|
||||
result, _mm256_castsi256_pd(_mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
|
||||
result = _mm512_insertf64x4(
|
||||
result, _mm256_castsi256_pd(_mm512_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
|
||||
return _mm512_castpd_ph(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
|
||||
// Discard second-half of input.
|
||||
Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(a), 0));
|
||||
return _mm256_cvtxph_ps(_mm_castsi128_ph(low));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
|
||||
__m256d result = _mm256_undefined_pd();
|
||||
result = _mm256_insertf64x2(result,
|
||||
_mm_castsi128_pd(_mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 0);
|
||||
result = _mm256_insertf64x2(result,
|
||||
_mm_castsi128_pd(_mm256_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)), 1);
|
||||
return _mm256_castpd_si256(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
|
||||
Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(a));
|
||||
// Discard second-half of input.
|
||||
return _mm256_extractf32x4_ps(full, 0);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
|
||||
__m256 result = _mm256_undefined_ps();
|
||||
result = _mm256_insertf128_ps(result, a, 0);
|
||||
result = _mm256_insertf128_ps(result, b, 1);
|
||||
return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
130
Eigen/src/Core/arch/AVX512/TypeCastingFP16.h
Normal file
130
Eigen/src/Core/arch/AVX512/TypeCastingFP16.h
Normal file
@ -0,0 +1,130 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2025 The Eigen Authors.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_TYPE_CASTING_FP16_AVX512_H
|
||||
#define EIGEN_TYPE_CASTING_FP16_AVX512_H
|
||||
|
||||
// IWYU pragma: private
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32s preinterpret<Packet32s, Packet32h>(const Packet32h& a) {
|
||||
return _mm512_castph_si512(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16s preinterpret<Packet16s, Packet16h>(const Packet16h& a) {
|
||||
return _mm256_castph_si256(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8s preinterpret<Packet8s, Packet8h>(const Packet8h& a) {
|
||||
return _mm_castph_si128(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h preinterpret<Packet32h, Packet32s>(const Packet32s& a) {
|
||||
return _mm512_castsi512_ph(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h preinterpret<Packet16h, Packet16s>(const Packet16s& a) {
|
||||
return _mm256_castsi256_ph(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h preinterpret<Packet8h, Packet8s>(const Packet8s& a) {
|
||||
return _mm_castsi128_ph(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcast<Packet16h, Packet16f>(const Packet16h& a) {
|
||||
return half2float(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f pcast<Packet8h, Packet8f>(const Packet8h& a) {
|
||||
return half2float(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packet16f& a) {
|
||||
return float2half(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcast<Packet8f, Packet8h>(const Packet8f& a) {
|
||||
return float2half(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
|
||||
// Discard second-half of input.
|
||||
Packet16h low = _mm256_castpd_ph(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
|
||||
return _mm512_cvtxph_ps(low);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
|
||||
// Discard second-half of input.
|
||||
Packet8h low = _mm_castps_ph(_mm256_extractf32x4_ps(_mm256_castph_ps(a), 0));
|
||||
return _mm256_cvtxph_ps(low);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
|
||||
Packet8f full = _mm256_cvtxph_ps(a);
|
||||
// Discard second-half of input.
|
||||
return _mm256_extractf32x4_ps(full, 0);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
|
||||
__m512 result = _mm512_castsi512_ps(_mm512_castsi256_si512(_mm256_castph_si256(_mm512_cvtxps_ph(a))));
|
||||
result = _mm512_insertf32x8(result, _mm256_castph_ps(_mm512_cvtxps_ph(b)), 1);
|
||||
return _mm512_castps_ph(result);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
|
||||
__m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castph_si128(_mm256_cvtxps_ph(a))));
|
||||
result = _mm256_insertf32x4(result, _mm_castph_ps(_mm256_cvtxps_ph(b)), 1);
|
||||
return _mm256_castps_ph(result);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
|
||||
__m256 result = _mm256_castsi256_ps(_mm256_castsi128_si256(_mm_castps_si128(a)));
|
||||
result = _mm256_insertf128_ps(result, b, 1);
|
||||
return _mm256_cvtxps_ph(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32s pcast<Packet32h, Packet32s>(const Packet32h& a) {
|
||||
return _mm512_cvtph_epi16(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16s pcast<Packet16h, Packet16s>(const Packet16h& a) {
|
||||
return _mm256_cvtph_epi16(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8s pcast<Packet8h, Packet8s>(const Packet8h& a) {
|
||||
return _mm_cvtph_epi16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcast<Packet32s, Packet32h>(const Packet32s& a) {
|
||||
return _mm512_cvtepi16_ph(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pcast<Packet16s, Packet16h>(const Packet16s& a) {
|
||||
return _mm256_cvtepi16_ph(a);
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcast<Packet8s, Packet8h>(const Packet8s& a) {
|
||||
return _mm_cvtepi16_ph(a);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_TYPE_CASTING_FP16_AVX512_H
|
@ -37,21 +37,23 @@
|
||||
// IWYU pragma: private
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
// When compiling with GPU support, the "__half_raw" base class as well as
|
||||
// some other routines are defined in the GPU compiler header files
|
||||
// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
|
||||
// As a consequence, we get compile failures when compiling Eigen with
|
||||
// GPU support. Hence the need to disable EIGEN_CONSTEXPR when building
|
||||
// Eigen with GPU support
|
||||
#pragma push_macro("EIGEN_CONSTEXPR")
|
||||
#undef EIGEN_CONSTEXPR
|
||||
#define EIGEN_CONSTEXPR
|
||||
// Eigen with GPU support.
|
||||
// Any functions that require `numext::bit_cast` may also not be constexpr,
|
||||
// including any native types when setting via raw bit values.
|
||||
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
#define _EIGEN_MAYBE_CONSTEXPR
|
||||
#else
|
||||
#define _EIGEN_MAYBE_CONSTEXPR constexpr
|
||||
#endif
|
||||
|
||||
#define F16_PACKET_FUNCTION(PACKET_F, PACKET_F16, METHOD) \
|
||||
template <> \
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
|
||||
EIGEN_UNUSED EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC PACKET_F16 METHOD<PACKET_F16>(const PACKET_F16& _x) { \
|
||||
return float2half(METHOD<PACKET_F>(half2float(_x))); \
|
||||
}
|
||||
|
||||
@ -81,8 +83,10 @@ namespace half_impl {
|
||||
// Making the host side compile phase of hipcc use the same Eigen::half impl, as the gcc compile, resolves
|
||||
// this error, and hence the following convoluted #if condition
|
||||
#if !defined(EIGEN_HAS_GPU_FP16) || !defined(EIGEN_GPU_COMPILE_PHASE)
|
||||
|
||||
// Make our own __half_raw definition that is similar to CUDA's.
|
||||
struct __half_raw {
|
||||
struct construct_from_rep_tag {};
|
||||
#if (defined(EIGEN_HAS_GPU_FP16) && !defined(EIGEN_GPU_COMPILE_PHASE))
|
||||
// Eigen::half can be used as the datatype for shared memory declarations (in Eigen and TF)
|
||||
// The element type for shared memory cannot have non-trivial constructors
|
||||
@ -91,43 +95,53 @@ struct __half_raw {
|
||||
// hence the need for this
|
||||
EIGEN_DEVICE_FUNC __half_raw() {}
|
||||
#else
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw() : x(0) {}
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {}
|
||||
explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {}
|
||||
EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, __fp16 rep) : x{rep} {}
|
||||
__fp16 x;
|
||||
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
explicit EIGEN_DEVICE_FUNC __half_raw(numext::uint16_t raw) : x(numext::bit_cast<_Float16>(raw)) {}
|
||||
EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, _Float16 rep) : x{rep} {}
|
||||
_Float16 x;
|
||||
#else
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
|
||||
explicit EIGEN_DEVICE_FUNC constexpr __half_raw(numext::uint16_t raw) : x(raw) {}
|
||||
EIGEN_DEVICE_FUNC constexpr __half_raw(construct_from_rep_tag, numext::uint16_t rep) : x{rep} {}
|
||||
numext::uint16_t x;
|
||||
#endif
|
||||
};
|
||||
|
||||
#elif defined(EIGEN_HAS_HIP_FP16)
|
||||
// Nothing to do here
|
||||
// HIP GPU compile phase: nothing to do here.
|
||||
// HIP fp16 header file has a definition for __half_raw
|
||||
#elif defined(EIGEN_HAS_CUDA_FP16)
|
||||
|
||||
// CUDA GPU compile phase.
|
||||
#if EIGEN_CUDA_SDK_VER < 90000
|
||||
// In CUDA < 9.0, __half is the equivalent of CUDA 9's __half_raw
|
||||
typedef __half __half_raw;
|
||||
#endif // defined(EIGEN_HAS_CUDA_FP16)
|
||||
|
||||
#elif defined(SYCL_DEVICE_ONLY)
|
||||
typedef cl::sycl::half __half_raw;
|
||||
#endif
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
|
||||
|
||||
struct half_base : public __half_raw {
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base() {}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base() {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half_raw& h) : __half_raw(h) {}
|
||||
|
||||
#if defined(EIGEN_HAS_GPU_FP16)
|
||||
#if defined(EIGEN_HAS_HIP_FP16)
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) { x = __half_as_ushort(h); }
|
||||
#elif defined(EIGEN_HAS_CUDA_FP16)
|
||||
#if EIGEN_CUDA_SDK_VER >= 90000
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half_base(const __half& h) : __half_raw(*(__half_raw*)&h) {}
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
@ -156,21 +170,29 @@ struct half : public half_impl::half_base {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half() {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half() {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half_raw& h) : half_impl::half_base(h) {}
|
||||
|
||||
#if defined(EIGEN_HAS_GPU_FP16)
|
||||
#if defined(EIGEN_HAS_HIP_FP16)
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
|
||||
#elif defined(EIGEN_HAS_CUDA_FP16)
|
||||
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(const __half& h) : half_impl::half_base(h) {}
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR half(bool b)
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(__fp16 b)
|
||||
: half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
|
||||
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(_Float16 b)
|
||||
: half(__half_raw(__half_raw::construct_from_rep_tag(), b)) {}
|
||||
#endif
|
||||
|
||||
explicit EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR half(bool b)
|
||||
: half_impl::half_base(half_impl::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
|
||||
template <class T>
|
||||
explicit EIGEN_DEVICE_FUNC half(T val)
|
||||
@ -201,99 +223,99 @@ struct half : public half_impl::half_base {
|
||||
namespace half_impl {
|
||||
template <typename = void>
|
||||
struct numeric_limits_half_impl {
|
||||
static EIGEN_CONSTEXPR const bool is_specialized = true;
|
||||
static EIGEN_CONSTEXPR const bool is_signed = true;
|
||||
static EIGEN_CONSTEXPR const bool is_integer = false;
|
||||
static EIGEN_CONSTEXPR const bool is_exact = false;
|
||||
static EIGEN_CONSTEXPR const bool has_infinity = true;
|
||||
static EIGEN_CONSTEXPR const bool has_quiet_NaN = true;
|
||||
static EIGEN_CONSTEXPR const bool has_signaling_NaN = true;
|
||||
static constexpr const bool is_specialized = true;
|
||||
static constexpr const bool is_signed = true;
|
||||
static constexpr const bool is_integer = false;
|
||||
static constexpr const bool is_exact = false;
|
||||
static constexpr const bool has_infinity = true;
|
||||
static constexpr const bool has_quiet_NaN = true;
|
||||
static constexpr const bool has_signaling_NaN = true;
|
||||
EIGEN_DIAGNOSTICS(push)
|
||||
EIGEN_DISABLE_DEPRECATED_WARNING
|
||||
static EIGEN_CONSTEXPR const std::float_denorm_style has_denorm = std::denorm_present;
|
||||
static EIGEN_CONSTEXPR const bool has_denorm_loss = false;
|
||||
static constexpr const std::float_denorm_style has_denorm = std::denorm_present;
|
||||
static constexpr const bool has_denorm_loss = false;
|
||||
EIGEN_DIAGNOSTICS(pop)
|
||||
static EIGEN_CONSTEXPR const std::float_round_style round_style = std::round_to_nearest;
|
||||
static EIGEN_CONSTEXPR const bool is_iec559 = true;
|
||||
static constexpr const std::float_round_style round_style = std::round_to_nearest;
|
||||
static constexpr const bool is_iec559 = true;
|
||||
// The C++ standard defines this as "true if the set of values representable
|
||||
// by the type is finite." Half has finite precision.
|
||||
static EIGEN_CONSTEXPR const bool is_bounded = true;
|
||||
static EIGEN_CONSTEXPR const bool is_modulo = false;
|
||||
static EIGEN_CONSTEXPR const int digits = 11;
|
||||
static EIGEN_CONSTEXPR const int digits10 =
|
||||
static constexpr const bool is_bounded = true;
|
||||
static constexpr const bool is_modulo = false;
|
||||
static constexpr const int digits = 11;
|
||||
static constexpr const int digits10 =
|
||||
3; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
|
||||
static EIGEN_CONSTEXPR const int max_digits10 =
|
||||
static constexpr const int max_digits10 =
|
||||
5; // according to http://half.sourceforge.net/structstd_1_1numeric__limits_3_01half__float_1_1half_01_4.html
|
||||
static EIGEN_CONSTEXPR const int radix = std::numeric_limits<float>::radix;
|
||||
static EIGEN_CONSTEXPR const int min_exponent = -13;
|
||||
static EIGEN_CONSTEXPR const int min_exponent10 = -4;
|
||||
static EIGEN_CONSTEXPR const int max_exponent = 16;
|
||||
static EIGEN_CONSTEXPR const int max_exponent10 = 4;
|
||||
static EIGEN_CONSTEXPR const bool traps = std::numeric_limits<float>::traps;
|
||||
static constexpr const int radix = std::numeric_limits<float>::radix;
|
||||
static constexpr const int min_exponent = -13;
|
||||
static constexpr const int min_exponent10 = -4;
|
||||
static constexpr const int max_exponent = 16;
|
||||
static constexpr const int max_exponent10 = 4;
|
||||
static constexpr const bool traps = std::numeric_limits<float>::traps;
|
||||
// IEEE754: "The implementer shall choose how tininess is detected, but shall
|
||||
// detect tininess in the same way for all operations in radix two"
|
||||
static EIGEN_CONSTEXPR const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
|
||||
static constexpr const bool tinyness_before = std::numeric_limits<float>::tinyness_before;
|
||||
|
||||
static EIGEN_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); }
|
||||
static EIGEN_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
|
||||
static EIGEN_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
|
||||
static EIGEN_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); }
|
||||
static EIGEN_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); }
|
||||
static EIGEN_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
|
||||
static EIGEN_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
|
||||
static EIGEN_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
|
||||
static EIGEN_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half(min)() { return Eigen::half_impl::raw_uint16_to_half(0x0400); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half lowest() { return Eigen::half_impl::raw_uint16_to_half(0xfbff); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half(max)() { return Eigen::half_impl::raw_uint16_to_half(0x7bff); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half epsilon() { return Eigen::half_impl::raw_uint16_to_half(0x1400); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half round_error() { return Eigen::half_impl::raw_uint16_to_half(0x3800); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half infinity() { return Eigen::half_impl::raw_uint16_to_half(0x7c00); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half quiet_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7e00); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half signaling_NaN() { return Eigen::half_impl::raw_uint16_to_half(0x7d00); }
|
||||
static _EIGEN_MAYBE_CONSTEXPR Eigen::half denorm_min() { return Eigen::half_impl::raw_uint16_to_half(0x0001); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_specialized;
|
||||
constexpr const bool numeric_limits_half_impl<T>::is_specialized;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_signed;
|
||||
constexpr const bool numeric_limits_half_impl<T>::is_signed;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_integer;
|
||||
constexpr const bool numeric_limits_half_impl<T>::is_integer;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_exact;
|
||||
constexpr const bool numeric_limits_half_impl<T>::is_exact;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_infinity;
|
||||
constexpr const bool numeric_limits_half_impl<T>::has_infinity;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_quiet_NaN;
|
||||
constexpr const bool numeric_limits_half_impl<T>::has_quiet_NaN;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_signaling_NaN;
|
||||
constexpr const bool numeric_limits_half_impl<T>::has_signaling_NaN;
|
||||
EIGEN_DIAGNOSTICS(push)
|
||||
EIGEN_DISABLE_DEPRECATED_WARNING
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const std::float_denorm_style numeric_limits_half_impl<T>::has_denorm;
|
||||
constexpr const std::float_denorm_style numeric_limits_half_impl<T>::has_denorm;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::has_denorm_loss;
|
||||
constexpr const bool numeric_limits_half_impl<T>::has_denorm_loss;
|
||||
EIGEN_DIAGNOSTICS(pop)
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const std::float_round_style numeric_limits_half_impl<T>::round_style;
|
||||
constexpr const std::float_round_style numeric_limits_half_impl<T>::round_style;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_iec559;
|
||||
constexpr const bool numeric_limits_half_impl<T>::is_iec559;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_bounded;
|
||||
constexpr const bool numeric_limits_half_impl<T>::is_bounded;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::is_modulo;
|
||||
constexpr const bool numeric_limits_half_impl<T>::is_modulo;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::digits;
|
||||
constexpr const int numeric_limits_half_impl<T>::digits;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::digits10;
|
||||
constexpr const int numeric_limits_half_impl<T>::digits10;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_digits10;
|
||||
constexpr const int numeric_limits_half_impl<T>::max_digits10;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::radix;
|
||||
constexpr const int numeric_limits_half_impl<T>::radix;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::min_exponent;
|
||||
constexpr const int numeric_limits_half_impl<T>::min_exponent;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::min_exponent10;
|
||||
constexpr const int numeric_limits_half_impl<T>::min_exponent10;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_exponent;
|
||||
constexpr const int numeric_limits_half_impl<T>::max_exponent;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const int numeric_limits_half_impl<T>::max_exponent10;
|
||||
constexpr const int numeric_limits_half_impl<T>::max_exponent10;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::traps;
|
||||
constexpr const bool numeric_limits_half_impl<T>::traps;
|
||||
template <typename T>
|
||||
EIGEN_CONSTEXPR const bool numeric_limits_half_impl<T>::tinyness_before;
|
||||
constexpr const bool numeric_limits_half_impl<T>::tinyness_before;
|
||||
} // end namespace half_impl
|
||||
} // end namespace Eigen
|
||||
|
||||
@ -320,8 +342,7 @@ namespace half_impl {
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
|
||||
// Note: We deliberately do *not* define this to 1 even if we have Arm's native
|
||||
// fp16 type since GPU half types are rather different from native CPU half types.
|
||||
// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
|
||||
#define EIGEN_HAS_NATIVE_FP16
|
||||
#define EIGEN_HAS_NATIVE_GPU_FP16
|
||||
#endif
|
||||
|
||||
// Intrinsics for native fp16 support. Note that on current hardware,
|
||||
@ -329,7 +350,7 @@ namespace half_impl {
|
||||
// versions to get the ALU speed increased), but you do save the
|
||||
// conversion steps back and forth.
|
||||
|
||||
#if defined(EIGEN_HAS_NATIVE_FP16)
|
||||
#if defined(EIGEN_HAS_NATIVE_GPU_FP16)
|
||||
EIGEN_STRONG_INLINE __device__ half operator+(const half& a, const half& b) {
|
||||
#if defined(EIGEN_CUDA_SDK_VER) && EIGEN_CUDA_SDK_VER >= 90000
|
||||
return __hadd(::__half(a), ::__half(b));
|
||||
@ -371,7 +392,8 @@ EIGEN_STRONG_INLINE __device__ bool operator<(const half& a, const half& b) { re
|
||||
EIGEN_STRONG_INLINE __device__ bool operator<=(const half& a, const half& b) { return __hle(a, b); }
|
||||
EIGEN_STRONG_INLINE __device__ bool operator>(const half& a, const half& b) { return __hgt(a, b); }
|
||||
EIGEN_STRONG_INLINE __device__ bool operator>=(const half& a, const half& b) { return __hge(a, b); }
|
||||
#endif
|
||||
|
||||
#endif // EIGEN_HAS_NATIVE_GPU_FP16
|
||||
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) && !defined(EIGEN_GPU_COMPILE_PHASE)
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(vaddh_f16(a.x, b.x)); }
|
||||
@ -401,16 +423,47 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half&
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return vcleh_f16(a.x, b.x); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return vcgth_f16(a.x, b.x); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return vcgeh_f16(a.x, b.x); }
|
||||
|
||||
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16) && !defined(EIGEN_GPU_COMPILE_PHASE)
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator+(const half& a, const half& b) { return half(a.x + b.x); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator*(const half& a, const half& b) { return half(a.x * b.x); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a, const half& b) { return half(a.x - b.x); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator/(const half& a, const half& b) { return half(a.x / b.x); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator-(const half& a) { return half(-a.x); }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator+=(half& a, const half& b) {
|
||||
a = a + b;
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator*=(half& a, const half& b) {
|
||||
a = a * b;
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator-=(half& a, const half& b) {
|
||||
a = a - b;
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
|
||||
a = a / b;
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) { return a.x == b.x; }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return a.x != b.x; }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return a.x < b.x; }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return a.x <= b.x; }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return a.x > b.x; }
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return a.x >= b.x; }
|
||||
|
||||
// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
|
||||
// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
|
||||
// of the functions, while the latter can only deal with one of them.
|
||||
#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
|
||||
#elif !defined(EIGEN_HAS_NATIVE_GPU_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
|
||||
|
||||
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
|
||||
// We need to provide emulated *host-side* FP16 operators for clang.
|
||||
#pragma push_macro("EIGEN_DEVICE_FUNC")
|
||||
#undef EIGEN_DEVICE_FUNC
|
||||
#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_FP16)
|
||||
#if defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_HAS_NATIVE_GPU_FP16)
|
||||
#define EIGEN_DEVICE_FUNC __host__
|
||||
#else // both host and device need emulated ops.
|
||||
#define EIGEN_DEVICE_FUNC __host__ __device__
|
||||
@ -458,6 +511,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half&
|
||||
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
|
||||
#pragma pop_macro("EIGEN_DEVICE_FUNC")
|
||||
#endif
|
||||
|
||||
#endif // Emulate support for half floats
|
||||
|
||||
// Division by an index. Do it in full float precision to avoid accuracy
|
||||
@ -493,7 +547,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator--(half& a, int) {
|
||||
// these in hardware. If we need more performance on older/other CPUs, they are
|
||||
// also possible to vectorize directly.
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
|
||||
// We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
|
||||
// in the hip_fp16 header file, and that will trigger a compile error
|
||||
// On the other hand, having anything but a return statement also triggers a compile error
|
||||
@ -515,6 +569,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint16_t raw_half_as_uint16(const
|
||||
// For SYCL, cl::sycl::half is _Float16, so cast directly.
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return numext::bit_cast<numext::uint16_t>(h.x);
|
||||
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
return numext::bit_cast<numext::uint16_t>(h.x);
|
||||
#elif defined(SYCL_DEVICE_ONLY)
|
||||
return numext::bit_cast<numext::uint16_t>(h);
|
||||
#else
|
||||
@ -528,6 +584,16 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
||||
__half tmp_ff = __float2half(ff);
|
||||
return *(__half_raw*)&tmp_ff;
|
||||
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
__half_raw h;
|
||||
h.x = static_cast<__fp16>(ff);
|
||||
return h;
|
||||
|
||||
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
__half_raw h;
|
||||
h.x = static_cast<_Float16>(ff);
|
||||
return h;
|
||||
|
||||
#elif defined(EIGEN_HAS_FP16_C)
|
||||
__half_raw h;
|
||||
#if EIGEN_COMP_MSVC
|
||||
@ -538,11 +604,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
||||
#endif
|
||||
return h;
|
||||
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
__half_raw h;
|
||||
h.x = static_cast<__fp16>(ff);
|
||||
return h;
|
||||
|
||||
#else
|
||||
uint32_t f_bits = Eigen::numext::bit_cast<uint32_t>(ff);
|
||||
const uint32_t f32infty_bits = {255 << 23};
|
||||
@ -595,6 +656,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
|
||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||
return __half2float(h);
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
return static_cast<float>(h.x);
|
||||
#elif defined(EIGEN_HAS_FP16_C)
|
||||
#if EIGEN_COMP_MSVC
|
||||
// MSVC does not have scalar instructions.
|
||||
@ -602,8 +665,6 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
|
||||
#else
|
||||
return _cvtsh_ss(h.x);
|
||||
#endif
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return static_cast<float>(h.x);
|
||||
#else
|
||||
const float magic = Eigen::numext::bit_cast<float>(static_cast<uint32_t>(113 << 23));
|
||||
const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
||||
@ -628,7 +689,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
|
||||
// --- standard functions ---
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isinf)(const half& a) {
|
||||
#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
|
||||
#else
|
||||
return (a.x & 0x7fff) == 0x7c00;
|
||||
@ -638,7 +699,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) {
|
||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||
return __hisnan(a);
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
|
||||
#else
|
||||
return (a.x & 0x7fff) > 0x7c00;
|
||||
@ -651,6 +712,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return half(vabsh_f16(a.x));
|
||||
#elif defined(EIGEN_HAS_BUILTIN_FLOAT16)
|
||||
half result;
|
||||
result.x =
|
||||
numext::bit_cast<_Float16>(static_cast<numext::uint16_t>(numext::bit_cast<numext::uint16_t>(a.x) & 0x7FFF));
|
||||
return result;
|
||||
#else
|
||||
half result;
|
||||
result.x = a.x & 0x7FFF;
|
||||
@ -734,26 +800,9 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half fmod(const half& a, const half& b) {
|
||||
return half(::fmodf(float(a), float(b)));
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) {
|
||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||
return __hlt(b, a) ? b : a;
|
||||
#else
|
||||
const float f1 = static_cast<float>(a);
|
||||
const float f2 = static_cast<float>(b);
|
||||
return f2 < f1 ? b : a;
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) {
|
||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||
return __hlt(a, b) ? b : a;
|
||||
#else
|
||||
const float f1 = static_cast<float>(a);
|
||||
const float f2 = static_cast<float>(b);
|
||||
return f1 < f2 ? b : a;
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(min)(const half& a, const half& b) { return b < a ? b : a; }
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half(max)(const half& a, const half& b) { return a < b ? b : a; }
|
||||
|
||||
#ifndef EIGEN_NO_IO
|
||||
EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) {
|
||||
@ -794,31 +843,29 @@ template <>
|
||||
struct NumTraits<Eigen::half> : GenericNumTraits<Eigen::half> {
|
||||
enum { IsSigned = true, IsInteger = false, IsComplex = false, RequireInitialization = false };
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half epsilon() {
|
||||
return half_impl::raw_uint16_to_half(0x0800);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half dummy_precision() {
|
||||
return half_impl::raw_uint16_to_half(0x211f); // Eigen::half(1e-2f);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half highest() {
|
||||
return half_impl::raw_uint16_to_half(0x7bff);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half lowest() {
|
||||
return half_impl::raw_uint16_to_half(0xfbff);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half infinity() {
|
||||
return half_impl::raw_uint16_to_half(0x7c00);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
|
||||
EIGEN_DEVICE_FUNC _EIGEN_MAYBE_CONSTEXPR static EIGEN_STRONG_INLINE Eigen::half quiet_NaN() {
|
||||
return half_impl::raw_uint16_to_half(0x7e00);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
#pragma pop_macro("EIGEN_CONSTEXPR")
|
||||
#endif
|
||||
#undef _EIGEN_MAYBE_CONSTEXPR
|
||||
|
||||
namespace Eigen {
|
||||
namespace numext {
|
||||
@ -976,6 +1023,36 @@ struct cast_impl<half, float> {
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline half pmadd(const half& a, const half& b, const half& c) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return half(vfmah_f16(a.x, b.x, c.x));
|
||||
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
|
||||
// Reduces to vfmadd213sh.
|
||||
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
|
||||
#else
|
||||
// Emulate FMA via float.
|
||||
return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline half pmsub(const half& a, const half& b, const half& c) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return half(vfmah_f16(a.x, b.x, -c.x));
|
||||
#elif defined(EIGEN_VECTORIZE_AVX512FP16)
|
||||
// Reduces to vfmadd213sh.
|
||||
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), -_mm_set_sh(c.x))));
|
||||
#else
|
||||
// Emulate FMA via float.
|
||||
return half(static_cast<float>(a) * static_cast<float>(b) - static_cast<float>(c));
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
|
@ -285,6 +285,8 @@
|
||||
#ifdef __AVX512FP16__
|
||||
#ifdef __AVX512VL__
|
||||
#define EIGEN_VECTORIZE_AVX512FP16
|
||||
// Built-in _Float16.
|
||||
#define EIGEN_HAS_BUILTIN_FLOAT16 1
|
||||
#else
|
||||
#if EIGEN_COMP_GNUC
|
||||
#error Please add -mavx512vl to your compiler flags: compiling with -mavx512fp16 alone without AVX512-VL is not supported.
|
||||
|
@ -7,7 +7,8 @@
|
||||
// Include this header to be able to print Packets while debugging.
|
||||
|
||||
template <typename Packet,
|
||||
typename EnableIf = std::enable_if_t<Eigen::internal::unpacket_traits<Packet>::vectorizable> >
|
||||
typename EnableIf = std::enable_if_t<(Eigen::internal::unpacket_traits<Packet>::vectorizable ||
|
||||
Eigen::internal::unpacket_traits<Packet>::size > 1)> >
|
||||
std::ostream& operator<<(std::ostream& os, const Packet& packet) {
|
||||
using Scalar = typename Eigen::internal::unpacket_traits<Packet>::type;
|
||||
Scalar v[Eigen::internal::unpacket_traits<Packet>::size];
|
||||
|
@ -26,19 +26,19 @@ inline T REF_MUL(const T& a, const T& b) {
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_MADD(const T& a, const T& b, const T& c) {
|
||||
return a * b + c;
|
||||
return internal::pmadd(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_MSUB(const T& a, const T& b, const T& c) {
|
||||
return a * b - c;
|
||||
return internal::pmsub(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_NMADD(const T& a, const T& b, const T& c) {
|
||||
return c - a * b;
|
||||
return internal::pnmadd(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_NMSUB(const T& a, const T& b, const T& c) {
|
||||
return test::negate(a * b + c);
|
||||
return internal::pnmsub(a, b, c);
|
||||
}
|
||||
template <typename T>
|
||||
inline T REF_DIV(const T& a, const T& b) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user