mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-10 02:39:03 +08:00
Add dynamic dispatch to BF16 GEMM (Power) and new VSX version
This commit is contained in:
parent
3026fc0d3c
commit
1148f0a9ec
File diff suppressed because it is too large
Load Diff
@ -84,6 +84,18 @@ EIGEN_ALWAYS_INLINE void gemm_complex_extra_cols(
|
||||
const Packet& pAlphaImag,
|
||||
const Packet& pMask);
|
||||
|
||||
template<typename DataMapper>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src);
|
||||
|
||||
template<const Index size, bool non_unit_stride, Index delta>
|
||||
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra = 0);
|
||||
|
||||
template<bool non_unit_stride = false>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols, Index rows, bfloat16* src, Index resInc = 1);
|
||||
|
||||
template<bool rhsExtraCols, bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows);
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);
|
||||
|
||||
|
@ -28,9 +28,7 @@
|
||||
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
#if !EIGEN_ALTIVEC_DISABLE_MMA
|
||||
#include "MatrixProductMMAbfloat16.h"
|
||||
#endif
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
|
@ -53,7 +53,7 @@ EIGEN_ALWAYS_INLINE void KLoop
|
||||
|
||||
indexA += k*(lhsExtraRows ? extra_rows : num_packets);
|
||||
for(Index j = 0; j < num_lhs; j++) {
|
||||
lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); //a packet of bfloat16 has 8 elements
|
||||
lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); // a packet of bfloat16 has 8 elements
|
||||
}
|
||||
|
||||
BFLOAT16_UNROLL
|
||||
@ -65,46 +65,6 @@ EIGEN_ALWAYS_INLINE void KLoop
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE Packet4f loadAndMultiplyF32(Packet4f acc, const Packet4f pAlpha, float* result)
|
||||
{
|
||||
Packet4f result_block = ploadu<Packet4f>(result);
|
||||
return pmadd(acc, pAlpha, result_block);
|
||||
}
|
||||
|
||||
template<bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE void storeF32(float*& result, Packet4f result_block, Index rows, Index extra_rows)
|
||||
{
|
||||
if (lhsExtraRows) {
|
||||
pstoreu_partial(result, result_block, extra_rows);
|
||||
} else {
|
||||
pstoreu(result, result_block);
|
||||
}
|
||||
result += rows;
|
||||
}
|
||||
|
||||
template<bool rhsExtraCols, bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows)
|
||||
{
|
||||
Index x = 0;
|
||||
if (rhsExtraCols) {
|
||||
do{
|
||||
Packet4f result_block = loadAndMultiplyF32(acc[x], pAlpha, result);
|
||||
storeF32<lhsExtraRows>(result, result_block, rows, extra_rows);
|
||||
} while (++x < extra_cols);
|
||||
} else {
|
||||
Packet4f result_block[4];
|
||||
float *result2 = result;
|
||||
do{
|
||||
result_block[x] = loadAndMultiplyF32(acc[x], pAlpha, result);
|
||||
result += rows;
|
||||
} while (++x < 4);
|
||||
x = 0;
|
||||
do{
|
||||
storeF32<lhsExtraRows>(result2, result_block[x], rows, extra_rows);
|
||||
} while (++x < 4);
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void zeroAccumulators(__vector_quad (&quad_acc)[num_acc])
|
||||
{
|
||||
@ -165,17 +125,14 @@ EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f
|
||||
template<const Index num_acc, const Index num_packets, bool rhsExtraCols, bool lhsExtraRows>
|
||||
void colLoopBody(Index& col, Index depth, Index cols, Index rows, const Packet4f pAlpha, const bfloat16* indexA, const bfloat16* indexB, Index strideB, Index offsetB, float* result)
|
||||
{
|
||||
constexpr Index step = (num_acc * 4); //each accumulator has 4 elements
|
||||
constexpr Index step = (num_acc * 4); // each accumulator has 4 elements
|
||||
const Index extra_cols = (rhsExtraCols) ? (cols & 3) : 0;
|
||||
const Index extra_rows = (lhsExtraRows) ? (rows & 3) : 0;
|
||||
constexpr bool multiIters = !rhsExtraCols && (num_acc == MAX_BFLOAT16_ACC);
|
||||
constexpr bool normIters = multiIters && ((num_acc % (num_packets / 4)) == 0);
|
||||
|
||||
do{
|
||||
if (multiIters && ((num_acc % (num_packets / 4)) == 0)) {
|
||||
colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, true>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
|
||||
} else {
|
||||
colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
|
||||
}
|
||||
colLoopBodyIter<num_acc, num_packets, rhsExtraCols, lhsExtraRows, normIters>(depth, rows, pAlpha, indexA, indexB, strideB, offsetB, result, extra_cols, extra_rows);
|
||||
|
||||
indexB += strideB*num_acc;
|
||||
result += rows*step;
|
||||
@ -239,104 +196,89 @@ EIGEN_ALWAYS_INLINE void colLoops(Index depth, Index cols, Index rows, const Pac
|
||||
}
|
||||
}
|
||||
|
||||
template<bool full = true>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf convertF32toBF16(const float *res)
|
||||
{
|
||||
Packet16uc fp16_0 = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
|
||||
Packet16uc fp16_1 = (full) ? __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4))) : fp16_0;
|
||||
return vec_pack(reinterpret_cast<Packet4ui>(fp16_0), reinterpret_cast<Packet4ui>(fp16_1));
|
||||
Packet16uc fp16[2];
|
||||
#if EIGEN_COMP_LLVM
|
||||
__vector_pair fp16_vp = *reinterpret_cast<__vector_pair *>(const_cast<float *>(res));
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(fp16), &fp16_vp);
|
||||
fp16[0] = __builtin_vsx_xvcvspbf16(fp16[0]);
|
||||
fp16[1] = __builtin_vsx_xvcvspbf16(fp16[1]);
|
||||
#else
|
||||
fp16[0] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 0)));
|
||||
fp16[1] = __builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(ploadu<Packet4f>(res + 4)));
|
||||
#endif
|
||||
return vec_pack(reinterpret_cast<Packet4ui>(fp16[0]), reinterpret_cast<Packet4ui>(fp16[1]));
|
||||
}
|
||||
|
||||
template<int N>
|
||||
EIGEN_ALWAYS_INLINE void storeConvertBlockBF16(float* to, PacketBlock<Packet8bf,(N+4)/8>& block)
|
||||
template<typename DataMapper, const Index size>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16Col(float *result, Index col, Index rows, const DataMapper& res)
|
||||
{
|
||||
Packet8us z = pset1<Packet8us>(0);
|
||||
pstore(to + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[0].m_val)));
|
||||
if (N >= 8) {
|
||||
pstore(to + 4, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[0].m_val)));
|
||||
const DataMapper res2 = res.getSubMapper(0, col);
|
||||
Index row;
|
||||
float *result2 = result + col*rows;
|
||||
for(row = 0; row + 8 <= rows; row += 8){
|
||||
// get and save block
|
||||
PacketBlock<Packet8bf,size> block;
|
||||
for(Index j = 0; j < size; j++){
|
||||
block.packet[j] = convertF32toBF16(result2 + j*rows + row);
|
||||
}
|
||||
res2.template storePacketBlock<Packet8bf,size>(row, 0, block);
|
||||
}
|
||||
if (N >= 16) {
|
||||
pstore(to + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[1].m_val)));
|
||||
pstore(to + 12, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[1].m_val)));
|
||||
}
|
||||
if (N >= 32) {
|
||||
pstore(to + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[2].m_val)));
|
||||
pstore(to + 20, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[2].m_val)));
|
||||
pstore(to + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, block.packet[3].m_val)));
|
||||
pstore(to + 28, reinterpret_cast<Packet4f>(vec_mergel(z, block.packet[3].m_val)));
|
||||
// extra rows
|
||||
if(row < rows){
|
||||
for(Index j = 0; j < size; j++){
|
||||
Packet8bf fp16 = convertF32toBF16(result2 + j*rows + row);
|
||||
res2.template storePacketPartial<Packet8bf>(row, j, fp16, rows & 7);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index size, typename DataMapper>
|
||||
EIGEN_ALWAYS_INLINE void convertBF16toF32(Index& i, float *result, Index rows, const DataMapper& src)
|
||||
template<const Index size, bool non_unit_stride = false>
|
||||
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
|
||||
{
|
||||
for(; i + size <= rows; i += size){
|
||||
PacketBlock<Packet8bf,(size+4)/8> r32;
|
||||
r32.packet[0] = src.template loadPacket<Packet8bf>(i + 0);
|
||||
constexpr Index extra = ((size < 8) ? 8 : size);
|
||||
for(; i + size <= rows; i += extra, dst += extra*resInc){
|
||||
PacketBlock<Packet8bf,(size+7)/8> r32;
|
||||
r32.packet[0] = convertF32toBF16(result + i + 0);
|
||||
if (size >= 16) {
|
||||
r32.packet[1] = src.template loadPacket<Packet8bf>(i + 8);
|
||||
r32.packet[1] = convertF32toBF16(result + i + 8);
|
||||
}
|
||||
if (size >= 32) {
|
||||
r32.packet[2] = src.template loadPacket<Packet8bf>(i + 16);
|
||||
r32.packet[3] = src.template loadPacket<Packet8bf>(i + 24);
|
||||
r32.packet[2] = convertF32toBF16(result + i + 16);
|
||||
r32.packet[3] = convertF32toBF16(result + i + 24);
|
||||
}
|
||||
storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc, rows & 7);
|
||||
if (size >= 16) {
|
||||
storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
|
||||
}
|
||||
if (size >= 32) {
|
||||
storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
|
||||
storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
|
||||
}
|
||||
storeConvertBlockBF16<size>(result + i, r32);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename DataMapper>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayBF16toF32(float *result, Index cols, Index rows, const DataMapper& src)
|
||||
template<bool non_unit_stride = false>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1)
|
||||
{
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
for(Index j = 0; j < cols; j++, result += rows){
|
||||
const LinearMapper src2 = src.getLinearMapper(0, j);
|
||||
Index i = 0;
|
||||
convertBF16toF32<32, LinearMapper>(i, result, rows, src2);
|
||||
convertBF16toF32<16, LinearMapper>(i, result, rows, src2);
|
||||
convertBF16toF32<8, LinearMapper>(i, result, rows, src2);
|
||||
convertBF16toF32<4, LinearMapper>(i, result, rows, src2);
|
||||
for(; i < rows; i++){
|
||||
result[i] = Eigen::bfloat16_impl::bfloat16_to_float(src2(i));
|
||||
}
|
||||
}
|
||||
Index i = 0;
|
||||
convertPointerF32toBF16<32,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16<16,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16<8,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16<1,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
}
|
||||
|
||||
template<typename DataMapper>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayF32toBF16(float *result, Index cols, Index rows, const DataMapper& res)
|
||||
{
|
||||
typedef typename DataMapper::LinearMapper LinearMapper;
|
||||
Index col, row;
|
||||
Index col;
|
||||
for(col = 0; col + 4 <= cols; col += 4){
|
||||
const DataMapper res2 = res.getSubMapper(0, col);
|
||||
for(row = 0; row + 8 <= rows; row += 8){
|
||||
//get and save block
|
||||
PacketBlock<Packet8bf,4> block;
|
||||
for(Index j = 0; j < 4; j++){
|
||||
block.packet[j].m_val = convertF32toBF16(result + (col + j)*rows + row);
|
||||
}
|
||||
|
||||
res2.template storePacketBlock<Packet8bf,4>(row, 0, block);
|
||||
}
|
||||
//extra rows
|
||||
while(row < rows){
|
||||
for(Index col_off = 0; col_off < 4; col_off++){
|
||||
res2(row, col_off) = Eigen::bfloat16(result[(col+col_off)*rows+row]);
|
||||
}
|
||||
row++;
|
||||
}
|
||||
|
||||
convertArrayF32toBF16Col<DataMapper,4>(result, col, rows, res);
|
||||
}
|
||||
//extra cols
|
||||
// extra cols
|
||||
while(col < cols){
|
||||
const LinearMapper res2 = res.getLinearMapper(0, col);
|
||||
float *result2 = result + col*rows;
|
||||
for(row = 0; row + 8 <= rows; row += 8){
|
||||
Packet8bf fp16 = convertF32toBF16(result2 + row);
|
||||
res2.template storePacket<Packet8bf>(row, fp16);
|
||||
}
|
||||
for(; row < rows; row++){
|
||||
res2(row) = Eigen::bfloat16(result2[row]);
|
||||
}
|
||||
convertArrayF32toBF16Col<DataMapper,1>(result, col, rows, res);
|
||||
col++;
|
||||
}
|
||||
}
|
||||
@ -361,134 +303,42 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat
|
||||
|
||||
convertArrayBF16toF32<DataMapper>(result, cols, rows, res);
|
||||
|
||||
Index row = 0;
|
||||
|
||||
if( strideA == -1 ) strideA = depth;
|
||||
if( strideB == -1 ) strideB = depth;
|
||||
//Packing is done in blocks.
|
||||
//There's 4 possible sizes of blocks
|
||||
//Blocks of 8 columns with 16 elements (8x16)
|
||||
//Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
|
||||
//Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
|
||||
//Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
|
||||
// Packing is done in blocks.
|
||||
// There's 4 possible sizes of blocks
|
||||
// Blocks of 8 columns with 16 elements (8x16)
|
||||
// Blocks of 8 columns with 8 elements (8x8). This happens when there's 16 > rows >= 8
|
||||
// Blocks of 8 columns with 4 elements (8x4). This happens when there's 8 > rows >= 4
|
||||
// Blocks of 8 columns with < 4 elements. This happens when there's less than 4 remaining rows
|
||||
|
||||
//Loop for LHS standard block (8x16)
|
||||
// Loop for LHS standard block (8x16)
|
||||
Index bigSuffix = (2*8) * (strideA-offsetA);
|
||||
indexB += 4*offsetB;
|
||||
strideB *= 4;
|
||||
offsetB *= 3;
|
||||
|
||||
Index row = 0;
|
||||
while(row + 16 <= rows){
|
||||
calcColLoops<16>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
|
||||
}
|
||||
//LHS (8x8) block
|
||||
// LHS (8x8) block
|
||||
calcColLoops<8>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
|
||||
//LHS (8x4) block
|
||||
// LHS (8x4) block
|
||||
calcColLoops<4>(indexA, row, depth, cols, rows, pAlpha, indexB, strideB, offsetA, offsetB, bigSuffix, result);
|
||||
//extra rows
|
||||
// extra rows
|
||||
if(rows & 3){
|
||||
//This index is the beginning of remaining block.
|
||||
// This index is the beginning of remaining block.
|
||||
colLoops<4, true>(depth, cols, rows, pAlpha, indexA, indexB, strideB, offsetB, result + row);
|
||||
}
|
||||
|
||||
//Convert back to bfloat16
|
||||
// Convert back to bfloat16
|
||||
convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
|
||||
}
|
||||
|
||||
template<const Index size, bool inc, Index delta>
|
||||
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc)
|
||||
{
|
||||
if (inc) {
|
||||
if (size == 4) {
|
||||
pscatter_partial(dst + delta*resInc, data, resInc, 4);
|
||||
} else {
|
||||
pscatter(dst + delta*resInc, data, resInc);
|
||||
}
|
||||
} else {
|
||||
if (size == 4) {
|
||||
pstoreu_partial(dst + delta, data, 4);
|
||||
} else {
|
||||
pstoreu(dst + delta, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index size, bool inc>
|
||||
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc)
|
||||
{
|
||||
for(; i + size <= rows; i += size, dst += size*resInc){
|
||||
PacketBlock<Packet8bf,(size+4)/8> r32;
|
||||
r32.packet[0] = convertF32toBF16<size != 4>(result + i + 0);
|
||||
if (size >= 16) {
|
||||
r32.packet[1] = convertF32toBF16<true>(result + i + 8);
|
||||
}
|
||||
if (size >= 32) {
|
||||
r32.packet[2] = convertF32toBF16<true>(result + i + 16);
|
||||
r32.packet[3] = convertF32toBF16<true>(result + i + 24);
|
||||
}
|
||||
storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc);
|
||||
if (size >= 16) {
|
||||
storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
|
||||
}
|
||||
if (size >= 32) {
|
||||
storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
|
||||
storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool inc, Index delta>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc)
|
||||
{
|
||||
if (inc) {
|
||||
return pgather<bfloat16, Packet8bf>(src + delta*resInc, resInc);
|
||||
} else {
|
||||
return ploadu<Packet8bf>(src + delta);
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index size, bool inc>
|
||||
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
|
||||
{
|
||||
for(; i + size <= rows; i += size, src += size*resInc){
|
||||
PacketBlock<Packet8bf,(size+4)/8> r32;
|
||||
r32.packet[0] = loadBF16fromResult<inc, 0>(src, resInc);
|
||||
if (size >= 16) {
|
||||
r32.packet[1] = loadBF16fromResult<inc, 8>(src, resInc);
|
||||
}
|
||||
if (size >= 32) {
|
||||
r32.packet[2] = loadBF16fromResult<inc, 16>(src, resInc);
|
||||
r32.packet[3] = loadBF16fromResult<inc, 24>(src, resInc);
|
||||
}
|
||||
storeConvertBlockBF16<size>(result + i, r32);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool inc = false>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index rows, bfloat16* src, Index resInc = 1)
|
||||
{
|
||||
Index i = 0;
|
||||
convertPointerBF16toF32<32, inc>(i, result, rows, src, resInc);
|
||||
convertPointerBF16toF32<16, inc>(i, result, rows, src, resInc);
|
||||
convertPointerBF16toF32<8, inc>(i, result, rows, src, resInc);
|
||||
convertPointerBF16toF32<4, inc>(i, result, rows, src, resInc);
|
||||
for(; i < rows; i++, src += resInc){
|
||||
result[i] = Eigen::bfloat16_impl::bfloat16_to_float(*src);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool inc = false>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1)
|
||||
{
|
||||
Index i = 0;
|
||||
convertPointerF32toBF16<32,inc>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16<16,inc>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16<8,inc>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16<4,inc>(i, result, rows, dst, resInc);
|
||||
for(; i < rows; i++, dst += resInc){
|
||||
*dst = Eigen::bfloat16(result[i]);
|
||||
}
|
||||
}
|
||||
#undef MAX_BFLOAT16_ACC
|
||||
|
||||
#if !EIGEN_ALTIVEC_DISABLE_MMA
|
||||
template<bool extraRows>
|
||||
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
|
||||
{
|
||||
@ -667,7 +517,7 @@ void gemvMMA_bfloat16_col(
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
|
||||
|
||||
convertArrayPointerBF16toF32(result, rows, res);
|
||||
convertArrayPointerBF16toF32(result, 1, rows, res);
|
||||
|
||||
for (Index j2 = 0; j2 < cols; j2 += block_cols)
|
||||
{
|
||||
@ -867,9 +717,9 @@ EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
|
||||
if (resIncr == 1) {
|
||||
convertArrayPointerBF16toF32(result, rows, res);
|
||||
convertArrayPointerBF16toF32(result, 1, rows, res);
|
||||
} else {
|
||||
convertArrayPointerBF16toF32<true>(result, rows, res, resIncr);
|
||||
convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
|
||||
}
|
||||
calcVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
|
||||
if (resIncr == 1) {
|
||||
@ -878,6 +728,10 @@ EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(
|
||||
convertArrayPointerF32toBF16<true>(result, rows, res, resIncr);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#undef MAX_BFLOAT16_VEC_ACC
|
||||
#undef BFLOAT16_UNROLL
|
||||
|
||||
}
|
||||
}
|
||||
|
@ -17,8 +17,8 @@
|
||||
#define USE_GEMV_MMA
|
||||
#endif
|
||||
|
||||
#if !EIGEN_COMP_LLVM && (__GNUC__ == 10 && __GNUC_MINOR__ <= 3)
|
||||
// Only allow one vector_pair in buggy gcc - gcc 10.3 has a bug
|
||||
#if !EIGEN_COMP_LLVM && (__GNUC__ < 11)
|
||||
// Only allow one vector_pair in buggy gcc - gcc 10.x has a bug
|
||||
#define GCC_ONE_VECTORPAIR_BUG
|
||||
#endif
|
||||
#endif
|
||||
|
@ -35,6 +35,7 @@ typedef __vector unsigned int Packet4ui;
|
||||
typedef __vector __bool int Packet4bi;
|
||||
typedef __vector short int Packet8s;
|
||||
typedef __vector unsigned short int Packet8us;
|
||||
typedef __vector __bool short Packet8bi;
|
||||
typedef __vector signed char Packet16c;
|
||||
typedef __vector unsigned char Packet16uc;
|
||||
typedef eigen_packet_wrapper<__vector unsigned short int,0> Packet8bf;
|
||||
@ -83,10 +84,7 @@ static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS16,-16); //{ -16, -16, -16, -16}
|
||||
static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1,-1); //{ -1, -1, -1, -1}
|
||||
static EIGEN_DECLARE_CONST_FAST_Packet4ui(SIGN, 0x80000000u);
|
||||
static EIGEN_DECLARE_CONST_FAST_Packet4ui(PREV0DOT5, 0x3EFFFFFFu);
|
||||
#ifndef __POWER8_VECTOR__
|
||||
static EIGEN_DECLARE_CONST_FAST_Packet8us(ONE,1); //{ 1, 1, 1, 1, 1, 1, 1, 1}
|
||||
static EIGEN_DECLARE_CONST_FAST_Packet16uc(ONE,1);
|
||||
#endif
|
||||
static Packet4f p4f_MZERO = (Packet4f) vec_sl((Packet4ui)p4i_MINUS1, (Packet4ui)p4i_MINUS1); //{ 0x80000000, 0x80000000, 0x80000000, 0x80000000}
|
||||
#ifndef __VSX__
|
||||
static Packet4f p4f_ONE = vec_ctf(p4i_ONE, 0); //{ 1.0, 1.0, 1.0, 1.0}
|
||||
@ -116,6 +114,14 @@ static const Packet16uc p16uc_DUPLICATE16_ODD = { 2,3 ,2,3, 6,7, 6,7, 10,11, 10,
|
||||
|
||||
static Packet16uc p16uc_QUADRUPLICATE16_HI = { 0,1,0,1,0,1,0,1, 2,3,2,3,2,3,2,3 };
|
||||
|
||||
static Packet16uc p16uc_MERGEE16 = { 0,1, 16,17, 4,5, 20,21, 8,9, 24,25, 12,13, 28,29 };
|
||||
static Packet16uc p16uc_MERGEO16 = { 2,3, 18,19, 6,7, 22,23, 10,11, 26,27, 14,15, 30,31 };
|
||||
#ifdef _BIG_ENDIAN
|
||||
static Packet16uc p16uc_MERGEH16 = { 0,1, 4,5, 8,9, 12,13, 16,17, 20,21, 24,25, 28,29 };
|
||||
#else
|
||||
static Packet16uc p16uc_MERGEL16 = { 2,3, 6,7, 10,11, 14,15, 18,19, 22,23, 26,27, 30,31 };
|
||||
#endif
|
||||
|
||||
// Handle endianness properly while loading constants
|
||||
// Define global static constants:
|
||||
#ifdef _BIG_ENDIAN
|
||||
@ -537,31 +543,20 @@ EIGEN_ALWAYS_INLINE Packet pload_partial_common(const __UNPACK_TYPE__(Packet)* f
|
||||
}
|
||||
return load;
|
||||
#else
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
|
||||
unsigned char* load2 = reinterpret_cast<unsigned char *>(load + offset);
|
||||
unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
|
||||
Index n2 = n * size;
|
||||
Index i = 0;
|
||||
if (16 <= n2) {
|
||||
pstoreu(load2, ploadu<Packet16uc>(from2));
|
||||
i += 16;
|
||||
if (n) {
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
|
||||
unsigned char* load2 = reinterpret_cast<unsigned char *>(load + offset);
|
||||
unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
|
||||
Index n2 = n * size;
|
||||
if (16 <= n2) {
|
||||
pstoreu(load2, ploadu<Packet16uc>(from2));
|
||||
} else {
|
||||
memcpy((void *)load2, (void *)from2, n2);
|
||||
}
|
||||
return pload_ignore<Packet>(load);
|
||||
} else {
|
||||
return Packet(pset1<Packet16uc>(0));
|
||||
}
|
||||
if (i + 8 <= n2) {
|
||||
*reinterpret_cast<uint64_t *>(load2 + i) = *reinterpret_cast<uint64_t *>(from2 + i);
|
||||
i += 8;
|
||||
}
|
||||
if (i + 4 <= n2) {
|
||||
*reinterpret_cast<uint32_t *>(load2 + i) = *reinterpret_cast<uint32_t *>(from2 + i);
|
||||
i += 4;
|
||||
}
|
||||
if (i + 2 <= n2) {
|
||||
*reinterpret_cast<uint16_t *>(load2 + i) = *reinterpret_cast<uint16_t *>(from2 + i);
|
||||
i += 2;
|
||||
}
|
||||
if (i < n2) {
|
||||
*reinterpret_cast<uint8_t *>(load2 + i) = *reinterpret_cast<uint8_t *>(from2 + i);
|
||||
}
|
||||
return pload_ignore<Packet>(load);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -635,7 +630,7 @@ template<> EIGEN_STRONG_INLINE void pstore<unsigned short int>(unsigned short in
|
||||
|
||||
template<> EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet8bf& from)
|
||||
{
|
||||
pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
|
||||
pstore_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from.m_val);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE void pstore<signed char>(signed char* to, const Packet16c& from)
|
||||
@ -670,30 +665,17 @@ template<typename Packet> EIGEN_ALWAYS_INLINE void pstore_partial_common(__UNPAC
|
||||
}
|
||||
vec_xst_len(store, to, n * size);
|
||||
#else
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
|
||||
pstore(store, from);
|
||||
unsigned char* store2 = reinterpret_cast<unsigned char *>(store + offset);
|
||||
unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
|
||||
Index n2 = n * size;
|
||||
Index i = 0;
|
||||
if (16 <= n2) {
|
||||
pstore(to2, ploadu<Packet16uc>(store2));
|
||||
i += 16;
|
||||
}
|
||||
if (i + 8 <= n2) {
|
||||
*reinterpret_cast<uint64_t *>(to2 + i) = *reinterpret_cast<uint64_t *>(store2 + i);
|
||||
i += 8;
|
||||
}
|
||||
if (i + 4 <= n2) {
|
||||
*reinterpret_cast<uint32_t *>(to2 + i) = *reinterpret_cast<uint32_t *>(store2 + i);
|
||||
i += 4;
|
||||
}
|
||||
if (i + 2 <= n2) {
|
||||
*reinterpret_cast<uint16_t *>(to2 + i) = *reinterpret_cast<uint16_t *>(store2 + i);
|
||||
i += 2;
|
||||
}
|
||||
if (i < n2) {
|
||||
*reinterpret_cast<uint8_t *>(to2 + i) = *reinterpret_cast<uint8_t *>(store2 + i);
|
||||
if (n) {
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
|
||||
pstore(store, from);
|
||||
unsigned char* store2 = reinterpret_cast<unsigned char *>(store + offset);
|
||||
unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
|
||||
Index n2 = n * size;
|
||||
if (16 <= n2) {
|
||||
pstore(to2, ploadu<Packet16uc>(store2));
|
||||
} else {
|
||||
memcpy((void *)to2, (void *)store2, n2);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -720,7 +702,7 @@ template<> EIGEN_ALWAYS_INLINE void pstore_partial<unsigned short int>(unsigned
|
||||
|
||||
template<> EIGEN_ALWAYS_INLINE void pstore_partial<bfloat16>(bfloat16* to, const Packet8bf& from, const Index n, const Index offset)
|
||||
{
|
||||
pstore_partial_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from, n, offset);
|
||||
pstore_partial_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from.m_val, n, offset);
|
||||
}
|
||||
|
||||
template<> EIGEN_ALWAYS_INLINE void pstore_partial<signed char>(signed char* to, const Packet16c& from, const Index n, const Index offset)
|
||||
@ -1003,6 +985,22 @@ template<> EIGEN_STRONG_INLINE Packet4f pnegate(const Packet4f& a)
|
||||
return vec_xor(a, p4f_MZERO);
|
||||
#endif
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet16c pnegate(const Packet16c& a)
|
||||
{
|
||||
#ifdef __POWER8_VECTOR__
|
||||
return vec_neg(a);
|
||||
#else
|
||||
return reinterpret_cast<Packet16c>(p4i_ZERO) - a;
|
||||
#endif
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet8s pnegate(const Packet8s& a)
|
||||
{
|
||||
#ifdef __POWER8_VECTOR__
|
||||
return vec_neg(a);
|
||||
#else
|
||||
return reinterpret_cast<Packet8s>(p4i_ZERO) - a;
|
||||
#endif
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet4i pnegate(const Packet4i& a)
|
||||
{
|
||||
#ifdef __POWER8_VECTOR__
|
||||
@ -1102,7 +1100,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pmax<Packet16uc>(const Packet16uc& a,
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmple(a,b)); }
|
||||
// To fix bug with vec_cmplt on older versions
|
||||
#if defined(__POWER8_VECTOR__) || EIGEN_COMP_LLVM
|
||||
#ifdef EIGEN_VECTORIZE_VSX
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmplt(a,b)); }
|
||||
#endif
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return reinterpret_cast<Packet4f>(vec_cmpeq(a,b)); }
|
||||
@ -1256,31 +1254,20 @@ template<typename Packet> EIGEN_ALWAYS_INLINE Packet ploadu_partial_common(const
|
||||
EIGEN_DEBUG_UNALIGNED_LOAD
|
||||
return vec_xl_len(const_cast<__UNPACK_TYPE__(Packet)*>(from), n * size);
|
||||
#else
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
|
||||
unsigned char* load2 = reinterpret_cast<unsigned char *>(load);
|
||||
unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
|
||||
Index n2 = n * size;
|
||||
Index i = 0;
|
||||
if (16 <= n2) {
|
||||
pstore(load2, ploadu<Packet16uc>(from2));
|
||||
i += 16;
|
||||
if (n) {
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
|
||||
unsigned char* load2 = reinterpret_cast<unsigned char *>(load);
|
||||
unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
|
||||
Index n2 = n * size;
|
||||
if (16 <= n2) {
|
||||
pstore(load2, ploadu<Packet16uc>(from2));
|
||||
} else {
|
||||
memcpy((void *)load2, (void *)from2, n2);
|
||||
}
|
||||
return pload_ignore<Packet>(load);
|
||||
} else {
|
||||
return Packet(pset1<Packet16uc>(0));
|
||||
}
|
||||
if (i + 8 <= n2) {
|
||||
*reinterpret_cast<uint64_t *>(load2 + i) = *reinterpret_cast<uint64_t *>(from2 + i);
|
||||
i += 8;
|
||||
}
|
||||
if (i + 4 <= n2) {
|
||||
*reinterpret_cast<uint32_t *>(load2 + i) = *reinterpret_cast<uint32_t *>(from2 + i);
|
||||
i += 4;
|
||||
}
|
||||
if (i + 2 <= n2) {
|
||||
*reinterpret_cast<uint16_t *>(load2 + i) = *reinterpret_cast<uint16_t *>(from2 + i);
|
||||
i += 2;
|
||||
}
|
||||
if (i < n2) {
|
||||
*reinterpret_cast<uint8_t *>(load2 + i) = *reinterpret_cast<uint8_t *>(from2 + i);
|
||||
}
|
||||
return pload_ignore<Packet>(load);
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -1422,7 +1409,7 @@ template<> EIGEN_STRONG_INLINE void pstoreu<unsigned short int>(unsigned short i
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<bfloat16>(bfloat16* to, const Packet8bf& from)
|
||||
{
|
||||
pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from);
|
||||
pstoreu_common<Packet8us>(reinterpret_cast<unsigned short int*>(to), from.m_val);
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE void pstoreu<signed char>(signed char* to, const Packet16c& from)
|
||||
{
|
||||
@ -1443,30 +1430,17 @@ template<typename Packet> EIGEN_ALWAYS_INLINE void pstoreu_partial_common(__UNPA
|
||||
EIGEN_DEBUG_UNALIGNED_STORE
|
||||
vec_xst_len(from, to, n * size);
|
||||
#else
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
|
||||
pstore(store, from);
|
||||
unsigned char* store2 = reinterpret_cast<unsigned char *>(store);
|
||||
unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
|
||||
Index n2 = n * size;
|
||||
Index i = 0;
|
||||
if (16 <= n2) {
|
||||
pstoreu(to2, pload<Packet16uc>(store2));
|
||||
i += 16;
|
||||
}
|
||||
if (i + 8 <= n2) {
|
||||
*reinterpret_cast<uint64_t *>(to2 + i) = *reinterpret_cast<uint64_t *>(store2 + i);
|
||||
i += 8;
|
||||
}
|
||||
if (i + 4 <= n2) {
|
||||
*reinterpret_cast<uint32_t *>(to2 + i) = *reinterpret_cast<uint32_t *>(store2 + i);
|
||||
i += 4;
|
||||
}
|
||||
if (i + 2 <= n2) {
|
||||
*reinterpret_cast<uint16_t *>(to2 + i) = *reinterpret_cast<uint16_t *>(store2 + i);
|
||||
i += 2;
|
||||
}
|
||||
if (i < n2) {
|
||||
*reinterpret_cast<uint8_t *>(to2 + i) = *reinterpret_cast<uint8_t *>(store2 + i);
|
||||
if (n) {
|
||||
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
|
||||
pstore(store, from);
|
||||
unsigned char* store2 = reinterpret_cast<unsigned char *>(store);
|
||||
unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
|
||||
Index n2 = n * size;
|
||||
if (16 <= n2) {
|
||||
pstoreu(to2, pload<Packet16uc>(store2));
|
||||
} else {
|
||||
memcpy((void *)to2, (void *)store2, n2);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -1636,17 +1610,37 @@ EIGEN_STRONG_INLINE Packet4f Bf16ToF32Odd(const Packet8bf& bf){
|
||||
);
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE Packet8us pmerge(Packet4ui even, Packet4ui odd) {
|
||||
#ifdef _BIG_ENDIAN
|
||||
return vec_perm(reinterpret_cast<Packet8us>(odd), reinterpret_cast<Packet8us>(even), p16uc_MERGEO16);
|
||||
#else
|
||||
return vec_perm(reinterpret_cast<Packet8us>(even), reinterpret_cast<Packet8us>(odd), p16uc_MERGEE16);
|
||||
#endif
|
||||
}
|
||||
|
||||
// Simple interleaving of bool masks, prevents true values from being
|
||||
// converted to NaNs.
|
||||
EIGEN_STRONG_INLINE Packet8bf F32ToBf16Bool(Packet4f even, Packet4f odd) {
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(high_mask, 0xFFFF0000);
|
||||
Packet4f bf_odd, bf_even;
|
||||
bf_odd = pand(reinterpret_cast<Packet4f>(p4ui_high_mask), odd);
|
||||
bf_even = plogical_shift_right<16>(even);
|
||||
return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
|
||||
return pmerge(reinterpret_cast<Packet4ui>(even), reinterpret_cast<Packet4ui>(odd));
|
||||
}
|
||||
|
||||
//#define SUPPORT_BF16_SUBNORMALS
|
||||
|
||||
#ifndef __VEC_CLASS_FP_NAN
|
||||
#define __VEC_CLASS_FP_NAN (1<<6)
|
||||
#endif
|
||||
|
||||
#if defined(SUPPORT_BF16_SUBNORMALS) && !defined(__VEC_CLASS_FP_SUBNORMAL)
|
||||
#define __VEC_CLASS_FP_SUBNORMAL_P (1<<1)
|
||||
#define __VEC_CLASS_FP_SUBNORMAL_N (1<<0)
|
||||
|
||||
#define __VEC_CLASS_FP_SUBNORMAL (__VEC_CLASS_FP_SUBNORMAL_P | __VEC_CLASS_FP_SUBNORMAL_N)
|
||||
#endif
|
||||
|
||||
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
|
||||
#ifdef _ARCH_PWR10
|
||||
return reinterpret_cast<Packet8us>(__builtin_vsx_xvcvspbf16(reinterpret_cast<Packet16uc>(p4f)));
|
||||
#else
|
||||
Packet4ui input = reinterpret_cast<Packet4ui>(p4f);
|
||||
Packet4ui lsb = plogical_shift_right<16>(input);
|
||||
lsb = pand<Packet4ui>(lsb, reinterpret_cast<Packet4ui>(p4i_ONE));
|
||||
@ -1655,43 +1649,202 @@ EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f p4f){
|
||||
Packet4ui rounding_bias = padd<Packet4ui>(lsb, p4ui_BIAS);
|
||||
input = padd<Packet4ui>(input, rounding_bias);
|
||||
|
||||
//Test NaN and Subnormal - Begin
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
|
||||
#ifdef _ARCH_PWR9
|
||||
Packet4bi nan_selector = vec_test_data_class(p4f, __VEC_CLASS_FP_NAN);
|
||||
input = vec_sel(input, p4ui_nan, nan_selector);
|
||||
|
||||
#ifdef SUPPORT_BF16_SUBNORMALS
|
||||
Packet4bi subnormal_selector = vec_test_data_class(p4f, __VEC_CLASS_FP_SUBNORMAL);
|
||||
input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
|
||||
#endif
|
||||
#else
|
||||
#ifdef SUPPORT_BF16_SUBNORMALS
|
||||
//Test NaN and Subnormal
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(exp_mask, 0x7F800000);
|
||||
Packet4ui exp = pand<Packet4ui>(p4ui_exp_mask, reinterpret_cast<Packet4ui>(p4f));
|
||||
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(mantissa_mask, 0x7FFFFF);
|
||||
Packet4ui mantissa = pand<Packet4ui>(p4ui_mantissa_mask, reinterpret_cast<Packet4ui>(p4f));
|
||||
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(max_exp, 0x7F800000);
|
||||
Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_max_exp);
|
||||
Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
|
||||
|
||||
Packet4bi is_max_exp = vec_cmpeq(exp, p4ui_exp_mask);
|
||||
Packet4bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast<Packet4ui>(p4i_ZERO));
|
||||
|
||||
Packet4ui nan_selector = pandnot<Packet4ui>(
|
||||
reinterpret_cast<Packet4ui>(is_max_exp),
|
||||
reinterpret_cast<Packet4ui>(is_mant_zero)
|
||||
);
|
||||
|
||||
Packet4bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet4ui>(p4i_ZERO));
|
||||
|
||||
Packet4ui subnormal_selector = pandnot<Packet4ui>(
|
||||
reinterpret_cast<Packet4ui>(is_zero_exp),
|
||||
reinterpret_cast<Packet4ui>(is_mant_zero)
|
||||
);
|
||||
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet4ui(nan, 0x7FC00000);
|
||||
input = vec_sel(input, p4ui_nan, nan_selector);
|
||||
input = vec_sel(input, reinterpret_cast<Packet4ui>(p4f), subnormal_selector);
|
||||
//Test NaN and Subnormal - End
|
||||
#else
|
||||
//Test only NaN
|
||||
Packet4bi nan_selector = vec_cmpeq(p4f, p4f);
|
||||
|
||||
input = vec_sel(p4ui_nan, input, nan_selector);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
input = plogical_shift_right<16>(input);
|
||||
return reinterpret_cast<Packet8us>(input);
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef _BIG_ENDIAN
|
||||
/**
|
||||
* Pack the high portion of two float Packets into one bfloat16 Packet
|
||||
*
|
||||
* @param lohi to expect either a low & high OR odd & even order
|
||||
*/
|
||||
template<bool lohi>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf Bf16PackHigh(Packet4f lo, Packet4f hi)
|
||||
{
|
||||
if (lohi) {
|
||||
return vec_perm(reinterpret_cast<Packet8us>(lo), reinterpret_cast<Packet8us>(hi), p16uc_MERGEH16);
|
||||
} else {
|
||||
return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEE16);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pack the low portion of two float Packets into one bfloat16 Packet
|
||||
*
|
||||
* @param lohi to expect either a low & high OR odd & even order
|
||||
*/
|
||||
template<bool lohi>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf Bf16PackLow(Packet4f lo, Packet4f hi)
|
||||
{
|
||||
if (lohi) {
|
||||
return vec_pack(reinterpret_cast<Packet4ui>(lo), reinterpret_cast<Packet4ui>(hi));
|
||||
} else {
|
||||
return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEO16);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template<bool lohi>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf Bf16PackLow(Packet4f hi, Packet4f lo)
|
||||
{
|
||||
if (lohi) {
|
||||
return vec_pack(reinterpret_cast<Packet4ui>(hi), reinterpret_cast<Packet4ui>(lo));
|
||||
} else {
|
||||
return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEE16);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool lohi>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf Bf16PackHigh(Packet4f hi, Packet4f lo)
|
||||
{
|
||||
if (lohi) {
|
||||
return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEL16);
|
||||
} else {
|
||||
return vec_perm(reinterpret_cast<Packet8us>(hi), reinterpret_cast<Packet8us>(lo), p16uc_MERGEO16);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* Convert and pack two float Packets into one bfloat16 Packet
|
||||
*
|
||||
* @param lohi to expect either a low & high OR odd & even order
|
||||
*/
|
||||
template<bool lohi = true>
|
||||
EIGEN_ALWAYS_INLINE Packet8bf F32ToBf16Two(Packet4f lo, Packet4f hi)
|
||||
{
|
||||
Packet8us p4f = Bf16PackHigh<lohi>(lo, hi);
|
||||
Packet8us p4f2 = Bf16PackLow<lohi>(lo, hi);
|
||||
|
||||
Packet8us lsb = pand<Packet8us>(p4f, p8us_ONE);
|
||||
EIGEN_DECLARE_CONST_FAST_Packet8us(BIAS,0x7FFFu);
|
||||
lsb = padd<Packet8us>(lsb, p8us_BIAS);
|
||||
lsb = padd<Packet8us>(lsb, p4f2);
|
||||
|
||||
Packet8bi rounding_bias = vec_cmplt(lsb, p4f2);
|
||||
Packet8us input = psub<Packet8us>(p4f, reinterpret_cast<Packet8us>(rounding_bias));
|
||||
|
||||
#ifdef _ARCH_PWR9
|
||||
Packet4bi nan_selector_lo = vec_test_data_class(lo, __VEC_CLASS_FP_NAN);
|
||||
Packet4bi nan_selector_hi = vec_test_data_class(hi, __VEC_CLASS_FP_NAN);
|
||||
Packet8us nan_selector = Bf16PackLow<lohi>(reinterpret_cast<Packet4f>(nan_selector_lo), reinterpret_cast<Packet4f>(nan_selector_hi));
|
||||
|
||||
input = vec_sel(input, p8us_BIAS, nan_selector);
|
||||
|
||||
#ifdef SUPPORT_BF16_SUBNORMALS
|
||||
Packet4bi subnormal_selector_lo = vec_test_data_class(lo, __VEC_CLASS_FP_SUBNORMAL);
|
||||
Packet4bi subnormal_selector_hi = vec_test_data_class(hi, __VEC_CLASS_FP_SUBNORMAL);
|
||||
Packet8us subnormal_selector = Bf16PackLow<lohi>(reinterpret_cast<Packet4f>(subnormal_selector_lo), reinterpret_cast<Packet4f>(subnormal_selector_hi));
|
||||
|
||||
input = vec_sel(input, reinterpret_cast<Packet8us>(p4f), subnormal_selector);
|
||||
#endif
|
||||
#else
|
||||
#ifdef SUPPORT_BF16_SUBNORMALS
|
||||
//Test NaN and Subnormal
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet8us(exp_mask, 0x7F80);
|
||||
Packet8us exp = pand<Packet8us>(p8us_exp_mask, p4f);
|
||||
|
||||
const EIGEN_DECLARE_CONST_FAST_Packet8us(mantissa_mask, 0x7Fu);
|
||||
Packet8us mantissa = pand<Packet8us>(p8us_mantissa_mask, p4f);
|
||||
|
||||
Packet8bi is_max_exp = vec_cmpeq(exp, p8us_exp_mask);
|
||||
Packet8bi is_mant_zero = vec_cmpeq(mantissa, reinterpret_cast<Packet8us>(p4i_ZERO));
|
||||
|
||||
Packet8us nan_selector = pandnot<Packet8us>(
|
||||
reinterpret_cast<Packet8us>(is_max_exp),
|
||||
reinterpret_cast<Packet8us>(is_mant_zero)
|
||||
);
|
||||
|
||||
Packet8bi is_zero_exp = vec_cmpeq(exp, reinterpret_cast<Packet8us>(p4i_ZERO));
|
||||
|
||||
Packet8us subnormal_selector = pandnot<Packet8us>(
|
||||
reinterpret_cast<Packet8us>(is_zero_exp),
|
||||
reinterpret_cast<Packet8us>(is_mant_zero)
|
||||
);
|
||||
|
||||
// Using BIAS as NaN (since any or all of the last 7 bits can be set)
|
||||
input = vec_sel(input, p8us_BIAS, nan_selector);
|
||||
input = vec_sel(input, reinterpret_cast<Packet8us>(p4f), subnormal_selector);
|
||||
#else
|
||||
//Test only NaN
|
||||
Packet4bi nan_selector_lo = vec_cmpeq(lo, lo);
|
||||
Packet4bi nan_selector_hi = vec_cmpeq(hi, hi);
|
||||
Packet8us nan_selector = Bf16PackLow<lohi>(reinterpret_cast<Packet4f>(nan_selector_lo), reinterpret_cast<Packet4f>(nan_selector_hi));
|
||||
|
||||
input = vec_sel(p8us_BIAS, input, nan_selector);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert and pack two float Packets into one bfloat16 Packet - low & high order
|
||||
*/
|
||||
EIGEN_STRONG_INLINE Packet8bf F32ToBf16Both(Packet4f lo, Packet4f hi)
|
||||
{
|
||||
#ifdef _ARCH_PWR10
|
||||
Packet8bf fp16_0 = F32ToBf16(lo);
|
||||
Packet8bf fp16_1 = F32ToBf16(hi);
|
||||
return vec_pack(reinterpret_cast<Packet4ui>(fp16_0.m_val), reinterpret_cast<Packet4ui>(fp16_1.m_val));
|
||||
#else
|
||||
return F32ToBf16Two(lo, hi);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert and pack two float Packets into one bfloat16 Packet - odd & even order
|
||||
*/
|
||||
EIGEN_STRONG_INLINE Packet8bf F32ToBf16(Packet4f even, Packet4f odd){
|
||||
Packet4f bf_odd, bf_even;
|
||||
bf_odd = reinterpret_cast<Packet4f>(F32ToBf16(odd).m_val);
|
||||
bf_odd = plogical_shift_left<16>(bf_odd);
|
||||
bf_even = reinterpret_cast<Packet4f>(F32ToBf16(even).m_val);
|
||||
return reinterpret_cast<Packet8us>(por<Packet4f>(bf_even, bf_odd));
|
||||
#ifdef _ARCH_PWR10
|
||||
return pmerge(reinterpret_cast<Packet4ui>(F32ToBf16(even).m_val), reinterpret_cast<Packet4ui>(F32ToBf16(odd).m_val));
|
||||
#else
|
||||
return F32ToBf16Two<false>(even, odd);
|
||||
#endif
|
||||
}
|
||||
#define BF16_TO_F32_UNARY_OP_WRAPPER(OP, A) \
|
||||
Packet4f a_even = Bf16ToF32Even(A);\
|
||||
@ -2493,11 +2646,7 @@ ptranspose(PacketBlock<Packet16uc,16>& kernel) {
|
||||
template<typename Packet> EIGEN_STRONG_INLINE
|
||||
Packet pblend4(const Selector<4>& ifPacket, const Packet& thenPacket, const Packet& elsePacket) {
|
||||
Packet4ui select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3] };
|
||||
#ifdef __POWER8_VECTOR__
|
||||
Packet4ui mask = reinterpret_cast<Packet4ui>(vec_neg(reinterpret_cast<Packet4i>(select)));
|
||||
#else
|
||||
Packet4ui mask = reinterpret_cast<Packet4ui>(vec_cmpeq(reinterpret_cast<Packet4ui>(select), reinterpret_cast<Packet4ui>(p4i_ONE)));
|
||||
#endif
|
||||
Packet4ui mask = reinterpret_cast<Packet4ui>(pnegate(reinterpret_cast<Packet4i>(select)));
|
||||
return vec_sel(elsePacket, thenPacket, mask);
|
||||
}
|
||||
|
||||
@ -2512,11 +2661,7 @@ template<> EIGEN_STRONG_INLINE Packet4f pblend(const Selector<4>& ifPacket, cons
|
||||
template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, const Packet8s& thenPacket, const Packet8s& elsePacket) {
|
||||
Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
|
||||
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
|
||||
#ifdef __POWER8_VECTOR__
|
||||
Packet8us mask = reinterpret_cast<Packet8us>(vec_neg(reinterpret_cast<Packet8s>(select)));
|
||||
#else
|
||||
Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(select, p8us_ONE));
|
||||
#endif
|
||||
Packet8us mask = reinterpret_cast<Packet8us>(pnegate(reinterpret_cast<Packet8s>(select)));
|
||||
Packet8s result = vec_sel(elsePacket, thenPacket, mask);
|
||||
return result;
|
||||
}
|
||||
@ -2524,11 +2669,7 @@ template<> EIGEN_STRONG_INLINE Packet8s pblend(const Selector<8>& ifPacket, cons
|
||||
template<> EIGEN_STRONG_INLINE Packet8us pblend(const Selector<8>& ifPacket, const Packet8us& thenPacket, const Packet8us& elsePacket) {
|
||||
Packet8us select = { ifPacket.select[0], ifPacket.select[1], ifPacket.select[2], ifPacket.select[3],
|
||||
ifPacket.select[4], ifPacket.select[5], ifPacket.select[6], ifPacket.select[7] };
|
||||
#ifdef __POWER8_VECTOR__
|
||||
Packet8us mask = reinterpret_cast<Packet8us>(vec_neg(reinterpret_cast<Packet8s>(select)));
|
||||
#else
|
||||
Packet8us mask = reinterpret_cast<Packet8us>(vec_cmpeq(reinterpret_cast<Packet8us>(select), p8us_ONE));
|
||||
#endif
|
||||
Packet8us mask = reinterpret_cast<Packet8us>(pnegate(reinterpret_cast<Packet8s>(select)));
|
||||
return vec_sel(elsePacket, thenPacket, mask);
|
||||
}
|
||||
|
||||
@ -2542,11 +2683,7 @@ template<> EIGEN_STRONG_INLINE Packet16c pblend(const Selector<16>& ifPacket, co
|
||||
ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
|
||||
ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
|
||||
|
||||
#ifdef __POWER8_VECTOR__
|
||||
Packet16uc mask = reinterpret_cast<Packet16uc>(vec_neg(reinterpret_cast<Packet16c>(select)));
|
||||
#else
|
||||
Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
|
||||
#endif
|
||||
Packet16uc mask = reinterpret_cast<Packet16uc>(pnegate(reinterpret_cast<Packet16c>(select)));
|
||||
return vec_sel(elsePacket, thenPacket, mask);
|
||||
}
|
||||
|
||||
@ -2556,11 +2693,7 @@ template<> EIGEN_STRONG_INLINE Packet16uc pblend(const Selector<16>& ifPacket, c
|
||||
ifPacket.select[8], ifPacket.select[9], ifPacket.select[10], ifPacket.select[11],
|
||||
ifPacket.select[12], ifPacket.select[13], ifPacket.select[14], ifPacket.select[15] };
|
||||
|
||||
#ifdef __POWER8_VECTOR__
|
||||
Packet16uc mask = reinterpret_cast<Packet16uc>(vec_neg(reinterpret_cast<Packet16c>(select)));
|
||||
#else
|
||||
Packet16uc mask = reinterpret_cast<Packet16uc>(vec_cmpeq(reinterpret_cast<Packet16uc>(select), p16uc_ONE));
|
||||
#endif
|
||||
Packet16uc mask = reinterpret_cast<Packet16uc>(pnegate(reinterpret_cast<Packet16c>(select)));
|
||||
return vec_sel(elsePacket, thenPacket, mask);
|
||||
}
|
||||
|
||||
@ -2636,10 +2769,7 @@ template<> EIGEN_STRONG_INLINE Packet8us pcast<Packet8bf, Packet8us>(const Packe
|
||||
low_odd = vec_sel(low_even, p4ui_low_mask, overflow_selector);
|
||||
}
|
||||
|
||||
low_odd = plogical_shift_left<16>(low_odd);
|
||||
|
||||
Packet4ui int_final = por<Packet4ui>(low_even, low_odd);
|
||||
return reinterpret_cast<Packet8us>(int_final);
|
||||
return pmerge(low_even, low_odd);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8bf pcast<Packet8us, Packet8bf>(const Packet8us& a) {
|
||||
@ -2937,7 +3067,21 @@ template<> EIGEN_STRONG_INLINE Packet2d preverse(const Packet2d& a)
|
||||
return vec_sld(a, a, 8);
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs(a); }
|
||||
#ifdef __POWER8_VECTOR__
|
||||
template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a) { return (Packet2d)vec_sra((Packet2l)a, vec_splats((unsigned long long)(63))); }
|
||||
#else
|
||||
#ifdef _BIG_ENDIAN
|
||||
static Packet16uc p16uc_DUPSIGN = { 0,0,0,0, 0,0,0,0, 8,8,8,8, 8,8,8,8 };
|
||||
#else
|
||||
static Packet16uc p16uc_DUPSIGN = { 7,7,7,7, 7,7,7,7, 15,15,15,15, 15,15,15,15 };
|
||||
#endif
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d psignbit(const Packet2d& a)
|
||||
{
|
||||
Packet16c tmp = vec_sra(reinterpret_cast<Packet16c>(a), vec_splats((unsigned char)(7)));
|
||||
return reinterpret_cast<Packet2d>(vec_perm(tmp, tmp, p16uc_DUPSIGN));
|
||||
}
|
||||
#endif
|
||||
// VSX support varies between different compilers and even different
|
||||
// versions of the same compiler. For gcc version >= 4.9.3, we can use
|
||||
// vec_cts to efficiently convert Packet2d to Packet2l. Otherwise, use
|
||||
|
Loading…
x
Reference in New Issue
Block a user