mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-19 16:19:37 +08:00
Add support for Armv8.2-a __fp16
Armv8.2-a provides a native half-precision floating point (__fp16 aka. float16_t). This patch introduces * __fp16 as underlying type of Eigen::half if this type is available * the packet types Packet4hf and Packet8hf representing float16x4_t and float16x8_t respectively * packet-math for the above packets with corresponding scalar type Eigen::half The packet-math functionality has been implemented by Ashutosh Sharma <ashutosh.sharma@amperecomputing.com>. This closes #1940.
This commit is contained in:
parent
a725a3233c
commit
e265f7ed8e
@ -1846,6 +1846,11 @@ template<> struct random_impl<bool>
|
||||
{
|
||||
return random<int>(0,1)==0 ? false : true;
|
||||
}
|
||||
|
||||
static inline bool run(const bool& a, const bool& b)
|
||||
{
|
||||
return random<int>(a, b)==0 ? false : true;
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct scalar_fuzzy_impl<bool>
|
||||
|
@ -77,6 +77,30 @@ struct default_digits_impl<T,false,true> // Integer
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
namespace numext {
|
||||
/** \internal bit-wise cast without changing the underlying bit representation. */
|
||||
|
||||
// TODO: Replace by std::bit_cast (available in C++20)
|
||||
template <typename Tgt, typename Src>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Tgt bit_cast(const Src& src) {
|
||||
#if EIGEN_HAS_TYPE_TRAITS
|
||||
// The behaviour of memcpy is not specified for non-trivially copyable types
|
||||
EIGEN_STATIC_ASSERT(std::is_trivially_copyable<Src>::value, THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
EIGEN_STATIC_ASSERT(std::is_trivially_copyable<Tgt>::value && std::is_default_constructible<Tgt>::value,
|
||||
THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
#endif
|
||||
|
||||
EIGEN_STATIC_ASSERT(sizeof(Src) == sizeof(Tgt), THIS_TYPE_IS_NOT_SUPPORTED);
|
||||
Tgt tgt;
|
||||
EIGEN_USING_STD(memcpy)
|
||||
memcpy(&tgt, &src, sizeof(Tgt));
|
||||
return tgt;
|
||||
}
|
||||
|
||||
/** \internal extract the bits of the float \a x */
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC numext::uint32_t as_uint(float x) { return bit_cast<numext::uint32_t>(x); }
|
||||
} // namespace numext
|
||||
|
||||
/** \class NumTraits
|
||||
* \ingroup Core_Module
|
||||
*
|
||||
|
@ -44,8 +44,7 @@
|
||||
|
||||
#include <sstream>
|
||||
|
||||
|
||||
#if defined(EIGEN_HAS_GPU_FP16)
|
||||
#if defined(EIGEN_HAS_GPU_FP16) || defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
// When compiling with GPU support, the "__half_raw" base class as well as
|
||||
// some other routines are defined in the GPU compiler header files
|
||||
// (cuda_fp16.h, hip_fp16.h), and they are not tagged constexpr
|
||||
@ -81,9 +80,16 @@ namespace half_impl {
|
||||
// Make our own __half_raw definition that is similar to CUDA's.
|
||||
struct __half_raw {
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw() : x(0) {}
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(unsigned short raw) : x(raw) {}
|
||||
unsigned short x;
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(numext::bit_cast<__fp16>(raw)) {
|
||||
}
|
||||
__fp16 x;
|
||||
#else
|
||||
explicit EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw(numext::uint16_t raw) : x(raw) {}
|
||||
numext::uint16_t x;
|
||||
#endif
|
||||
};
|
||||
|
||||
#elif defined(EIGEN_HAS_HIP_FP16)
|
||||
// Nothing to do here
|
||||
// HIP fp16 header file has a definition for __half_raw
|
||||
@ -98,7 +104,7 @@ typedef cl::sycl::half __half_raw;
|
||||
|
||||
#endif
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(unsigned short x);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff);
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h);
|
||||
|
||||
@ -160,6 +166,7 @@ struct half : public half_impl::half_base {
|
||||
: half_impl::half_base(half_impl::float_to_half_rtne(static_cast<float>(val))) {}
|
||||
explicit EIGEN_DEVICE_FUNC half(float f)
|
||||
: half_impl::half_base(half_impl::float_to_half_rtne(f)) {}
|
||||
|
||||
// Following the convention of numpy, converting between complex and
|
||||
// float will lead to loss of imag value.
|
||||
template<typename RealScalar>
|
||||
@ -168,7 +175,11 @@ struct half : public half_impl::half_base {
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(bool) const {
|
||||
// +0.0 and -0.0 become false, everything else becomes true.
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return (numext::bit_cast<numext::uint16_t>(x) & 0x7fff) != 0;
|
||||
#else
|
||||
return (x & 0x7fff) != 0;
|
||||
#endif
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(signed char) const {
|
||||
return static_cast<signed char>(half_impl::half_to_float(*this));
|
||||
@ -179,8 +190,8 @@ struct half : public half_impl::half_base {
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(short) const {
|
||||
return static_cast<short>(half_impl::half_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(unsigned short) const {
|
||||
return static_cast<unsigned short>(half_impl::half_to_float(*this));
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(numext::uint16_t) const {
|
||||
return static_cast<numext::uint16_t>(half_impl::half_to_float(*this));
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(int) const {
|
||||
return static_cast<int>(half_impl::half_to_float(*this));
|
||||
@ -272,6 +283,9 @@ namespace half_impl {
|
||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && \
|
||||
EIGEN_CUDA_ARCH >= 530) || \
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(HIP_DEVICE_COMPILE))
|
||||
// Note: We deliberatly do *not* define this to 1 even if we have Arm's native
|
||||
// fp16 type since GPU halfs are rather different from native CPU halfs.
|
||||
// TODO: Rename to something like EIGEN_HAS_NATIVE_GPU_FP16
|
||||
#define EIGEN_HAS_NATIVE_FP16
|
||||
#endif
|
||||
|
||||
@ -340,13 +354,62 @@ EIGEN_STRONG_INLINE __device__ bool operator > (const half& a, const half& b) {
|
||||
EIGEN_STRONG_INLINE __device__ bool operator >= (const half& a, const half& b) {
|
||||
return __hge(a, b);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
|
||||
return half(vaddh_f16(a.x, b.x));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator * (const half& a, const half& b) {
|
||||
return half(vmulh_f16(a.x, b.x));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a, const half& b) {
|
||||
return half(vsubh_f16(a.x, b.x));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, const half& b) {
|
||||
return half(vdivh_f16(a.x, b.x));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator - (const half& a) {
|
||||
return half(vnegh_f16(a.x));
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator += (half& a, const half& b) {
|
||||
a = half(vaddh_f16(a.x, b.x));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator *= (half& a, const half& b) {
|
||||
a = half(vmulh_f16(a.x, b.x));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator -= (half& a, const half& b) {
|
||||
a = half(vsubh_f16(a.x, b.x));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator /= (half& a, const half& b) {
|
||||
a = half(vdivh_f16(a.x, b.x));
|
||||
return a;
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator == (const half& a, const half& b) {
|
||||
return vceqh_f16(a.x, b.x);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator != (const half& a, const half& b) {
|
||||
return !vceqh_f16(a.x, b.x);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator < (const half& a, const half& b) {
|
||||
return vclth_f16(a.x, b.x);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator <= (const half& a, const half& b) {
|
||||
return vcleh_f16(a.x, b.x);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator > (const half& a, const half& b) {
|
||||
return vcgth_f16(a.x, b.x);
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator >= (const half& a, const half& b) {
|
||||
return vcgeh_f16(a.x, b.x);
|
||||
}
|
||||
// We need to distinguish ‘clang as the CUDA compiler’ from ‘clang as the host compiler,
|
||||
// invoked by NVCC’ (e.g. on MacOS). The former needs to see both host and device implementation
|
||||
// of the functions, while the latter can only deal with one of them.
|
||||
#if !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
|
||||
#elif !defined(EIGEN_HAS_NATIVE_FP16) || (EIGEN_COMP_CLANG && !EIGEN_COMP_NVCC) // Emulate support for half floats
|
||||
|
||||
#if EIGEN_COMP_CLANG && defined(EIGEN_CUDACC)
|
||||
// We need to provide emulated *host-side* FP16 operators for clang.
|
||||
@ -361,7 +424,6 @@ EIGEN_STRONG_INLINE __device__ bool operator >= (const half& a, const half& b) {
|
||||
|
||||
// Definitions for CPUs and older HIP+CUDA, mostly working through conversion
|
||||
// to/from fp32.
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator + (const half& a, const half& b) {
|
||||
return half(float(a) + float(b));
|
||||
}
|
||||
@ -430,10 +492,10 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half operator / (const half& a, Index b) {
|
||||
// these in hardware. If we need more performance on older/other CPUs, they are
|
||||
// also possible to vectorize directly.
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(unsigned short x) {
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR __half_raw raw_uint16_to_half(numext::uint16_t x) {
|
||||
// We cannot simply do a "return __half_raw(x)" here, because __half_raw is union type
|
||||
// in the hip_fp16 header file, and that will trigger a compile error
|
||||
// On the other hand, having anythion but a return statement also triggers a compile error
|
||||
// On the other hand, having anything but a return statement also triggers a compile error
|
||||
// because this is constexpr function.
|
||||
// Fortunately, since we need to disable EIGEN_CONSTEXPR for GPU anyway, we can get out
|
||||
// of this catch22 by having separate bodies for GPU / non GPU
|
||||
@ -462,6 +524,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
||||
h.x = _cvtss_sh(ff, 0);
|
||||
return h;
|
||||
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
__half_raw h;
|
||||
h.x = static_cast<__fp16>(ff);
|
||||
return h;
|
||||
|
||||
#else
|
||||
float32_bits f; f.f = ff;
|
||||
|
||||
@ -470,7 +537,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
||||
const float32_bits denorm_magic = { ((127 - 15) + (23 - 10) + 1) << 23 };
|
||||
unsigned int sign_mask = 0x80000000u;
|
||||
__half_raw o;
|
||||
o.x = static_cast<unsigned short>(0x0u);
|
||||
o.x = static_cast<numext::uint16_t>(0x0u);
|
||||
|
||||
unsigned int sign = f.u & sign_mask;
|
||||
f.u ^= sign;
|
||||
@ -490,7 +557,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
||||
f.f += denorm_magic.f;
|
||||
|
||||
// and one integer subtract of the bias later, we have our final float!
|
||||
o.x = static_cast<unsigned short>(f.u - denorm_magic.u);
|
||||
o.x = static_cast<numext::uint16_t>(f.u - denorm_magic.u);
|
||||
} else {
|
||||
unsigned int mant_odd = (f.u >> 13) & 1; // resulting mantissa is odd
|
||||
|
||||
@ -501,11 +568,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC __half_raw float_to_half_rtne(float ff) {
|
||||
// rounding bias part 2
|
||||
f.u += mant_odd;
|
||||
// take the bits!
|
||||
o.x = static_cast<unsigned short>(f.u >> 13);
|
||||
o.x = static_cast<numext::uint16_t>(f.u >> 13);
|
||||
}
|
||||
}
|
||||
|
||||
o.x |= static_cast<unsigned short>(sign >> 16);
|
||||
o.x |= static_cast<numext::uint16_t>(sign >> 16);
|
||||
return o;
|
||||
#endif
|
||||
}
|
||||
@ -514,10 +581,10 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
|
||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300) || \
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||
return __half2float(h);
|
||||
|
||||
#elif defined(EIGEN_HAS_FP16_C)
|
||||
return _cvtsh_ss(h.x);
|
||||
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return static_cast<float>(h.x);
|
||||
#else
|
||||
const float32_bits magic = { 113 << 23 };
|
||||
const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
||||
@ -543,12 +610,18 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC float half_to_float(__half_raw h) {
|
||||
// --- standard functions ---
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isinf)(const half& a) {
|
||||
#ifdef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
|
||||
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) == 0x7c00;
|
||||
#else
|
||||
return (a.x & 0x7fff) == 0x7c00;
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isnan)(const half& a) {
|
||||
#if (defined(EIGEN_HAS_CUDA_FP16) && defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 530) || \
|
||||
(defined(EIGEN_HAS_HIP_FP16) && defined(EIGEN_HIP_DEVICE_COMPILE))
|
||||
return __hisnan(a);
|
||||
#elif defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) > 0x7c00;
|
||||
#else
|
||||
return (a.x & 0x7fff) > 0x7c00;
|
||||
#endif
|
||||
@ -558,9 +631,13 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool (isfinite)(const half& a) {
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return half(vabsh_f16(a.x));
|
||||
#else
|
||||
half result;
|
||||
result.x = a.x & 0x7FFF;
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half exp(const half& a) {
|
||||
#if (EIGEN_CUDA_SDK_VER >= 80000 && defined EIGEN_CUDA_ARCH && EIGEN_CUDA_ARCH >= 530) || \
|
||||
@ -717,9 +794,13 @@ template<> struct NumTraits<Eigen::half>
|
||||
|
||||
// C-like standard mathematical functions and trancendentals.
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half fabsh(const Eigen::half& a) {
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
return Eigen::half(vabsh_f16(a.x));
|
||||
#else
|
||||
Eigen::half result;
|
||||
result.x = a.x & 0x7FFF;
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half exph(const Eigen::half& a) {
|
||||
return Eigen::half(::expf(float(a)));
|
||||
@ -778,7 +859,7 @@ __device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneM
|
||||
defined(EIGEN_HIPCC)
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) {
|
||||
return Eigen::half_impl::raw_uint16_to_half(
|
||||
__ldg(reinterpret_cast<const unsigned short*>(ptr)));
|
||||
__ldg(reinterpret_cast<const numext::uint16_t*>(ptr)));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -3771,6 +3771,650 @@ template<> EIGEN_STRONG_INLINE Packet2d psqrt(const Packet2d& _x){ return vsqrt_
|
||||
|
||||
#endif // EIGEN_ARCH_ARM64
|
||||
|
||||
// Do we have an fp16 types and supporting Neon intrinsics?
|
||||
#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
typedef float16x4_t Packet4hf;
|
||||
typedef float16x8_t Packet8hf;
|
||||
|
||||
// TODO(tellenbach): Enable packets of size 8 as soon as the GEBP can handle them
|
||||
template <>
|
||||
struct packet_traits<Eigen::half> : default_packet_traits {
|
||||
typedef Packet4hf type;
|
||||
typedef Packet4hf half;
|
||||
enum {
|
||||
Vectorizable = 1,
|
||||
AlignedOnScalar = 1,
|
||||
size = 4,
|
||||
HasHalfPacket = 0,
|
||||
|
||||
HasCmp = 1,
|
||||
HasCast = 1,
|
||||
HasAdd = 1,
|
||||
HasSub = 1,
|
||||
HasShift = 1,
|
||||
HasMul = 1,
|
||||
HasNegate = 1,
|
||||
HasAbs = 1,
|
||||
HasArg = 0,
|
||||
HasAbs2 = 1,
|
||||
HasAbsDiff = 0,
|
||||
HasMin = 1,
|
||||
HasMax = 1,
|
||||
HasConj = 1,
|
||||
HasSetLinear = 0,
|
||||
HasBlend = 0,
|
||||
HasInsert = 1,
|
||||
HasReduxp = 1,
|
||||
HasDiv = 1,
|
||||
HasFloor = 1,
|
||||
HasSin = 0,
|
||||
HasCos = 0,
|
||||
HasLog = 0,
|
||||
HasExp = 0,
|
||||
HasSqrt = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet4hf> {
|
||||
typedef Eigen::half type;
|
||||
typedef Packet4hf half;
|
||||
enum {
|
||||
size = 4,
|
||||
alignment = Aligned16,
|
||||
vectorizable = true,
|
||||
masked_load_available = false,
|
||||
masked_store_available = false
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unpacket_traits<Packet8hf> {
|
||||
typedef Eigen::half type;
|
||||
typedef Packet8hf half;
|
||||
enum {
|
||||
size = 8,
|
||||
alignment = Aligned16,
|
||||
vectorizable = true,
|
||||
masked_load_available = false,
|
||||
masked_store_available = false
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pset1<Packet8hf>(const Eigen::half& from) {
|
||||
return vdupq_n_f16(from.x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pset1<Packet4hf>(const Eigen::half& from) {
|
||||
return vdup_n_f16(from.x);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf plset<Packet8hf>(const Eigen::half& a) {
|
||||
const float16_t f[] = {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
Packet8hf countdown = vld1q_f16(f);
|
||||
return vaddq_f16(pset1<Packet8hf>(a), countdown);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf plset<Packet4hf>(const Eigen::half& a) {
|
||||
const float16_t f[] = {0, 1, 2, 3};
|
||||
Packet4hf countdown = vld1_f16(f);
|
||||
return vadd_f16(pset1<Packet4hf>(a), countdown);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf padd<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vaddq_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf padd<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vadd_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf psub<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vsubq_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf psub<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vsub_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pnegate(const Packet8hf& a) {
|
||||
return vnegq_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pnegate(const Packet4hf& a) {
|
||||
return vneg_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pconj(const Packet8hf& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pconj(const Packet4hf& a) {
|
||||
return a;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pmul<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vmulq_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pmul<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vmul_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pdiv<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vdivq_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pdiv<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vdiv_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pmadd(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
|
||||
return vfmaq_f16(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
|
||||
return vfma_f16(c, a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pmin<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vminq_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pmin<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vmin_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pmax<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vmaxq_f16(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pmax<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vmax_f16(a, b);
|
||||
}
|
||||
|
||||
#define EIGEN_MAKE_ARM_FP16_CMP_8(name) \
|
||||
template <> \
|
||||
EIGEN_STRONG_INLINE Packet8hf pcmp_##name(const Packet8hf& a, const Packet8hf& b) { \
|
||||
return vreinterpretq_f16_u16(vc##name##q_f16(a, b)); \
|
||||
}
|
||||
|
||||
#define EIGEN_MAKE_ARM_FP16_CMP_4(name) \
|
||||
template <> \
|
||||
EIGEN_STRONG_INLINE Packet4hf pcmp_##name(const Packet4hf& a, const Packet4hf& b) { \
|
||||
return vreinterpret_f16_u16(vc##name##_f16(a, b)); \
|
||||
}
|
||||
|
||||
EIGEN_MAKE_ARM_FP16_CMP_8(eq)
|
||||
EIGEN_MAKE_ARM_FP16_CMP_8(lt)
|
||||
EIGEN_MAKE_ARM_FP16_CMP_8(le)
|
||||
|
||||
EIGEN_MAKE_ARM_FP16_CMP_4(eq)
|
||||
EIGEN_MAKE_ARM_FP16_CMP_4(lt)
|
||||
EIGEN_MAKE_ARM_FP16_CMP_4(le)
|
||||
|
||||
#undef EIGEN_MAKE_ARM_FP16_CMP_8
|
||||
#undef EIGEN_MAKE_ARM_FP16_CMP_4
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pcmp_lt_or_nan<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vreinterpretq_f16_u16(vmvnq_u16(vcgeq_f16(a, b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pcmp_lt_or_nan<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vreinterpret_f16_u16(vmvn_u16(vcge_f16(a, b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pfloor<Packet8hf>(const Packet8hf& a) {
|
||||
const Packet8hf cst_1 = pset1<Packet8hf>(Eigen::half(1.0f));
|
||||
/* perform a floorf */
|
||||
Packet8hf tmp = vcvtq_f16_s16(vcvtq_s16_f16(a));
|
||||
|
||||
/* if greater, substract 1 */
|
||||
uint16x8_t mask = vcgtq_f16(tmp, a);
|
||||
mask = vandq_u16(mask, vreinterpretq_u16_f16(cst_1));
|
||||
return vsubq_f16(tmp, vreinterpretq_f16_u16(mask));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pfloor<Packet4hf>(const Packet4hf& a) {
|
||||
const Packet4hf cst_1 = pset1<Packet4hf>(Eigen::half(1.0f));
|
||||
/* perform a floorf */
|
||||
Packet4hf tmp = vcvt_f16_s16(vcvt_s16_f16(a));
|
||||
|
||||
/* if greater, substract 1 */
|
||||
uint16x4_t mask = vcgt_f16(tmp, a);
|
||||
mask = vand_u16(mask, vreinterpret_u16_f16(cst_1));
|
||||
return vsub_f16(tmp, vreinterpret_f16_u16(mask));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf psqrt<Packet8hf>(const Packet8hf& a) {
|
||||
return vsqrtq_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf psqrt<Packet4hf>(const Packet4hf& a) {
|
||||
return vsqrt_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pand<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vreinterpretq_f16_u16(vandq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pand<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vreinterpret_f16_u16(vand_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf por<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vreinterpretq_f16_u16(vorrq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf por<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vreinterpret_f16_u16(vorr_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pxor<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vreinterpretq_f16_u16(veorq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pxor<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vreinterpret_f16_u16(veor_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pandnot<Packet8hf>(const Packet8hf& a, const Packet8hf& b) {
|
||||
return vreinterpretq_f16_u16(vbicq_u16(vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pandnot<Packet4hf>(const Packet4hf& a, const Packet4hf& b) {
|
||||
return vreinterpret_f16_u16(vbic_u16(vreinterpret_u16_f16(a), vreinterpret_u16_f16(b)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pload<Packet8hf>(const Eigen::half* from) {
|
||||
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pload<Packet4hf>(const Eigen::half* from) {
|
||||
EIGEN_DEBUG_ALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf ploadu<Packet8hf>(const Eigen::half* from) {
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf ploadu<Packet4hf>(const Eigen::half* from) {
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf ploaddup<Packet8hf>(const Eigen::half* from) {
|
||||
Packet8hf packet;
|
||||
packet[0] = from[0].x;
|
||||
packet[1] = from[0].x;
|
||||
packet[2] = from[1].x;
|
||||
packet[3] = from[1].x;
|
||||
packet[4] = from[2].x;
|
||||
packet[5] = from[2].x;
|
||||
packet[6] = from[3].x;
|
||||
packet[7] = from[3].x;
|
||||
return packet;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf ploaddup<Packet4hf>(const Eigen::half* from) {
|
||||
float16x4_t packet;
|
||||
float16_t* tmp;
|
||||
tmp = (float16_t*)&packet;
|
||||
tmp[0] = from[0].x;
|
||||
tmp[1] = from[0].x;
|
||||
tmp[2] = from[1].x;
|
||||
tmp[3] = from[1].x;
|
||||
return packet;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf ploadquad<Packet8hf>(const Eigen::half* from) {
|
||||
Packet4hf lo, hi;
|
||||
lo = vld1_dup_f16(reinterpret_cast<const float16_t*>(from));
|
||||
hi = vld1_dup_f16(reinterpret_cast<const float16_t*>(from+1));
|
||||
return vcombine_f16(lo, hi);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline Packet8hf pinsertfirst(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 0); }
|
||||
|
||||
EIGEN_DEVICE_FUNC inline Packet4hf pinsertfirst(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 0); }
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet8hf pselect(const Packet8hf& mask, const Packet8hf& a, const Packet8hf& b) {
|
||||
return vbslq_f16(vreinterpretq_u16_f16(mask), a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet4hf pselect(const Packet4hf& mask, const Packet4hf& a, const Packet4hf& b) {
|
||||
return vbsl_f16(vreinterpret_u16_f16(mask), a, b);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline Packet8hf pinsertlast(const Packet8hf& a, Eigen::half b) { return vsetq_lane_f16(b.x, a, 7); }
|
||||
|
||||
EIGEN_DEVICE_FUNC inline Packet4hf pinsertlast(const Packet4hf& a, Eigen::half b) { return vset_lane_f16(b.x, a, 3); }
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
|
||||
EIGEN_DEBUG_ALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
|
||||
EIGEN_DEBUG_ALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
|
||||
EIGEN_DEBUG_UNALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void pstoreu<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
|
||||
EIGEN_DEBUG_UNALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet8hf pgather<Eigen::half, Packet8hf>(const Eigen::half* from, Index stride) {
|
||||
Packet8hf res = pset1<Packet8hf>(Eigen::half(0.f));
|
||||
res = vsetq_lane_f16(from[0 * stride].x, res, 0);
|
||||
res = vsetq_lane_f16(from[1 * stride].x, res, 1);
|
||||
res = vsetq_lane_f16(from[2 * stride].x, res, 2);
|
||||
res = vsetq_lane_f16(from[3 * stride].x, res, 3);
|
||||
res = vsetq_lane_f16(from[4 * stride].x, res, 4);
|
||||
res = vsetq_lane_f16(from[5 * stride].x, res, 5);
|
||||
res = vsetq_lane_f16(from[6 * stride].x, res, 6);
|
||||
res = vsetq_lane_f16(from[7 * stride].x, res, 7);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline Packet4hf pgather<Eigen::half, Packet4hf>(const Eigen::half* from, Index stride) {
|
||||
Packet4hf res = pset1<Packet4hf>(Eigen::half(0.f));
|
||||
res = vset_lane_f16(from[0 * stride].x, res, 0);
|
||||
res = vset_lane_f16(from[1 * stride].x, res, 1);
|
||||
res = vset_lane_f16(from[2 * stride].x, res, 2);
|
||||
res = vset_lane_f16(from[3 * stride].x, res, 3);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet8hf>(Eigen::half* to, const Packet8hf& from, Index stride) {
|
||||
to[stride * 0].x = vgetq_lane_f16(from, 0);
|
||||
to[stride * 1].x = vgetq_lane_f16(from, 1);
|
||||
to[stride * 2].x = vgetq_lane_f16(from, 2);
|
||||
to[stride * 3].x = vgetq_lane_f16(from, 3);
|
||||
to[stride * 4].x = vgetq_lane_f16(from, 4);
|
||||
to[stride * 5].x = vgetq_lane_f16(from, 5);
|
||||
to[stride * 6].x = vgetq_lane_f16(from, 6);
|
||||
to[stride * 7].x = vgetq_lane_f16(from, 7);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC inline void pscatter<Eigen::half, Packet4hf>(Eigen::half* to, const Packet4hf& from, Index stride) {
|
||||
to[stride * 0].x = vget_lane_f16(from, 0);
|
||||
to[stride * 1].x = vget_lane_f16(from, 1);
|
||||
to[stride * 2].x = vget_lane_f16(from, 2);
|
||||
to[stride * 3].x = vget_lane_f16(from, 3);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE void prefetch<Eigen::half>(const Eigen::half* addr) {
|
||||
EIGEN_ARM_PREFETCH(addr);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet8hf>(const Packet8hf& a) {
|
||||
float16_t x[8];
|
||||
vst1q_f16(x, a);
|
||||
Eigen::half h;
|
||||
h.x = x[0];
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half pfirst<Packet4hf>(const Packet4hf& a) {
|
||||
float16_t x[4];
|
||||
vst1_f16(x, a);
|
||||
Eigen::half h;
|
||||
h.x = x[0];
|
||||
return h;
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8hf preverse(const Packet8hf& a) {
|
||||
float16x4_t a_lo, a_hi;
|
||||
Packet8hf a_r64;
|
||||
|
||||
a_r64 = vrev64q_f16(a);
|
||||
a_lo = vget_low_f16(a_r64);
|
||||
a_hi = vget_high_f16(a_r64);
|
||||
return vcombine_f16(a_hi, a_lo);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf preverse<Packet4hf>(const Packet4hf& a) {
|
||||
return vrev64_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8hf pabs<Packet8hf>(const Packet8hf& a) {
|
||||
return vabsq_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet4hf pabs<Packet4hf>(const Packet4hf& a) {
|
||||
return vabs_f16(a);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux<Packet8hf>(const Packet8hf& a) {
|
||||
float16x4_t a_lo, a_hi, sum;
|
||||
|
||||
a_lo = vget_low_f16(a);
|
||||
a_hi = vget_high_f16(a);
|
||||
sum = vpadd_f16(a_lo, a_hi);
|
||||
sum = vpadd_f16(sum, sum);
|
||||
sum = vpadd_f16(sum, sum);
|
||||
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(sum, 0);
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux<Packet4hf>(const Packet4hf& a) {
|
||||
float16x4_t sum;
|
||||
|
||||
sum = vpadd_f16(a, a);
|
||||
sum = vpadd_f16(sum, sum);
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(sum, 0);
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8hf>(const Packet8hf& a) {
|
||||
float16x4_t a_lo, a_hi, prod;
|
||||
|
||||
a_lo = vget_low_f16(a);
|
||||
a_hi = vget_high_f16(a);
|
||||
prod = vmul_f16(a_lo, a_hi);
|
||||
prod = vmul_f16(prod, vrev64_f16(prod));
|
||||
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(prod, 0) * vget_lane_f16(prod, 1);
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet4hf>(const Packet4hf& a) {
|
||||
float16x4_t prod;
|
||||
prod = vmul_f16(a, vrev64_f16(a));
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(prod, 0) * vget_lane_f16(prod, 1);
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8hf>(const Packet8hf& a) {
|
||||
float16x4_t a_lo, a_hi, min;
|
||||
|
||||
a_lo = vget_low_f16(a);
|
||||
a_hi = vget_high_f16(a);
|
||||
min = vpmin_f16(a_lo, a_hi);
|
||||
min = vpmin_f16(min, min);
|
||||
min = vpmin_f16(min, min);
|
||||
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(min, 0);
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet4hf>(const Packet4hf& a) {
|
||||
Packet4hf tmp;
|
||||
tmp = vpmin_f16(a, a);
|
||||
tmp = vpmin_f16(tmp, tmp);
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(tmp, 0);
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8hf>(const Packet8hf& a) {
|
||||
float16x4_t a_lo, a_hi, max;
|
||||
|
||||
a_lo = vget_low_f16(a);
|
||||
a_hi = vget_high_f16(a);
|
||||
max = vpmax_f16(a_lo, a_hi);
|
||||
max = vpmax_f16(max, max);
|
||||
max = vpmax_f16(max, max);
|
||||
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(max, 0);
|
||||
return h;
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet4hf>(const Packet4hf& a) {
|
||||
Packet4hf tmp;
|
||||
tmp = vpmax_f16(a, a);
|
||||
tmp = vpmax_f16(tmp, tmp);
|
||||
Eigen::half h;
|
||||
h.x = vget_lane_f16(tmp, 0);
|
||||
return h;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8hf, 4>& kernel) {
|
||||
EIGEN_ALIGN16 Eigen::half in[4][8];
|
||||
|
||||
pstore<Eigen::half>(in[0], kernel.packet[0]);
|
||||
pstore<Eigen::half>(in[1], kernel.packet[1]);
|
||||
pstore<Eigen::half>(in[2], kernel.packet[2]);
|
||||
pstore<Eigen::half>(in[3], kernel.packet[3]);
|
||||
|
||||
EIGEN_ALIGN16 Eigen::half out[4][8];
|
||||
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
out[i][j] = in[j][2*i];
|
||||
}
|
||||
EIGEN_UNROLL_LOOP
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
out[i][j+4] = in[j][2*i+1];
|
||||
}
|
||||
}
|
||||
|
||||
kernel.packet[0] = pload<Packet8hf>(out[0]);
|
||||
kernel.packet[1] = pload<Packet8hf>(out[1]);
|
||||
kernel.packet[2] = pload<Packet8hf>(out[2]);
|
||||
kernel.packet[3] = pload<Packet8hf>(out[3]);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4hf, 4>& kernel) {
|
||||
EIGEN_ALIGN16 float16x4x4_t tmp_x4;
|
||||
float16_t* tmp = (float16_t*)&kernel;
|
||||
tmp_x4 = vld4_f16(tmp);
|
||||
|
||||
kernel.packet[0] = tmp_x4.val[0];
|
||||
kernel.packet[1] = tmp_x4.val[1];
|
||||
kernel.packet[2] = tmp_x4.val[2];
|
||||
kernel.packet[3] = tmp_x4.val[3];
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8hf, 8>& kernel) {
|
||||
float16x8x2_t T_1[4];
|
||||
|
||||
T_1[0] = vuzpq_f16(kernel.packet[0], kernel.packet[1]);
|
||||
T_1[1] = vuzpq_f16(kernel.packet[2], kernel.packet[3]);
|
||||
T_1[2] = vuzpq_f16(kernel.packet[4], kernel.packet[5]);
|
||||
T_1[3] = vuzpq_f16(kernel.packet[6], kernel.packet[7]);
|
||||
|
||||
float16x8x2_t T_2[4];
|
||||
T_2[0] = vuzpq_f16(T_1[0].val[0], T_1[1].val[0]);
|
||||
T_2[1] = vuzpq_f16(T_1[0].val[1], T_1[1].val[1]);
|
||||
T_2[2] = vuzpq_f16(T_1[2].val[0], T_1[3].val[0]);
|
||||
T_2[3] = vuzpq_f16(T_1[2].val[1], T_1[3].val[1]);
|
||||
|
||||
float16x8x2_t T_3[4];
|
||||
T_3[0] = vuzpq_f16(T_2[0].val[0], T_2[2].val[0]);
|
||||
T_3[1] = vuzpq_f16(T_2[0].val[1], T_2[2].val[1]);
|
||||
T_3[2] = vuzpq_f16(T_2[1].val[0], T_2[3].val[0]);
|
||||
T_3[3] = vuzpq_f16(T_2[1].val[1], T_2[3].val[1]);
|
||||
|
||||
kernel.packet[0] = T_3[0].val[0];
|
||||
kernel.packet[1] = T_3[2].val[0];
|
||||
kernel.packet[2] = T_3[1].val[0];
|
||||
kernel.packet[3] = T_3[3].val[0];
|
||||
kernel.packet[4] = T_3[0].val[1];
|
||||
kernel.packet[5] = T_3[2].val[1];
|
||||
kernel.packet[6] = T_3[1].val[1];
|
||||
kernel.packet[7] = T_3[3].val[1];
|
||||
}
|
||||
#endif // end EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -414,6 +414,13 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// Following the Arm ACLE arm_neon.h should also include arm_fp16.h but not all
|
||||
// compilers seem to follow this. We therefore include it explicitly.
|
||||
// See also: https://bugs.llvm.org/show_bug.cgi?id=47955
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
#include <arm_fp16.h>
|
||||
#endif
|
||||
|
||||
#if defined(__F16C__) && (!defined(EIGEN_GPUCC) && (!defined(EIGEN_COMP_CLANG) || EIGEN_COMP_CLANG>=380))
|
||||
// We can use the optimized fp16 to float and float to fp16 conversion routines
|
||||
#define EIGEN_HAS_FP16_C
|
||||
|
@ -258,12 +258,47 @@
|
||||
#define EIGEN_ARCH_ARM64 0
|
||||
#endif
|
||||
|
||||
/// \internal EIGEN_ARCH_ARM_OR_ARM64 set to 1 if the architecture is ARM or ARM64
|
||||
#if EIGEN_ARCH_ARM || EIGEN_ARCH_ARM64
|
||||
#define EIGEN_ARCH_ARM_OR_ARM64 1
|
||||
#else
|
||||
#define EIGEN_ARCH_ARM_OR_ARM64 0
|
||||
#endif
|
||||
|
||||
/// \internal EIGEN_HAS_ARM64_FP16 set to 1 if the architecture provides an IEEE
|
||||
/// compliant Arm fp16 type
|
||||
#if EIGEN_ARCH_ARM64
|
||||
#ifndef EIGEN_HAS_ARM64_FP16
|
||||
#if defined(__ARM_FP16_FORMAT_IEEE)
|
||||
#define EIGEN_HAS_ARM64_FP16 1
|
||||
#else
|
||||
#define EIGEN_HAS_ARM64_FP16 0
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/// \internal EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC set to 1 if the architecture
|
||||
/// supports Neon vector intrinsics for fp16.
|
||||
#if EIGEN_ARCH_ARM64
|
||||
#ifndef EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
|
||||
#define EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC 1
|
||||
#else
|
||||
#define EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC 0
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/// \internal EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC set to 1 if the architecture
|
||||
/// supports Neon scalar intrinsics for fp16.
|
||||
#if EIGEN_ARCH_ARM64
|
||||
#ifndef EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC
|
||||
#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
|
||||
#define EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC 1
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/// \internal EIGEN_ARCH_MIPS set to 1 if the architecture is MIPS
|
||||
#if defined(__mips__) || defined(__mips)
|
||||
#define EIGEN_ARCH_MIPS 1
|
||||
|
@ -684,15 +684,6 @@ template<> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
bool not_equal_strict(const double& x,const double& y) { return std::not_equal_to<double>()(x,y); }
|
||||
#endif
|
||||
|
||||
/** \internal extract the bits of the float \a x */
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC unsigned int as_uint(float x)
|
||||
{
|
||||
unsigned int ret;
|
||||
EIGEN_USING_STD(memcpy)
|
||||
memcpy(&ret, &x, sizeof(float));
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // end namespace numext
|
||||
|
||||
} // end namespace Eigen
|
||||
|
@ -22,6 +22,9 @@ void test_conversion()
|
||||
{
|
||||
using Eigen::half_impl::__half_raw;
|
||||
|
||||
// We don't use a uint16_t raw member x if the platform has native Arm __fp16
|
||||
// support
|
||||
#if !defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
// Conversion from float.
|
||||
VERIFY_IS_EQUAL(half(1.0f).x, 0x3c00);
|
||||
VERIFY_IS_EQUAL(half(0.5f).x, 0x3800);
|
||||
@ -53,6 +56,41 @@ void test_conversion()
|
||||
// Conversion from bool.
|
||||
VERIFY_IS_EQUAL(half(false).x, 0x0000);
|
||||
VERIFY_IS_EQUAL(half(true).x, 0x3c00);
|
||||
#endif
|
||||
|
||||
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
// Conversion from float.
|
||||
VERIFY_IS_EQUAL(half(1.0f).x, __fp16(1.0f));
|
||||
VERIFY_IS_EQUAL(half(0.5f).x, __fp16(0.5f));
|
||||
VERIFY_IS_EQUAL(half(0.33333f).x, __fp16(0.33333f));
|
||||
VERIFY_IS_EQUAL(half(0.0f).x, __fp16(0.0f));
|
||||
VERIFY_IS_EQUAL(half(-0.0f).x, __fp16(-0.0f));
|
||||
VERIFY_IS_EQUAL(half(65504.0f).x, __fp16(65504.0f));
|
||||
VERIFY_IS_EQUAL(half(65536.0f).x, __fp16(65536.0f)); // Becomes infinity.
|
||||
|
||||
// Denormals.
|
||||
VERIFY_IS_EQUAL(half(-5.96046e-08f).x, __fp16(-5.96046e-08f));
|
||||
VERIFY_IS_EQUAL(half(5.96046e-08f).x, __fp16(5.96046e-08f));
|
||||
VERIFY_IS_EQUAL(half(1.19209e-07f).x, __fp16(1.19209e-07f));
|
||||
|
||||
// Verify round-to-nearest-even behavior.
|
||||
float val1 = float(half(__half_raw(0x3c00)));
|
||||
float val2 = float(half(__half_raw(0x3c01)));
|
||||
float val3 = float(half(__half_raw(0x3c02)));
|
||||
VERIFY_IS_EQUAL(half(0.5f * (val1 + val2)).x, __fp16(0.5f * (val1 + val2)));
|
||||
VERIFY_IS_EQUAL(half(0.5f * (val2 + val3)).x, __fp16(0.5f * (val2 + val3)));
|
||||
|
||||
// Conversion from int.
|
||||
VERIFY_IS_EQUAL(half(-1).x, __fp16(-1));
|
||||
VERIFY_IS_EQUAL(half(0).x, __fp16(0));
|
||||
VERIFY_IS_EQUAL(half(1).x, __fp16(1));
|
||||
VERIFY_IS_EQUAL(half(2).x, __fp16(2));
|
||||
VERIFY_IS_EQUAL(half(3).x, __fp16(3));
|
||||
|
||||
// Conversion from bool.
|
||||
VERIFY_IS_EQUAL(half(false).x, __fp16(false));
|
||||
VERIFY_IS_EQUAL(half(true).x, __fp16(true));
|
||||
#endif
|
||||
|
||||
// Conversion to float.
|
||||
VERIFY_IS_EQUAL(float(half(__half_raw(0x0000))), 0.0f);
|
||||
@ -92,6 +130,15 @@ void test_conversion()
|
||||
VERIFY((numext::isinf)(half(1.0 / 0.0)));
|
||||
VERIFY((numext::isinf)(half(-1.0 / 0.0)));
|
||||
#endif
|
||||
|
||||
// Conversion to bool
|
||||
VERIFY(!static_cast<bool>(half(0.0)));
|
||||
VERIFY(!static_cast<bool>(half(-0.0)));
|
||||
VERIFY(static_cast<bool>(half(__half_raw(0x7bff))));
|
||||
VERIFY(static_cast<bool>(half(-0.33333)));
|
||||
VERIFY(static_cast<bool>(half(1.0)));
|
||||
VERIFY(static_cast<bool>(half(-1.0)));
|
||||
VERIFY(static_cast<bool>(half(-5.96046e-08f)));
|
||||
}
|
||||
|
||||
void test_numtraits()
|
||||
@ -108,8 +155,12 @@ void test_numtraits()
|
||||
VERIFY(NumTraits<half>::IsSigned);
|
||||
|
||||
VERIFY_IS_EQUAL( std::numeric_limits<half>::infinity().x, half(std::numeric_limits<float>::infinity()).x );
|
||||
|
||||
// If we have a native fp16 types this becomes a nan == nan comparision so we have to disable it
|
||||
#if !defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC)
|
||||
VERIFY_IS_EQUAL( std::numeric_limits<half>::quiet_NaN().x, half(std::numeric_limits<float>::quiet_NaN()).x );
|
||||
VERIFY_IS_EQUAL( std::numeric_limits<half>::signaling_NaN().x, half(std::numeric_limits<float>::signaling_NaN()).x );
|
||||
#endif
|
||||
VERIFY( (std::numeric_limits<half>::min)() > half(0.f) );
|
||||
VERIFY( (std::numeric_limits<half>::denorm_min)() > half(0.f) );
|
||||
VERIFY( (std::numeric_limits<half>::min)()/half(2) > half(0.f) );
|
||||
@ -218,8 +269,8 @@ void test_trigonometric_functions()
|
||||
VERIFY_IS_APPROX(numext::cos(half(0.0f)), half(cosf(0.0f)));
|
||||
VERIFY_IS_APPROX(cos(half(0.0f)), half(cosf(0.0f)));
|
||||
VERIFY_IS_APPROX(numext::cos(half(EIGEN_PI)), half(cosf(EIGEN_PI)));
|
||||
//VERIFY_IS_APPROX(numext::cos(half(EIGEN_PI/2)), half(cosf(EIGEN_PI/2)));
|
||||
//VERIFY_IS_APPROX(numext::cos(half(3*EIGEN_PI/2)), half(cosf(3*EIGEN_PI/2)));
|
||||
// VERIFY_IS_APPROX(numext::cos(half(EIGEN_PI/2)), half(cosf(EIGEN_PI/2)));
|
||||
// VERIFY_IS_APPROX(numext::cos(half(3*EIGEN_PI/2)), half(cosf(3*EIGEN_PI/2)));
|
||||
VERIFY_IS_APPROX(numext::cos(half(3.5f)), half(cosf(3.5f)));
|
||||
|
||||
VERIFY_IS_APPROX(numext::sin(half(0.0f)), half(sinf(0.0f)));
|
||||
|
@ -246,6 +246,7 @@ void packetmath_boolean_mask_ops() {
|
||||
data1[i] = Scalar(i);
|
||||
data1[i + PacketSize] = internal::random<bool>() ? data1[i] : Scalar(0);
|
||||
}
|
||||
|
||||
CHECK_CWISE2_IF(true, internal::pcmp_eq, internal::pcmp_eq);
|
||||
|
||||
//Test (-0) == (0) for signed operations
|
||||
|
Loading…
x
Reference in New Issue
Block a user