mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-13 01:43:13 +08:00
Add support for AVX512-FP16 for vectorizing half precision math
This commit is contained in:
parent
76a669fb45
commit
7a3b667c43
@ -258,6 +258,12 @@ if(NOT MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512dq -mfma")
|
||||
message(STATUS "Enabling AVX512DQ in tests/examples")
|
||||
endif()
|
||||
|
||||
option(EIGEN_TEST_AVX512FP16 "Enable/Disable AVX512-FP16 in tests/examples" OFF)
|
||||
if(EIGEN_TEST_AVX512FP16)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx512f -mfma -mavx512vl -mavx512fp16")
|
||||
message(STATUS "Enabling AVX512-FP16 in tests/examples")
|
||||
endif()
|
||||
|
||||
option(EIGEN_TEST_F16C "Enable/Disable F16C in tests/examples" OFF)
|
||||
if(EIGEN_TEST_F16C)
|
||||
|
@ -178,6 +178,9 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/Default/GenericPacketMathFunctionsFwd.h"
|
||||
|
||||
#if defined EIGEN_VECTORIZE_AVX512
|
||||
#if defined EIGEN_VECTORIZE_AVX512FP16
|
||||
#include "src/Core/arch/AVX512/PacketMathFP16.h"
|
||||
#endif
|
||||
#include "src/Core/arch/SSE/PacketMath.h"
|
||||
#include "src/Core/arch/SSE/TypeCasting.h"
|
||||
#include "src/Core/arch/SSE/Complex.h"
|
||||
|
@ -33,7 +33,9 @@ namespace internal {
|
||||
typedef __m256 Packet8f;
|
||||
typedef eigen_packet_wrapper<__m256i, 0> Packet8i;
|
||||
typedef __m256d Packet4d;
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
|
||||
#endif
|
||||
typedef eigen_packet_wrapper<__m128i, 3> Packet8bf;
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
@ -45,7 +47,9 @@ template<> struct is_arithmetic<__m256> { enum { value = true }; };
|
||||
template<> struct is_arithmetic<__m256i> { enum { value = true }; };
|
||||
template<> struct is_arithmetic<__m256d> { enum { value = true }; };
|
||||
template<> struct is_arithmetic<Packet8i> { enum { value = true }; };
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template<> struct is_arithmetic<Packet8h> { enum { value = true }; };
|
||||
#endif
|
||||
template<> struct is_arithmetic<Packet8bf> { enum { value = true }; };
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
template<> struct is_arithmetic<Packet4l> { enum { value = true }; };
|
||||
@ -1347,8 +1351,9 @@ template<> EIGEN_STRONG_INLINE Packet4d pblend(const Selector<4>& ifPacket, cons
|
||||
}
|
||||
|
||||
// Packet math for Eigen::half
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template<> struct unpacket_traits<Packet8h> { typedef Eigen::half type; enum {size=8, alignment=Aligned16, vectorizable=true, masked_load_available=false, masked_store_available=false}; typedef Packet8h half; };
|
||||
#endif
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8h pset1<Packet8h>(const Eigen::half& from) {
|
||||
return _mm_set1_epi16(numext::bit_cast<numext::uint16_t>(from));
|
||||
@ -1495,6 +1500,7 @@ template<> EIGEN_STRONG_INLINE Packet8h pnegate(const Packet8h& a) {
|
||||
return _mm_xor_si128(a, sign_mask);
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template<> EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
|
||||
Packet8f af = half2float(a);
|
||||
Packet8f bf = half2float(b);
|
||||
@ -1522,6 +1528,7 @@ template<> EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const
|
||||
Packet8f rf = pdiv(af, bf);
|
||||
return float2half(rf);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8h pgather<Eigen::half, Packet8h>(const Eigen::half* from, Index stride)
|
||||
{
|
||||
@ -1550,11 +1557,14 @@ template<> EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half*
|
||||
to[stride*7] = aux[7];
|
||||
}
|
||||
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template<> EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
|
||||
Packet8f af = half2float(a);
|
||||
float reduced = predux<Packet8f>(af);
|
||||
return Eigen::half(reduced);
|
||||
}
|
||||
#endif
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
|
||||
Packet8f af = half2float(a);
|
||||
|
@ -40,7 +40,9 @@ namespace internal {
|
||||
typedef __m512 Packet16f;
|
||||
typedef __m512i Packet16i;
|
||||
typedef __m512d Packet8d;
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
|
||||
#endif
|
||||
typedef eigen_packet_wrapper<__m256i, 2> Packet16bf;
|
||||
|
||||
template <>
|
||||
@ -56,6 +58,7 @@ struct is_arithmetic<__m512d> {
|
||||
enum { value = true };
|
||||
};
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template<> struct is_arithmetic<Packet16h> { enum { value = true }; };
|
||||
|
||||
template <>
|
||||
@ -100,6 +103,7 @@ struct packet_traits<half> : default_packet_traits {
|
||||
HasRint = 1
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
template<> struct packet_traits<float> : default_packet_traits
|
||||
{
|
||||
@ -196,12 +200,14 @@ struct unpacket_traits<Packet16i> {
|
||||
enum { size = 16, alignment=Aligned64, vectorizable=true, masked_load_available=false, masked_store_available=false };
|
||||
};
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template<>
|
||||
struct unpacket_traits<Packet16h> {
|
||||
typedef Eigen::half type;
|
||||
typedef Packet8h half;
|
||||
enum {size=16, alignment=Aligned32, vectorizable=true, masked_load_available=false, masked_store_available=false};
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pset1<Packet16f>(const float& from) {
|
||||
@ -1975,6 +1981,7 @@ template<> EIGEN_STRONG_INLINE Packet16h pnegate(const Packet16h& a) {
|
||||
return _mm256_xor_si256(a, sign_mask);
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
template<> EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
|
||||
Packet16f af = half2float(a);
|
||||
Packet16f bf = half2float(b);
|
||||
@ -2008,6 +2015,8 @@ template<> EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
|
||||
return half(predux(from_float));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h predux_half_dowto4<Packet16h>(const Packet16h& a) {
|
||||
Packet8h lane0 = _mm256_extractf128_si256(a, 0);
|
||||
|
870
Eigen/src/Core/arch/AVX512/PacketMathFP16.h
Normal file
870
Eigen/src/Core/arch/AVX512/PacketMathFP16.h
Normal file
@ -0,0 +1,870 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
//
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_PACKET_MATH_FP16_AVX512_H
|
||||
#define EIGEN_PACKET_MATH_FP16_AVX512_H
|
||||
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Disable the code for older versions of gcc that don't support many of the required avx512 math instrinsics.
|
||||
#if EIGEN_GNUC_AT_LEAST(5, 3) || EIGEN_COMP_CLANG || EIGEN_COMP_MSVC >= 1923 || EIGEN_COMP_ICC >= 1900
|
||||
#define EIGEN_HAS_AVX512_MATH 1
|
||||
#else
|
||||
#define EIGEN_HAS_AVX512_MATH 0
|
||||
#endif
|
||||
|
||||
typedef __m512h Packet32h;
|
||||
typedef eigen_packet_wrapper<__m256i, 1> Packet16h;
|
||||
typedef eigen_packet_wrapper<__m128i, 2> Packet8h;
|
||||
|
||||
template <>
|
||||
struct is_arithmetic<Packet8h> {
|
||||
enum { value = true };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct packet_traits<half> : default_packet_traits {
|
||||
typedef Packet32h type;
|
||||
typedef Packet16h half;
|
||||
enum {
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
size = 32,
|
||||
HasHalfPacket = 1,
|
||||
|
||||
HasCmp = 1,
|
||||
HasAdd = 1,
|
||||
HasSub = 1,
|
||||
HasMul = 1,
|
||||
HasDiv = 1,
|
||||
HasNegate = 1,
|
||||
HasAbs = 1,
|
||||
HasAbs2 = 0,
|
||||
HasMin = 1,
|
||||
HasMax = 1,
|
||||
HasConj = 1,
|
||||
HasSetLinear = 0,
|
||||
// These ones should be implemented in future
|
||||
HasLog = EIGEN_HAS_AVX512_MATH,
|
||||
HasLog1p = EIGEN_HAS_AVX512_MATH,
|
||||
HasExp = EIGEN_HAS_AVX512_MATH,
|
||||
HasExpm1 = EIGEN_HAS_AVX512_MATH,
|
||||
HasSqrt = EIGEN_HAS_AVX512_MATH,
|
||||
HasRsqrt = EIGEN_HAS_AVX512_MATH,
|
||||
HasBessel = 0, // EIGEN_HAS_AVX512_MATH,
|
||||
HasNdtri = 0, // EIGEN_HAS_AVX512_MATH,
|
||||
HasSin = EIGEN_FAST_MATH,
|
||||
HasCos = EIGEN_FAST_MATH,
|
||||
HasTanh = EIGEN_FAST_MATH,
|
||||
HasErf = 0, // EIGEN_FAST_MATH,
|
||||
HasBlend = 0,
|
||||
HasRound = 1,
|
||||
HasFloor = 1,
|
||||
HasCeil = 1,
|
||||
HasRint = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet32h> {
|
||||
typedef Eigen::half type;
|
||||
typedef Packet16h half;
|
||||
enum {
|
||||
size = 32,
|
||||
alignment = Aligned64,
|
||||
vectorizable = true,
|
||||
masked_load_available = false,
|
||||
masked_store_available = false
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet16h> {
|
||||
typedef Eigen::half type;
|
||||
typedef Packet8h half;
|
||||
enum {
|
||||
size = 16,
|
||||
alignment = Aligned32,
|
||||
vectorizable = true,
|
||||
masked_load_available = false,
|
||||
masked_store_available = false
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet8h> {
|
||||
typedef Eigen::half type;
|
||||
typedef Packet8h half;
|
||||
enum {
|
||||
size = 8,
|
||||
alignment = Aligned16,
|
||||
vectorizable = true,
|
||||
masked_load_available = false,
|
||||
masked_store_available = false
|
||||
};
|
||||
};
|
||||
|
||||
// Memory functions
|
||||
|
||||
// pset1
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pset1<Packet32h>(const Eigen::half& from) {
|
||||
return _mm512_set1_ph(static_cast<_Float16>(from));
|
||||
}
|
||||
|
||||
// pset1frombits
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pset1frombits<Packet32h>(unsigned short from) {
|
||||
return _mm512_castsi512_ph(_mm512_set1_epi16(from));
|
||||
}
|
||||
|
||||
// pfirst
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet32h>(const Packet32h& from) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512DQ
|
||||
return half_impl::raw_uint16_to_half(
|
||||
static_cast<unsigned short>(_mm256_extract_epi16(_mm512_extracti32x8_epi32(_mm512_castph_si512(from), 0), 0)));
|
||||
#else
|
||||
Eigen::half dest[32];
|
||||
_mm512_storeu_ph(dest, from);
|
||||
return dest[0];
|
||||
#endif
|
||||
}
|
||||
|
||||
// pload
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pload<Packet32h>(const Eigen::half* from) {
|
||||
EIGEN_DEBUG_ALIGNED_LOAD return _mm512_load_ph(from);
|
||||
}
|
||||
|
||||
// ploadu
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h ploadu<Packet32h>(const Eigen::half* from) {
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return _mm512_loadu_ph(from);
|
||||
}
|
||||
|
||||
// pstore
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<half>(Eigen::half* to, const Packet32h& from) {
|
||||
EIGEN_DEBUG_ALIGNED_STORE _mm512_store_ph(to, from);
|
||||
}
|
||||
|
||||
// pstoreu
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<half>(Eigen::half* to, const Packet32h& from) {
|
||||
EIGEN_DEBUG_UNALIGNED_STORE _mm512_storeu_ph(to, from);
|
||||
}
|
||||
|
||||
// ploaddup
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h ploaddup<Packet32h>(const Eigen::half* from) {
|
||||
__m512h a = _mm512_castph256_ph512(_mm256_loadu_ph(from));
|
||||
return _mm512_permutexvar_ph(_mm512_set_epi16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6,
|
||||
5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0),
|
||||
a);
|
||||
}
|
||||
|
||||
// ploadquad
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h ploadquad<Packet32h>(const Eigen::half* from) {
|
||||
__m512h a = _mm512_castph128_ph512(_mm_loadu_ph(from));
|
||||
return _mm512_permutexvar_ph(
|
||||
_mm512_set_epi16(7, 7, 7, 7, 6, 6, 6, 6, 5, 5, 5, 5, 4, 4, 4, 4, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0),
|
||||
a);
|
||||
}
|
||||
|
||||
// pabs
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pabs<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_abs_ph(a);
|
||||
}
|
||||
|
||||
// pmin
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pmin<Packet32h>(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_min_ph(a, b);
|
||||
}
|
||||
|
||||
// pmax
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pmax<Packet32h>(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_max_ph(a, b);
|
||||
}
|
||||
|
||||
// plset
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h plset<Packet32h>(const half& a) {
|
||||
return _mm512_add_ph(_mm512_set1_ph(a),
|
||||
_mm512_set_ph(31.0f, 30.0f, 29.0f, 28.0f, 27.0f, 26.0f, 25.0f, 24.0f, 23.0f, 22.0f, 21.0f, 20.0f,
|
||||
19.0f, 18.0f, 17.0f, 16.0f, 15.0f, 14.0f, 13.0f, 12.0f, 11.0f, 10.0f, 9.0f, 8.0f,
|
||||
7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f));
|
||||
}
|
||||
|
||||
// por
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h por(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_castsi512_ph(_mm512_or_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
|
||||
}
|
||||
|
||||
// pxor
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pxor(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_castsi512_ph(_mm512_xor_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
|
||||
}
|
||||
|
||||
// pand
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pand(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_castsi512_ph(_mm512_and_si512(_mm512_castph_si512(a), _mm512_castph_si512(b)));
|
||||
}
|
||||
|
||||
// pandnot
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pandnot(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_castsi512_ph(_mm512_andnot_si512(_mm512_castph_si512(b), _mm512_castph_si512(a)));
|
||||
}
|
||||
|
||||
// pselect
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet32h pselect(const Packet32h& mask, const Packet32h& a, const Packet32h& b) {
|
||||
__mmask32 mask32 = _mm512_cmp_epi16_mask(_mm512_castph_si512(mask), _mm512_setzero_epi32(), _MM_CMPINT_EQ);
|
||||
return _mm512_mask_blend_ph(mask32, a, b);
|
||||
}
|
||||
|
||||
// pcmp_eq
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcmp_eq(const Packet32h& a, const Packet32h& b) {
|
||||
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_EQ_OQ);
|
||||
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
|
||||
}
|
||||
|
||||
// pcmp_le
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcmp_le(const Packet32h& a, const Packet32h& b) {
|
||||
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LE_OQ);
|
||||
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
|
||||
}
|
||||
|
||||
// pcmp_lt
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcmp_lt(const Packet32h& a, const Packet32h& b) {
|
||||
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_LT_OQ);
|
||||
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi32(0), mask, 0xffffu));
|
||||
}
|
||||
|
||||
// pcmp_lt_or_nan
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcmp_lt_or_nan(const Packet32h& a, const Packet32h& b) {
|
||||
__mmask32 mask = _mm512_cmp_ph_mask(a, b, _CMP_NGE_UQ);
|
||||
return _mm512_castsi512_ph(_mm512_mask_set1_epi16(_mm512_set1_epi16(0), mask, 0xffffu));
|
||||
}
|
||||
|
||||
// padd
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h padd<Packet32h>(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_add_ph(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h padd<Packet16h>(const Packet16h& a, const Packet16h& b) {
|
||||
return _mm256_castph_si256(_mm256_add_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h padd<Packet8h>(const Packet8h& a, const Packet8h& b) {
|
||||
return _mm_castph_si128(_mm_add_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
|
||||
}
|
||||
|
||||
// psub
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h psub<Packet32h>(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_sub_ph(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h psub<Packet16h>(const Packet16h& a, const Packet16h& b) {
|
||||
return _mm256_castph_si256(_mm256_sub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h psub<Packet8h>(const Packet8h& a, const Packet8h& b) {
|
||||
return _mm_castph_si128(_mm_sub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
|
||||
}
|
||||
|
||||
// pmul
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pmul<Packet32h>(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_mul_ph(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pmul<Packet16h>(const Packet16h& a, const Packet16h& b) {
|
||||
return _mm256_castph_si256(_mm256_mul_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pmul<Packet8h>(const Packet8h& a, const Packet8h& b) {
|
||||
return _mm_castph_si128(_mm_mul_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
|
||||
}
|
||||
|
||||
// pdiv
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pdiv<Packet32h>(const Packet32h& a, const Packet32h& b) {
|
||||
return _mm512_div_ph(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16h& b) {
|
||||
return _mm256_castph_si256(_mm256_div_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pdiv<Packet8h>(const Packet8h& a, const Packet8h& b) {
|
||||
return _mm_castph_si128(_mm_div_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b)));
|
||||
}
|
||||
|
||||
// pround
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pround<Packet32h>(const Packet32h& a) {
|
||||
// Work-around for default std::round rounding mode.
|
||||
|
||||
// Mask for the sign bit
|
||||
const Packet32h signMask = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x8000u));
|
||||
// The largest half-preicision float less than 0.5
|
||||
const Packet32h prev0dot5 = pset1frombits<Packet32h>(static_cast<numext::uint16_t>(0x37FFu));
|
||||
|
||||
return _mm512_roundscale_ph(padd(por(pand(a, signMask), prev0dot5), a), _MM_FROUND_TO_ZERO);
|
||||
}
|
||||
|
||||
// print
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h print<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_roundscale_ph(a, _MM_FROUND_CUR_DIRECTION);
|
||||
}
|
||||
|
||||
// pceil
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pceil<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_roundscale_ph(a, _MM_FROUND_TO_POS_INF);
|
||||
}
|
||||
|
||||
// pfloor
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pfloor<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_roundscale_ph(a, _MM_FROUND_TO_NEG_INF);
|
||||
}
|
||||
|
||||
// predux
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE half predux<Packet32h>(const Packet32h& a) {
|
||||
return (half)_mm512_reduce_add_ph(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& a) {
|
||||
return (half)_mm256_reduce_add_ph(_mm256_castsi256_ph(a));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE half predux<Packet8h>(const Packet8h& a) {
|
||||
return (half)_mm_reduce_add_ph(_mm_castsi128_ph(a));
|
||||
}
|
||||
|
||||
// predux_half_dowto4
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h predux_half_dowto4<Packet32h>(const Packet32h& a) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX512DQ
|
||||
__m256i lowHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 0));
|
||||
__m256i highHalf = _mm256_castps_si256(_mm512_extractf32x8_ps(_mm512_castph_ps(a), 1));
|
||||
|
||||
return Packet16h(padd<Packet16h>(lowHalf, highHalf));
|
||||
#else
|
||||
Eigen::half data[32];
|
||||
_mm512_storeu_ph(data, a);
|
||||
|
||||
__m256i lowHalf = _mm256_castph_si256(_mm256_loadu_ph(data));
|
||||
__m256i highHalf = _mm256_castph_si256(_mm256_loadu_ph(data + 16));
|
||||
|
||||
return Packet16h(padd<Packet16h>(lowHalf, highHalf));
|
||||
#endif
|
||||
}
|
||||
|
||||
// predux_max
|
||||
|
||||
// predux_min
|
||||
|
||||
// predux_mul
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
|
||||
// pmadd
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
|
||||
return _mm512_fmadd_ph(a, b, c);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||
return _mm256_castph_si256(_mm256_fmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return _mm_castph_si128(_mm_fmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
|
||||
}
|
||||
|
||||
// pmsub
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
|
||||
return _mm512_fmsub_ph(a, b, c);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||
return _mm256_castph_si256(_mm256_fmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return _mm_castph_si128(_mm_fmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
|
||||
}
|
||||
|
||||
// pnmadd
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pnmadd(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
|
||||
return _mm512_fnmadd_ph(a, b, c);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||
return _mm256_castph_si256(_mm256_fnmadd_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pnmadd(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return _mm_castph_si128(_mm_fnmadd_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
|
||||
}
|
||||
|
||||
// pnmsub
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pnmsub(const Packet32h& a, const Packet32h& b, const Packet32h& c) {
|
||||
return _mm512_fnmsub_ph(a, b, c);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||
return _mm256_castph_si256(_mm256_fnmsub_ph(_mm256_castsi256_ph(a), _mm256_castsi256_ph(b), _mm256_castsi256_ph(c)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pnmsub(const Packet8h& a, const Packet8h& b, const Packet8h& c) {
|
||||
return _mm_castph_si128(_mm_fnmsub_ph(_mm_castsi128_ph(a), _mm_castsi128_ph(b), _mm_castsi128_ph(c)));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// pnegate
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pnegate<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_sub_ph(_mm512_set1_ph(0.0), a);
|
||||
}
|
||||
|
||||
// pconj
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pconj<Packet32h>(const Packet32h& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
// psqrt
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h psqrt<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_sqrt_ph(a);
|
||||
}
|
||||
|
||||
// prsqrt
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h prsqrt<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_rsqrt_ph(a);
|
||||
}
|
||||
|
||||
// preciprocal
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h preciprocal<Packet32h>(const Packet32h& a) {
|
||||
return _mm512_rcp_ph(a);
|
||||
}
|
||||
|
||||
// ptranspose
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 32>& a) {
|
||||
__m512i t[32];
|
||||
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < 16; i++) {
|
||||
t[2 * i] = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
|
||||
t[2 * i + 1] =
|
||||
_mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2 * i]), _mm512_castph_si512(a.packet[2 * i + 1]));
|
||||
}
|
||||
|
||||
__m512i p[32];
|
||||
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < 8; i++) {
|
||||
p[4 * i] = _mm512_unpacklo_epi32(t[4 * i], t[4 * i + 2]);
|
||||
p[4 * i + 1] = _mm512_unpackhi_epi32(t[4 * i], t[4 * i + 2]);
|
||||
p[4 * i + 2] = _mm512_unpacklo_epi32(t[4 * i + 1], t[4 * i + 3]);
|
||||
p[4 * i + 3] = _mm512_unpackhi_epi32(t[4 * i + 1], t[4 * i + 3]);
|
||||
}
|
||||
|
||||
__m512i q[32];
|
||||
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[8 * i] = _mm512_unpacklo_epi64(p[8 * i], p[8 * i + 4]);
|
||||
q[8 * i + 1] = _mm512_unpackhi_epi64(p[8 * i], p[8 * i + 4]);
|
||||
q[8 * i + 2] = _mm512_unpacklo_epi64(p[8 * i + 1], p[8 * i + 5]);
|
||||
q[8 * i + 3] = _mm512_unpackhi_epi64(p[8 * i + 1], p[8 * i + 5]);
|
||||
q[8 * i + 4] = _mm512_unpacklo_epi64(p[8 * i + 2], p[8 * i + 6]);
|
||||
q[8 * i + 5] = _mm512_unpackhi_epi64(p[8 * i + 2], p[8 * i + 6]);
|
||||
q[8 * i + 6] = _mm512_unpacklo_epi64(p[8 * i + 3], p[8 * i + 7]);
|
||||
q[8 * i + 7] = _mm512_unpackhi_epi64(p[8 * i + 3], p[8 * i + 7]);
|
||||
}
|
||||
|
||||
__m512i f[32];
|
||||
|
||||
#define PACKET32H_TRANSPOSE_HELPER(X, Y) \
|
||||
do { \
|
||||
f[Y * 8] = _mm512_inserti32x4(f[Y * 8], _mm512_extracti32x4_epi32(q[X * 8], Y), X); \
|
||||
f[Y * 8 + 1] = _mm512_inserti32x4(f[Y * 8 + 1], _mm512_extracti32x4_epi32(q[X * 8 + 1], Y), X); \
|
||||
f[Y * 8 + 2] = _mm512_inserti32x4(f[Y * 8 + 2], _mm512_extracti32x4_epi32(q[X * 8 + 2], Y), X); \
|
||||
f[Y * 8 + 3] = _mm512_inserti32x4(f[Y * 8 + 3], _mm512_extracti32x4_epi32(q[X * 8 + 3], Y), X); \
|
||||
f[Y * 8 + 4] = _mm512_inserti32x4(f[Y * 8 + 4], _mm512_extracti32x4_epi32(q[X * 8 + 4], Y), X); \
|
||||
f[Y * 8 + 5] = _mm512_inserti32x4(f[Y * 8 + 5], _mm512_extracti32x4_epi32(q[X * 8 + 5], Y), X); \
|
||||
f[Y * 8 + 6] = _mm512_inserti32x4(f[Y * 8 + 6], _mm512_extracti32x4_epi32(q[X * 8 + 6], Y), X); \
|
||||
f[Y * 8 + 7] = _mm512_inserti32x4(f[Y * 8 + 7], _mm512_extracti32x4_epi32(q[X * 8 + 7], Y), X); \
|
||||
} while (false);
|
||||
|
||||
PACKET32H_TRANSPOSE_HELPER(0, 0);
|
||||
PACKET32H_TRANSPOSE_HELPER(1, 1);
|
||||
PACKET32H_TRANSPOSE_HELPER(2, 2);
|
||||
PACKET32H_TRANSPOSE_HELPER(3, 3);
|
||||
|
||||
PACKET32H_TRANSPOSE_HELPER(1, 0);
|
||||
PACKET32H_TRANSPOSE_HELPER(2, 0);
|
||||
PACKET32H_TRANSPOSE_HELPER(3, 0);
|
||||
PACKET32H_TRANSPOSE_HELPER(2, 1);
|
||||
PACKET32H_TRANSPOSE_HELPER(3, 1);
|
||||
PACKET32H_TRANSPOSE_HELPER(3, 2);
|
||||
|
||||
PACKET32H_TRANSPOSE_HELPER(0, 1);
|
||||
PACKET32H_TRANSPOSE_HELPER(0, 2);
|
||||
PACKET32H_TRANSPOSE_HELPER(0, 3);
|
||||
PACKET32H_TRANSPOSE_HELPER(1, 2);
|
||||
PACKET32H_TRANSPOSE_HELPER(1, 3);
|
||||
PACKET32H_TRANSPOSE_HELPER(2, 3);
|
||||
|
||||
#undef PACKET32H_TRANSPOSE_HELPER
|
||||
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < 32; i++) {
|
||||
a.packet[i] = _mm512_castsi512_ph(f[i]);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet32h, 4>& a) {
|
||||
__m512i p0, p1, p2, p3, t0, t1, t2, t3, a0, a1, a2, a3;
|
||||
t0 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
|
||||
t1 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[0]), _mm512_castph_si512(a.packet[1]));
|
||||
t2 = _mm512_unpacklo_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
|
||||
t3 = _mm512_unpackhi_epi16(_mm512_castph_si512(a.packet[2]), _mm512_castph_si512(a.packet[3]));
|
||||
|
||||
p0 = _mm512_unpacklo_epi32(t0, t2);
|
||||
p1 = _mm512_unpackhi_epi32(t0, t2);
|
||||
p2 = _mm512_unpacklo_epi32(t1, t3);
|
||||
p3 = _mm512_unpackhi_epi32(t1, t3);
|
||||
|
||||
a0 = p0;
|
||||
a1 = p1;
|
||||
a2 = p2;
|
||||
a3 = p3;
|
||||
|
||||
a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p1, 0), 1);
|
||||
a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p0, 1), 0);
|
||||
|
||||
a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p2, 0), 2);
|
||||
a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p0, 2), 0);
|
||||
|
||||
a0 = _mm512_inserti32x4(a0, _mm512_extracti32x4_epi32(p3, 0), 3);
|
||||
a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p0, 3), 0);
|
||||
|
||||
a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p2, 1), 2);
|
||||
a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p1, 2), 1);
|
||||
|
||||
a2 = _mm512_inserti32x4(a2, _mm512_extracti32x4_epi32(p3, 2), 3);
|
||||
a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p2, 3), 2);
|
||||
|
||||
a1 = _mm512_inserti32x4(a1, _mm512_extracti32x4_epi32(p3, 1), 3);
|
||||
a3 = _mm512_inserti32x4(a3, _mm512_extracti32x4_epi32(p1, 3), 1);
|
||||
|
||||
a.packet[0] = _mm512_castsi512_ph(a0);
|
||||
a.packet[1] = _mm512_castsi512_ph(a1);
|
||||
a.packet[2] = _mm512_castsi512_ph(a2);
|
||||
a.packet[3] = _mm512_castsi512_ph(a3);
|
||||
}
|
||||
|
||||
// preverse
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h preverse(const Packet32h& a) {
|
||||
return _mm512_permutexvar_ph(_mm512_set_epi16(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31),
|
||||
a);
|
||||
}
|
||||
|
||||
// pscatter
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pscatter<half, Packet32h>(half* to, const Packet32h& from, Index stride) {
|
||||
EIGEN_ALIGN64 half aux[32];
|
||||
pstore(aux, from);
|
||||
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < 32; i++) {
|
||||
to[stride * i] = aux[i];
|
||||
}
|
||||
}
|
||||
|
||||
// pgather
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pgather<Eigen::half, Packet32h>(const Eigen::half* from, Index stride) {
|
||||
return _mm512_castsi512_ph(_mm512_set_epi16(
|
||||
from[31 * stride].x, from[30 * stride].x, from[29 * stride].x, from[28 * stride].x, from[27 * stride].x,
|
||||
from[26 * stride].x, from[25 * stride].x, from[24 * stride].x, from[23 * stride].x, from[22 * stride].x,
|
||||
from[21 * stride].x, from[20 * stride].x, from[19 * stride].x, from[18 * stride].x, from[17 * stride].x,
|
||||
from[16 * stride].x, from[15 * stride].x, from[14 * stride].x, from[13 * stride].x, from[12 * stride].x,
|
||||
from[11 * stride].x, from[10 * stride].x, from[9 * stride].x, from[8 * stride].x, from[7 * stride].x,
|
||||
from[6 * stride].x, from[5 * stride].x, from[4 * stride].x, from[3 * stride].x, from[2 * stride].x,
|
||||
from[1 * stride].x, from[0 * stride].x));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pcos<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h psin<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h plog<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h plog2<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h plog1p<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pexp<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pexpm1<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h ptanh<Packet16h>(const Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pfrexp<Packet16h>(const Packet16h&, Packet16h&);
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pldexp<Packet16h>(const Packet16h&, const Packet16h&);
|
||||
|
||||
EIGEN_STRONG_INLINE Packet32h combine2Packet16h(const Packet16h& a, const Packet16h& b) {
|
||||
__m512d result = _mm512_undefined_pd();
|
||||
result = _mm512_insertf64x4(result, _mm256_castsi256_pd(a), 0);
|
||||
result = _mm512_insertf64x4(result, _mm256_castsi256_pd(b), 1);
|
||||
return _mm512_castpd_ph(result);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void extract2Packet16h(const Packet32h& x, Packet16h& a, Packet16h& b) {
|
||||
a = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 0));
|
||||
b = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(x), 1));
|
||||
}
|
||||
|
||||
// psin
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h psin<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = psin(low);
|
||||
Packet16h highOut = psin(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// pcos
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcos<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = pcos(low);
|
||||
Packet16h highOut = pcos(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// plog
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h plog<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = plog(low);
|
||||
Packet16h highOut = plog(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// plog2
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h plog2<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = plog2(low);
|
||||
Packet16h highOut = plog2(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// plog1p
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h plog1p<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = plog1p(low);
|
||||
Packet16h highOut = plog1p(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// pexp
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pexp<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = pexp(low);
|
||||
Packet16h highOut = pexp(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// pexpm1
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pexpm1<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = pexpm1(low);
|
||||
Packet16h highOut = pexpm1(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// ptanh
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h ptanh<Packet32h>(const Packet32h& a) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h lowOut = ptanh(low);
|
||||
Packet16h highOut = ptanh(high);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// pfrexp
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pfrexp<Packet32h>(const Packet32h& a, Packet32h& exponent) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h exp1 = _mm256_undefined_si256();
|
||||
Packet16h exp2 = _mm256_undefined_si256();
|
||||
|
||||
Packet16h lowOut = pfrexp(low, exp1);
|
||||
Packet16h highOut = pfrexp(high, exp2);
|
||||
|
||||
exponent = combine2Packet16h(exp1, exp2);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
// pldexp
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pldexp<Packet32h>(const Packet32h& a, const Packet32h& exponent) {
|
||||
Packet16h low;
|
||||
Packet16h high;
|
||||
extract2Packet16h(a, low, high);
|
||||
|
||||
Packet16h exp1;
|
||||
Packet16h exp2;
|
||||
extract2Packet16h(exponent, exp1, exp2);
|
||||
|
||||
Packet16h lowOut = pldexp(low, exp1);
|
||||
Packet16h highOut = pldexp(high, exp2);
|
||||
|
||||
return combine2Packet16h(lowOut, highOut);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_PACKET_MATH_FP16_AVX512_H
|
@ -80,6 +80,8 @@ template<> EIGEN_STRONG_INLINE Packet8d preinterpret<Packet8d, Packet8d>(const P
|
||||
return a;
|
||||
}
|
||||
|
||||
#ifndef EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<half, float> {
|
||||
enum {
|
||||
@ -106,6 +108,8 @@ template<> EIGEN_STRONG_INLINE Packet16h pcast<Packet16f, Packet16h>(const Packe
|
||||
return float2half(a);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<bfloat16, float> {
|
||||
enum {
|
||||
@ -132,6 +136,77 @@ template<> EIGEN_STRONG_INLINE Packet16bf pcast<Packet16f, Packet16bf>(const Pac
|
||||
return F32ToBf16(a);
|
||||
}
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_AVX512FP16
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<half, float> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 2
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<float, half> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 2,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcast<Packet32h, Packet16f>(const Packet32h& a) {
|
||||
// Discard second-half of input.
|
||||
Packet16h low = _mm256_castpd_si256(_mm512_extractf64x4_pd(_mm512_castph_pd(a), 0));
|
||||
return _mm512_cvtxph_ps(_mm256_castsi256_ph(low));
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet32h pcast<Packet16f, Packet32h>(const Packet16f& a, const Packet16f& b) {
|
||||
__m512d result = _mm512_undefined_pd();
|
||||
result = _mm512_insertf64x4(result, _mm256_castsi256_pd(_mm512_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 0);
|
||||
result = _mm512_insertf64x4(result, _mm256_castsi256_pd(_mm512_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 1);
|
||||
return _mm512_castpd_ph(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8f pcast<Packet16h, Packet8f>(const Packet16h& a) {
|
||||
// Discard second-half of input.
|
||||
Packet8h low = _mm_castps_si128(_mm256_extractf32x4_ps(_mm256_castsi256_ps(a), 0));
|
||||
return _mm256_cvtxph_ps(_mm_castsi128_ph(low));
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16h pcast<Packet8f, Packet16h>(const Packet8f& a, const Packet8f& b) {
|
||||
__m256d result = _mm256_undefined_pd();
|
||||
result = _mm256_insertf64x2(result, _mm_castsi128_pd(_mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 0);
|
||||
result = _mm256_insertf64x2(result, _mm_castsi128_pd(_mm256_cvtps_ph(b, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC)), 1);
|
||||
return _mm256_castpd_si256(result);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4f pcast<Packet8h, Packet4f>(const Packet8h& a) {
|
||||
Packet8f full = _mm256_cvtxph_ps(_mm_castsi128_ph(a));
|
||||
// Discard second-half of input.
|
||||
return _mm256_extractf32x4_ps(full, 0);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8h pcast<Packet4f, Packet8h>(const Packet4f& a, const Packet4f& b) {
|
||||
__m256 result = _mm256_undefined_ps();
|
||||
result = _mm256_insertf128_ps(result, a, 0);
|
||||
result = _mm256_insertf128_ps(result, b, 1);
|
||||
return _mm256_cvtps_ph(result, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -270,6 +270,17 @@
|
||||
#ifdef __AVX512BF16__
|
||||
#define EIGEN_VECTORIZE_AVX512BF16
|
||||
#endif
|
||||
#ifdef __AVX512FP16__
|
||||
#ifdef __AVX512VL__
|
||||
#define EIGEN_VECTORIZE_AVX512FP16
|
||||
#else
|
||||
#if EIGEN_COMP_GNUC
|
||||
#error Please add -mavx512vl to your compiler flags: compiling with -mavx512fp16 alone without AVX512-VL is not supported.
|
||||
#else
|
||||
#error Please enable AVX512-VL in your compiler flags (e.g. -mavx512vl): compiling with AVX512-FP16 alone without AVX512-VL is not supported.
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user