mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
New BF16 pcast functions and move type casting to TypeCasting.h
This commit is contained in:
parent
17b5b4de58
commit
3f3ce214e6
@ -220,6 +220,7 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/SSE/Complex.h"
|
||||
#elif defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
|
||||
#include "src/Core/arch/AltiVec/PacketMath.h"
|
||||
#include "src/Core/arch/AltiVec/TypeCasting.h"
|
||||
#include "src/Core/arch/AltiVec/MathFunctions.h"
|
||||
#include "src/Core/arch/AltiVec/Complex.h"
|
||||
#elif defined EIGEN_VECTORIZE_NEON
|
||||
|
@ -2697,102 +2697,6 @@ template<> EIGEN_STRONG_INLINE Packet16uc pblend(const Selector<16>& ifPacket, c
|
||||
return vec_sel(elsePacket, thenPacket, mask);
|
||||
}
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<float, int> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<int, float> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<bfloat16, unsigned short int> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<unsigned short int, bfloat16> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
|
||||
return vec_cts(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
|
||||
return vec_ctu(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
|
||||
return vec_ctf(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
|
||||
return vec_ctf(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
|
||||
Packet4f float_even = Bf16ToF32Even(a);
|
||||
Packet4f float_odd = Bf16ToF32Odd(a);
|
||||
Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
|
||||
Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
|
||||
Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
|
||||
Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
|
||||
|
||||
//Check values that are bigger than USHRT_MAX (0xFFFF)
|
||||
Packet4bi overflow_selector;
|
||||
if(vec_any_gt(int_even, p4ui_low_mask)){
|
||||
overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
|
||||
low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
|
||||
}
|
||||
if(vec_any_gt(int_odd, p4ui_low_mask)){
|
||||
overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
|
||||
low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
|
||||
}
|
||||
|
||||
return pmerge(low_even, low_odd);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
|
||||
//short -> int -> float -> bfloat16
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
|
||||
Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
|
||||
Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
|
||||
Packet4ui int_odd = plogical_shift_right<16>(int_cast);
|
||||
Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
|
||||
Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
|
||||
return F32ToBf16(float_even, float_odd);
|
||||
}
|
||||
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
|
||||
return reinterpret_cast<Packet4i>(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
|
||||
return reinterpret_cast<Packet4f>(a);
|
||||
}
|
||||
|
||||
|
||||
|
||||
//---------- double ----------
|
||||
#ifdef EIGEN_VECTORIZE_VSX
|
||||
@ -2805,7 +2709,6 @@ typedef Packet2ul Packet2bl;
|
||||
typedef __vector __bool long Packet2bl;
|
||||
#endif
|
||||
|
||||
static Packet2l p2l_ONE = { 1, 1 };
|
||||
static Packet2l p2l_ZERO = reinterpret_cast<Packet2l>(p4i_ZERO);
|
||||
static Packet2ul p2ul_SIGN = { 0x8000000000000000ull, 0x8000000000000000ull };
|
||||
static Packet2ul p2ul_PREV0DOT5 = { 0x3FDFFFFFFFFFFFFFull, 0x3FDFFFFFFFFFFFFFull };
|
||||
@ -3082,34 +2985,10 @@ template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a)
|
||||
return reinterpret_cast<Packet2d>(vec_perm(tmp, tmp, p16uc_DUPSIGN));
|
||||
}
|
||||
#endif
|
||||
// VSX support varies between different compilers and even different
|
||||
// versions of the same compiler. For gcc version >= 4.9.3, we can use
|
||||
// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
|
||||
// a slow version that works with older compilers.
|
||||
// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
|
||||
// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
|
||||
template<>
|
||||
inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
|
||||
#if EIGEN_GNUC_STRICT_AT_LEAST(7,1,0)
|
||||
return vec_cts(x, 0); // TODO: check clang version.
|
||||
#else
|
||||
double tmp[2];
|
||||
memcpy(tmp, &x, sizeof(tmp));
|
||||
Packet2l l = { static_cast<long long>(tmp[0]),
|
||||
static_cast<long long>(tmp[1]) };
|
||||
return l;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
|
||||
unsigned long long tmp[2];
|
||||
memcpy(tmp, &x, sizeof(tmp));
|
||||
Packet2d d = { static_cast<double>(tmp[0]),
|
||||
static_cast<double>(tmp[1]) };
|
||||
return d;
|
||||
}
|
||||
template<> inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x);
|
||||
|
||||
template<> inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x);
|
||||
|
||||
// Packet2l shifts.
|
||||
// For POWER8 we simply use vec_sr/l.
|
||||
@ -3290,7 +3169,7 @@ ptranspose(PacketBlock<Packet2d,2>& kernel) {
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& thenPacket, const Packet2d& elsePacket) {
|
||||
Packet2l select = { ifPacket.select[0], ifPacket.select[1] };
|
||||
Packet2bl mask = reinterpret_cast<Packet2bl>( vec_cmpeq(reinterpret_cast<Packet2d>(select), reinterpret_cast<Packet2d>(p2l_ONE)) );
|
||||
Packet2ul mask = reinterpret_cast<Packet2ul>(pnegate(reinterpret_cast<Packet2l>(select)));
|
||||
return vec_sel(elsePacket, thenPacket, mask);
|
||||
}
|
||||
|
||||
|
178
Eigen/src/Core/arch/AltiVec/TypeCasting.h
Normal file
178
Eigen/src/Core/arch/AltiVec/TypeCasting.h
Normal file
@ -0,0 +1,178 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// Copyright (C) 2019 Rasmus Munk Larsen <rmlarsen@google.com>
|
||||
// Copyright (C) 2023 Chip Kerchner (chip.kerchner@ibm.com)
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_TYPE_CASTING_ALTIVEC_H
|
||||
#define EIGEN_TYPE_CASTING_ALTIVEC_H
|
||||
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
template <>
|
||||
struct type_casting_traits<float, int> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<int, float> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<bfloat16, unsigned short int> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<unsigned short int, bfloat16> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
|
||||
return vec_cts(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4ui pcast<Packet4f, Packet4ui>(const Packet4f& a) {
|
||||
return vec_ctu(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4i, Packet4f>(const Packet4i& a) {
|
||||
return vec_ctf(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet4ui, Packet4f>(const Packet4ui& a) {
|
||||
return vec_ctf(a,0);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packet8bf& a) {
|
||||
Packet4f float_even = Bf16ToF32Even(a);
|
||||
Packet4f float_odd = Bf16ToF32Odd(a);
|
||||
Packet4ui int_even = pcast<Packet4f, Packet4ui>(float_even);
|
||||
Packet4ui int_odd = pcast<Packet4f, Packet4ui>(float_odd);
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
|
||||
Packet4ui low_even = pand<Packet4ui>(int_even, p4ui_low_mask);
|
||||
Packet4ui low_odd = pand<Packet4ui>(int_odd, p4ui_low_mask);
|
||||
|
||||
//Check values that are bigger than USHRT_MAX (0xFFFF)
|
||||
Packet4bi overflow_selector;
|
||||
if(vec_any_gt(int_even, p4ui_low_mask)){
|
||||
overflow_selector = vec_cmpgt(int_even, p4ui_low_mask);
|
||||
low_even = vec_sel(low_even, p4ui_low_mask, overflow_selector);
|
||||
}
|
||||
if(vec_any_gt(int_odd, p4ui_low_mask)){
|
||||
overflow_selector = vec_cmpgt(int_odd, p4ui_low_mask);
|
||||
low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
|
||||
}
|
||||
|
||||
return pmerge(low_even, low_odd);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
|
||||
//short -> int -> float -> bfloat16
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(low_mask, 0x0000FFFF);
|
||||
Packet4ui int_cast = reinterpret_cast<Packet4ui>(a);
|
||||
Packet4ui int_even = pand<Packet4ui>(int_cast, p4ui_low_mask);
|
||||
Packet4ui int_odd = plogical_shift_right<16>(int_cast);
|
||||
Packet4f float_even = pcast<Packet4ui, Packet4f>(int_even);
|
||||
Packet4f float_odd = pcast<Packet4ui, Packet4f>(int_odd);
|
||||
return F32ToBf16(float_even, float_odd);
|
||||
}
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<bfloat16, float> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 1,
|
||||
TgtCoeffRatio = 2
|
||||
};
|
||||
};
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcast<Packet8bf, Packet4f>(const Packet8bf& a) {
|
||||
Packet8us z = pset1<Packet8us>(0);
|
||||
#ifdef _BIG_ENDIAN
|
||||
return reinterpret_cast<Packet4f>(vec_mergeh(a.m_val, z));
|
||||
#else
|
||||
return reinterpret_cast<Packet4f>(vec_mergeh(z, a.m_val));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
struct type_casting_traits<float, bfloat16> {
|
||||
enum {
|
||||
VectorizedCast = 1,
|
||||
SrcCoeffRatio = 2,
|
||||
TgtCoeffRatio = 1
|
||||
};
|
||||
};
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet4f, Packet8bf>(const Packet4f& a, const Packet4f &b) {
|
||||
return F32ToBf16Both(a, b);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i preinterpret<Packet4i,Packet4f>(const Packet4f& a) {
|
||||
return reinterpret_cast<Packet4i>(a);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f preinterpret<Packet4f,Packet4i>(const Packet4i& a) {
|
||||
return reinterpret_cast<Packet4f>(a);
|
||||
}
|
||||
|
||||
#ifdef EIGEN_VECTORIZE_VSX
|
||||
// VSX support varies between different compilers and even different
|
||||
// versions of the same compiler. For gcc version >= 4.9.3, we can use
|
||||
// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
|
||||
// a slow version that works with older compilers.
|
||||
// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
|
||||
// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
|
||||
template<>
|
||||
inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
|
||||
#if EIGEN_GNUC_STRICT_AT_LEAST(7,1,0)
|
||||
return vec_cts(x, 0); // TODO: check clang version.
|
||||
#else
|
||||
double tmp[2];
|
||||
memcpy(tmp, &x, sizeof(tmp));
|
||||
Packet2l l = { static_cast<long long>(tmp[0]),
|
||||
static_cast<long long>(tmp[1]) };
|
||||
return l;
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
|
||||
unsigned long long tmp[2];
|
||||
memcpy(tmp, &x, sizeof(tmp));
|
||||
Packet2d d = { static_cast<double>(tmp[0]),
|
||||
static_cast<double>(tmp[1]) };
|
||||
return d;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_TYPE_CASTING_ALTIVEC_H
|
Loading…
x
Reference in New Issue
Block a user