mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 17:19:36 +08:00
New VSX version of BF16 GEMV (Power) - up to 6.7X faster
This commit is contained in:
parent
29c8e3c754
commit
03f646b7e3
@ -2920,13 +2920,13 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols,
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][4])
|
||||
template<Index num_acc, Index size = 4>
|
||||
EIGEN_ALWAYS_INLINE void zeroAccumulators(Packet4f (&acc)[num_acc][size])
|
||||
{
|
||||
Packet4f z = pset1<Packet4f>(float(0));
|
||||
|
||||
for(Index k = 0; k < num_acc; k++) {
|
||||
for(Index j = 0; j < 4; j++) {
|
||||
for(Index j = 0; j < size; j++) {
|
||||
acc[k][j] = z;
|
||||
}
|
||||
}
|
||||
@ -3246,59 +3246,6 @@ void gemmbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16*
|
||||
|
||||
#include "MatrixVectorProduct.h"
|
||||
|
||||
template<const Index size, bool non_unit_stride, Index delta>
|
||||
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra)
|
||||
{
|
||||
if (non_unit_stride) {
|
||||
if (size < 8) {
|
||||
pscatter_partial(dst + delta*resInc, data, resInc, extra);
|
||||
} else {
|
||||
pscatter(dst + delta*resInc, data, resInc);
|
||||
}
|
||||
} else {
|
||||
if (size < 8) {
|
||||
pstoreu_partial(dst + delta, data, extra);
|
||||
} else {
|
||||
pstoreu(dst + delta, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index size, bool non_unit_stride = false>
|
||||
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
|
||||
{
|
||||
constexpr Index extra = ((size < 8) ? 8 : size);
|
||||
for(; i + size <= rows; i += extra, dst += extra*resInc){
|
||||
PacketBlock<Packet8bf,(size+7)/8> r32;
|
||||
r32.packet[0] = convertF32toBF16VSX(result + i + 0);
|
||||
if (size >= 16) {
|
||||
r32.packet[1] = convertF32toBF16VSX(result + i + 8);
|
||||
}
|
||||
if (size >= 32) {
|
||||
r32.packet[2] = convertF32toBF16VSX(result + i + 16);
|
||||
r32.packet[3] = convertF32toBF16VSX(result + i + 24);
|
||||
}
|
||||
storeBF16fromResult<size, non_unit_stride, 0>(dst, r32.packet[0], resInc, rows & 7);
|
||||
if (size >= 16) {
|
||||
storeBF16fromResult<size, non_unit_stride, 8>(dst, r32.packet[1], resInc);
|
||||
}
|
||||
if (size >= 32) {
|
||||
storeBF16fromResult<size, non_unit_stride, 16>(dst, r32.packet[2], resInc);
|
||||
storeBF16fromResult<size, non_unit_stride, 24>(dst, r32.packet[3], resInc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool non_unit_stride = false>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16* dst, Index resInc = 1)
|
||||
{
|
||||
Index i = 0;
|
||||
convertPointerF32toBF16VSX<32,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16VSX<16,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16VSX<8,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16VSX<1,non_unit_stride>(i, result, rows, dst, resInc);
|
||||
}
|
||||
|
||||
/************************************
|
||||
* ppc64le template specializations *
|
||||
* **********************************/
|
||||
|
@ -96,6 +96,12 @@ EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index cols,
|
||||
template<bool rhsExtraCols, bool lhsExtraRows>
|
||||
EIGEN_ALWAYS_INLINE void storeResults(Packet4f (&acc)[4], Index rows, const Packet4f pAlpha, float* result, Index extra_cols, Index extra_rows);
|
||||
|
||||
template<Index num_acc, bool extraRows, Index size = 4>
|
||||
EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha, Index extra_rows);
|
||||
|
||||
template<Index num_acc, Index size = 4>
|
||||
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha);
|
||||
|
||||
template<typename Packet>
|
||||
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);
|
||||
|
||||
|
@ -44,6 +44,7 @@ EIGEN_ALWAYS_INLINE void KLoop
|
||||
{
|
||||
Packet8bf lhs[num_lhs], rhs[num_rhs];
|
||||
|
||||
BFLOAT16_UNROLL
|
||||
for(Index i = 0; i < (num_rhs - (rhsExtraCols ? 1 : 0)); i++){
|
||||
rhs[i] = loadRhsBfloat16<zero>(indexB + k*4, strideB, i);
|
||||
}
|
||||
@ -52,8 +53,21 @@ EIGEN_ALWAYS_INLINE void KLoop
|
||||
}
|
||||
|
||||
indexA += k*(lhsExtraRows ? extra_rows : num_packets);
|
||||
for(Index j = 0; j < num_lhs; j++) {
|
||||
lhs[j] = loadBfloat16<zero>(indexA + j*(zero ? 4 : 8)); // a packet of bfloat16 has 8 elements
|
||||
if (num_lhs == 1) {
|
||||
lhs[0] = loadBfloat16<zero>(indexA);
|
||||
} else {
|
||||
BFLOAT16_UNROLL
|
||||
for(Index j = 0; j < num_lhs; j += 2) {
|
||||
Packet8bf lhs1 = ploadu<Packet8bf>(indexA + (j + 0)*(zero ? 4 : 8));
|
||||
if (zero) {
|
||||
Packet8bf lhs2 = pset1<Packet8bf>(Eigen::bfloat16(0));
|
||||
lhs[j + 0] = vec_mergeh(lhs1.m_val, lhs2.m_val);
|
||||
lhs[j + 1] = vec_mergel(lhs1.m_val, lhs2.m_val);
|
||||
} else {
|
||||
lhs[j + 0] = lhs1;
|
||||
lhs[j + 1] = ploadu<Packet8bf>(indexA + (j + 1)*8);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
BFLOAT16_UNROLL
|
||||
@ -84,7 +98,9 @@ EIGEN_ALWAYS_INLINE void disassembleAccumulators(__vector_quad (&quad_acc)[num_a
|
||||
template<Index num_acc, bool rhsExtraCols, bool lhsExtraRows, Index num_rhs, Index num_lhs>
|
||||
EIGEN_ALWAYS_INLINE void outputResults(Packet4f (&acc)[num_acc][4], Index rows, const Packet4f pAlpha, float* result, const Index extra_cols, Index extra_rows)
|
||||
{
|
||||
BFLOAT16_UNROLL
|
||||
for(Index i = 0, k = 0; i < num_rhs - (rhsExtraCols ? 1 : 0); i++, result += 4*rows){
|
||||
BFLOAT16_UNROLL
|
||||
for(Index j = 0; j < num_lhs; j++, k++) {
|
||||
storeResults<false, lhsExtraRows>(acc[k], rows, pAlpha, result + j*4, extra_cols, extra_rows);
|
||||
}
|
||||
@ -339,29 +355,6 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat
|
||||
#undef MAX_BFLOAT16_ACC
|
||||
|
||||
#if !EIGEN_ALTIVEC_DISABLE_MMA
|
||||
template<bool extraRows>
|
||||
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
|
||||
{
|
||||
Packet4f d0 = ploadu<Packet4f>(result);
|
||||
d0 = pmadd(acc, pAlpha, d0);
|
||||
if (extraRows) {
|
||||
pstoreu_partial(result, d0, extra_rows);
|
||||
} else {
|
||||
pstoreu(result, d0);
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc, bool extraRows>
|
||||
EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha, Index extra_rows)
|
||||
{
|
||||
for(Index k = 0; k < num_acc - (extraRows ? 1 : 0); k++) {
|
||||
outputVecCol<false>(acc[k][0], result + k*4, pAlpha, extra_rows);
|
||||
}
|
||||
if (extraRows) {
|
||||
outputVecCol<true>(acc[num_acc - 1][0], result + (num_acc - 1)*4, pAlpha, extra_rows);
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc, typename LhsMapper, bool zero>
|
||||
EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1)
|
||||
{
|
||||
@ -396,6 +389,7 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v
|
||||
}
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
BFLOAT16_UNROLL
|
||||
for(Index k = 0; k < num_acc; k += 2) {
|
||||
loadVecLoop<num_acc, LhsMapper, zero>(k, lhs2, a0, b1);
|
||||
}
|
||||
@ -557,6 +551,7 @@ EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f (&acc)[num_acc][4], Index k)
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4])
|
||||
{
|
||||
BFLOAT16_UNROLL
|
||||
for(Index k = 0; k < num_acc; k += 4) {
|
||||
preduxVecResults2<num_acc>(acc, k + 0);
|
||||
if (num_acc > (k + 2)) {
|
||||
@ -566,27 +561,6 @@ EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4])
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha)
|
||||
{
|
||||
constexpr Index extra = num_acc & 3;
|
||||
|
||||
for(Index k = 0; k < num_acc; k += 4) {
|
||||
Packet4f d0 = ploadu<Packet4f>(result + k);
|
||||
d0 = pmadd(acc[k + 0][0], pAlpha, d0);
|
||||
|
||||
if (num_acc > (k + 3)) {
|
||||
pstoreu(result + k, d0);
|
||||
} else {
|
||||
if (extra == 3) {
|
||||
pstoreu_partial(result + k, d0, extra);
|
||||
} else {
|
||||
memcpy((void *)(result + k), (void *)(&d0), sizeof(float) * extra);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
|
||||
EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols)
|
||||
{
|
||||
@ -599,6 +573,7 @@ EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const L
|
||||
}
|
||||
|
||||
const LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
BFLOAT16_UNROLL
|
||||
for(Index k = 0; k < num_acc; k++) {
|
||||
if (extra) {
|
||||
a0[k] = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
|
||||
|
@ -464,6 +464,492 @@ EIGEN_STRONG_INLINE void gemv_col(
|
||||
}
|
||||
}
|
||||
|
||||
template<bool extraRows>
|
||||
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
|
||||
{
|
||||
Packet4f d0 = ploadu<Packet4f>(result);
|
||||
d0 = pmadd(acc, pAlpha, d0);
|
||||
if (extraRows) {
|
||||
pstoreu_partial(result, d0, extra_rows);
|
||||
} else {
|
||||
pstoreu(result, d0);
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc, bool extraRows, Index size>
|
||||
EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha, Index extra_rows)
|
||||
{
|
||||
constexpr Index real_acc = (num_acc - (extraRows ? 1 : 0));
|
||||
for(Index k = 0; k < real_acc; k++) {
|
||||
outputVecCol<false>(acc[k][0], result + k*4, pAlpha, extra_rows);
|
||||
}
|
||||
if (extraRows) {
|
||||
outputVecCol<true>(acc[real_acc][0], result + real_acc*4, pAlpha, extra_rows);
|
||||
}
|
||||
}
|
||||
|
||||
static Packet16uc p16uc_MERGE16_32_V1 = { 0, 1, 16,17, 0, 1, 16,17, 0, 1, 16,17, 0, 1, 16,17 };
|
||||
static Packet16uc p16uc_MERGE16_32_V2 = { 2, 3, 18,19, 2, 3, 18,19, 2, 3, 18,19, 2, 3, 18,19 };
|
||||
|
||||
template<Index num_acc, typename LhsMapper, bool zero>
|
||||
EIGEN_ALWAYS_INLINE void loadVecLoopVSX(Index k, LhsMapper& lhs, Packet4f (&a0)[num_acc][2])
|
||||
{
|
||||
Packet8bf c0 = lhs.template loadPacket<Packet8bf>(k*4, 0);
|
||||
Packet8bf b1;
|
||||
if (!zero) {
|
||||
b1 = lhs.template loadPacket<Packet8bf>(k*4, 1);
|
||||
|
||||
a0[k + 0][1] = oneConvertBF16Hi(b1.m_val);
|
||||
}
|
||||
a0[k + 0][0] = oneConvertBF16Hi(c0.m_val);
|
||||
|
||||
if (num_acc > (k + 1)) {
|
||||
a0[k + 1][0] = oneConvertBF16Lo(c0.m_val);
|
||||
if (!zero) {
|
||||
a0[k + 1][1] = oneConvertBF16Lo(b1.m_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc, bool zero>
|
||||
EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[num_acc][2], Packet4f (&b0)[2])
|
||||
{
|
||||
for(Index k = 0; k < num_acc; k++) {
|
||||
for(Index i = 0; i < (zero ? 1 : 2); i++) {
|
||||
acc[k][i] = pmadd(b0[i], a0[k][i], acc[k][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero>
|
||||
EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2])
|
||||
{
|
||||
Packet4f a0[num_acc][2], b0[2];
|
||||
Packet8bf b2 = rhs.template loadPacket<Packet8bf>(j + 0);
|
||||
|
||||
b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1);
|
||||
if (!zero) {
|
||||
b0[1] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V2);
|
||||
}
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
for(Index k = 0; k < num_acc; k += 2) {
|
||||
loadVecLoopVSX<num_acc, LhsMapper, zero>(k, lhs2, a0);
|
||||
}
|
||||
|
||||
multVecVSX<num_acc, zero>(acc, a0, b0);
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f (&acc)[num_acc][2])
|
||||
{
|
||||
for(Index i = 0; i < num_acc; i++) {
|
||||
acc[i][0] = acc[i][0] + acc[i][1];
|
||||
}
|
||||
}
|
||||
|
||||
// Uses 2X the accumulators or 4X the number of VSX registers
|
||||
#define MAX_BFLOAT16_VEC_ACC_VSX 8
|
||||
|
||||
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
|
||||
void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
constexpr Index step = (num_acc * 4);
|
||||
const Index extra_rows = (extraRows) ? (rows & 3) : 0;
|
||||
constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC_VSX);
|
||||
|
||||
do{
|
||||
Packet4f acc[num_acc][2];
|
||||
|
||||
zeroAccumulators<num_acc, 2>(acc);
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||
for(Index j = 0; j + 2 <= cend; j += 2) {
|
||||
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, acc);
|
||||
}
|
||||
if (cend & 1) {
|
||||
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true>(cend - 1, lhs2, rhs, acc);
|
||||
}
|
||||
|
||||
addResultsVSX<num_acc>(acc);
|
||||
|
||||
outputVecColResults<num_acc, extraRows, 2>(acc, result, pAlpha, extra_rows);
|
||||
|
||||
result += step;
|
||||
} while(multiIters && (step <= rows - (row += step)));
|
||||
}
|
||||
|
||||
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
|
||||
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
|
||||
colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper, bool extraRows>
|
||||
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
switch ((rows - row) >> 2) {
|
||||
case 7:
|
||||
colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 6:
|
||||
colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 5:
|
||||
colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 4:
|
||||
colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 3:
|
||||
colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 2:
|
||||
colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 1:
|
||||
colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
default:
|
||||
if (extraRows) {
|
||||
colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper>
|
||||
EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
Index row = 0;
|
||||
if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) {
|
||||
colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
result += row;
|
||||
}
|
||||
if (rows & 3) {
|
||||
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
} else {
|
||||
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index size, bool inc, Index delta>
|
||||
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc, Index extra)
|
||||
{
|
||||
if (inc) {
|
||||
if (size < 8) {
|
||||
pscatter_partial(dst + delta*resInc, data, resInc, extra);
|
||||
} else {
|
||||
pscatter(dst + delta*resInc, data, resInc);
|
||||
}
|
||||
} else {
|
||||
if (size < 8) {
|
||||
pstoreu_partial(dst + delta, data, extra);
|
||||
} else {
|
||||
pstoreu(dst + delta, data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index size, bool inc = false>
|
||||
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16VSX(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc = 1)
|
||||
{
|
||||
constexpr Index extra = ((size < 8) ? 8 : size);
|
||||
for(; i + size <= rows; i += extra, dst += extra*resInc){
|
||||
PacketBlock<Packet8bf,(size+7)/8> r32;
|
||||
r32.packet[0] = convertF32toBF16VSX(result + i + 0);
|
||||
if (size >= 16) {
|
||||
r32.packet[1] = convertF32toBF16VSX(result + i + 8);
|
||||
}
|
||||
if (size >= 32) {
|
||||
r32.packet[2] = convertF32toBF16VSX(result + i + 16);
|
||||
r32.packet[3] = convertF32toBF16VSX(result + i + 24);
|
||||
}
|
||||
storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc, rows & 7);
|
||||
if (size >= 16) {
|
||||
storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
|
||||
}
|
||||
if (size >= 32) {
|
||||
storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
|
||||
storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool inc = false>
|
||||
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16VSX(float *result, Index rows, bfloat16* dst, Index resInc = 1)
|
||||
{
|
||||
Index i = 0;
|
||||
convertPointerF32toBF16VSX<32,inc>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16VSX<16,inc>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16VSX<8,inc>(i, result, rows, dst, resInc);
|
||||
convertPointerF32toBF16VSX<1,inc>(i, result, rows, dst, resInc);
|
||||
}
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper>
|
||||
void gemv_bfloat16_col(
|
||||
Index rows, Index cols,
|
||||
const LhsMapper& alhs,
|
||||
const RhsMapper& rhs,
|
||||
bfloat16* res, Index resIncr,
|
||||
bfloat16 alpha)
|
||||
{
|
||||
typedef typename RhsMapper::LinearMapper LinearMapper;
|
||||
|
||||
EIGEN_UNUSED_VARIABLE(resIncr);
|
||||
eigen_internal_assert(resIncr == 1);
|
||||
|
||||
// The following copy tells the compiler that lhs's attributes are not modified outside this function
|
||||
// This helps GCC to generate proper code.
|
||||
LhsMapper lhs(alhs);
|
||||
RhsMapper rhs2(rhs);
|
||||
|
||||
const Index lhsStride = lhs.stride();
|
||||
|
||||
// TODO: improve the following heuristic:
|
||||
const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8);
|
||||
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
||||
Packet4f pAlpha = pset1<Packet4f>(falpha);
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
|
||||
|
||||
convertArrayPointerBF16toF32(result, 1, rows, res);
|
||||
|
||||
for (Index j2 = 0; j2 < cols; j2 += block_cols)
|
||||
{
|
||||
Index jend = numext::mini(j2 + block_cols, cols);
|
||||
|
||||
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
|
||||
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
|
||||
calcVSXVecColLoops<LhsMapper, LinearMapper>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
|
||||
}
|
||||
|
||||
convertArrayPointerF32toBF16VSX(result, rows, res);
|
||||
}
|
||||
|
||||
template<Index num_acc, Index size>
|
||||
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha)
|
||||
{
|
||||
constexpr Index extra = num_acc & 3;
|
||||
|
||||
for(Index k = 0; k < num_acc; k += 4) {
|
||||
Packet4f d0 = ploadu<Packet4f>(result + k);
|
||||
d0 = pmadd(acc[k + 0][0], pAlpha, d0);
|
||||
|
||||
if (num_acc > (k + 3)) {
|
||||
pstoreu(result + k, d0);
|
||||
} else {
|
||||
if (extra == 3) {
|
||||
pstoreu_partial(result + k, d0, extra);
|
||||
} else {
|
||||
memcpy((void *)(result + k), (void *)(&d0), sizeof(float) * extra);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void preduxVecResults2VSX(Packet4f (&acc)[num_acc][2], Index k)
|
||||
{
|
||||
if (num_acc > (k + 1)) {
|
||||
acc[k][1] = vec_mergel(acc[k + 0][0], acc[k + 1][0]);
|
||||
acc[k][0] = vec_mergeh(acc[k + 0][0], acc[k + 1][0]);
|
||||
acc[k][0] = acc[k][0] + acc[k][1];
|
||||
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
|
||||
} else {
|
||||
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 8);
|
||||
#ifdef _BIG_ENDIAN
|
||||
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
|
||||
#else
|
||||
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template<Index num_acc>
|
||||
EIGEN_ALWAYS_INLINE void preduxVecResultsVSX(Packet4f (&acc)[num_acc][2])
|
||||
{
|
||||
for(Index k = 0; k < num_acc; k += 4) {
|
||||
preduxVecResults2VSX<num_acc>(acc, k + 0);
|
||||
if (num_acc > (k + 2)) {
|
||||
preduxVecResults2VSX<num_acc>(acc, k + 2);
|
||||
#ifdef EIGEN_VECTORIZE_VSX
|
||||
acc[k + 0][0] = reinterpret_cast<Packet4f>(vec_mergeh(reinterpret_cast<Packet2ul>(acc[k + 0][0]), reinterpret_cast<Packet2ul>(acc[k + 2][0])));
|
||||
#else
|
||||
acc[k + 0][0] = reinterpret_cast<Packet4f>(vec_perm(acc[k + 0][0],acc[k + 2][0],p16uc_TRANSPOSE64_HI));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef _ARCH_PWR9
|
||||
EIGEN_ALWAYS_INLINE Packet8us loadPacketPartialZero(Packet8us data, Index extra_cols)
|
||||
{
|
||||
Packet16uc shift = pset1<Packet16uc>(8 * 2 * (8 - extra_cols));
|
||||
#ifdef _BIG_ENDIAN
|
||||
return reinterpret_cast<Packet8us>(vec_slo(vec_sro(reinterpret_cast<Packet16uc>(data), shift), shift));
|
||||
#else
|
||||
return reinterpret_cast<Packet8us>(vec_sro(vec_slo(reinterpret_cast<Packet16uc>(data), shift), shift));
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
|
||||
EIGEN_ALWAYS_INLINE void multVSXVecLoop(Packet4f (&acc)[num_acc][2], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols)
|
||||
{
|
||||
Packet4f a0[num_acc][2], b0[2];
|
||||
Packet8bf a1, b1;
|
||||
|
||||
if (extra) {
|
||||
b1 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
|
||||
#ifndef _ARCH_PWR9
|
||||
b1 = loadPacketPartialZero(b1.m_val, extra_cols);
|
||||
#endif
|
||||
} else {
|
||||
b1 = rhs.template loadPacket<Packet8bf>(j);
|
||||
}
|
||||
b0[0] = oneConvertBF16Hi(b1.m_val);
|
||||
b0[1] = oneConvertBF16Lo(b1.m_val);
|
||||
|
||||
const LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||
for(Index k = 0; k < num_acc; k++) {
|
||||
if (extra) {
|
||||
a1 = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
|
||||
#ifndef _ARCH_PWR9
|
||||
a1 = loadPacketPartialZero(a1.m_val, extra_cols);
|
||||
#endif
|
||||
} else {
|
||||
a1 = lhs2.template loadPacket<Packet8bf>(k, 0);
|
||||
}
|
||||
a0[k][0] = oneConvertBF16Hi(a1.m_val);
|
||||
a0[k][1] = oneConvertBF16Lo(a1.m_val);
|
||||
}
|
||||
|
||||
multVecVSX<num_acc, false>(acc, a0, b0);
|
||||
}
|
||||
|
||||
template<Index num_acc, typename LhsMapper, typename RhsMapper>
|
||||
EIGEN_ALWAYS_INLINE void vecVSXLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2], Index extra_cols)
|
||||
{
|
||||
Index j = 0;
|
||||
for(; j + 8 <= cols; j += 8){
|
||||
multVSXVecLoop<num_acc, LhsMapper, RhsMapper, false>(acc, lhs, rhs, j, extra_cols);
|
||||
}
|
||||
|
||||
if (extra_cols) {
|
||||
multVSXVecLoop<num_acc, LhsMapper, RhsMapper, true>(acc, lhs, rhs, j, extra_cols);
|
||||
}
|
||||
}
|
||||
|
||||
template<const Index num_acc, typename LhsMapper, typename RhsMapper>
|
||||
void colVSXVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC_VSX);
|
||||
const Index extra_cols = (cols & 7);
|
||||
|
||||
do{
|
||||
Packet4f acc[num_acc][2];
|
||||
|
||||
zeroAccumulators<num_acc, 2>(acc);
|
||||
|
||||
const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||
vecVSXLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, acc, extra_cols);
|
||||
|
||||
addResultsVSX<num_acc>(acc);
|
||||
|
||||
preduxVecResultsVSX<num_acc>(acc);
|
||||
|
||||
outputVecResults<num_acc, 2>(acc, result, pAlpha);
|
||||
|
||||
result += num_acc;
|
||||
} while(multiIters && (num_acc <= rows - (row += num_acc)));
|
||||
}
|
||||
|
||||
template<const Index num_acc, typename LhsMapper, typename RhsMapper>
|
||||
EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtraN(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
|
||||
colVSXVecLoopBody<num_acc, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper>
|
||||
EIGEN_ALWAYS_INLINE void colVSXVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
switch (rows - row) {
|
||||
case 7:
|
||||
colVSXVecLoopBodyExtraN<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 6:
|
||||
colVSXVecLoopBodyExtraN<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 5:
|
||||
colVSXVecLoopBodyExtraN<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 4:
|
||||
colVSXVecLoopBodyExtraN<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 3:
|
||||
colVSXVecLoopBodyExtraN<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 2:
|
||||
colVSXVecLoopBodyExtraN<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
case 1:
|
||||
colVSXVecLoopBodyExtraN<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper>
|
||||
EIGEN_ALWAYS_INLINE void calcVSXVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||
{
|
||||
Index row = 0;
|
||||
if (rows >= MAX_BFLOAT16_VEC_ACC_VSX) {
|
||||
colVSXVecLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
result += row;
|
||||
}
|
||||
colVSXVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||
}
|
||||
|
||||
template<typename LhsMapper, typename RhsMapper>
|
||||
EIGEN_STRONG_INLINE void gemv_bfloat16_row(
|
||||
Index rows, Index cols,
|
||||
const LhsMapper& alhs,
|
||||
const RhsMapper& rhs,
|
||||
bfloat16* res, Index resIncr,
|
||||
bfloat16 alpha)
|
||||
{
|
||||
typedef typename RhsMapper::LinearMapper LinearMapper;
|
||||
|
||||
// The following copy tells the compiler that lhs's attributes are not modified outside this function
|
||||
// This helps GCC to generate proper code.
|
||||
LhsMapper lhs(alhs);
|
||||
LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
|
||||
|
||||
eigen_internal_assert(rhs.stride() == 1);
|
||||
|
||||
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
||||
const Packet4f pAlpha = pset1<Packet4f>(falpha);
|
||||
|
||||
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
|
||||
if (resIncr == 1) {
|
||||
convertArrayPointerBF16toF32(result, 1, rows, res);
|
||||
} else {
|
||||
convertArrayPointerBF16toF32<true>(result, 1, rows, res, resIncr);
|
||||
}
|
||||
calcVSXVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
|
||||
if (resIncr == 1) {
|
||||
convertArrayPointerF32toBF16VSX(result, rows, res);
|
||||
} else {
|
||||
convertArrayPointerF32toBF16VSX<true>(result, rows, res, resIncr);
|
||||
}
|
||||
}
|
||||
|
||||
#undef MAX_BFLOAT16_VEC_ACC_VSX
|
||||
|
||||
const Packet16uc p16uc_COMPLEX32_XORFLIP = { 0x44,0x55,0x66,0x77, 0x00,0x11,0x22,0x33, 0xcc,0xdd,0xee,0xff, 0x88,0x99,0xaa,0xbb };
|
||||
const Packet16uc p16uc_COMPLEX64_XORFLIP = { 0x88,0x99,0xaa,0xbb, 0xcc,0xdd,0xee,0xff, 0x00,0x11,0x22,0x33, 0x44,0x55,0x66,0x77 };
|
||||
|
||||
@ -2062,6 +2548,13 @@ EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
|
||||
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
|
||||
|
||||
#ifdef USE_GEMV_MMA
|
||||
#define gemv_bf16_col gemvMMA_bfloat16_col
|
||||
#define gemv_bf16_row gemvMMA_bfloat16_row
|
||||
#else
|
||||
#define gemv_bf16_col gemv_bfloat16_col
|
||||
#define gemv_bf16_row gemv_bfloat16_row
|
||||
#endif
|
||||
|
||||
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \
|
||||
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
|
||||
struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, ConjugateLhs, bfloat16, RhsMapper, ConjugateRhs, Version> \
|
||||
@ -2072,7 +2565,7 @@ struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, Conju
|
||||
const RhsMapper& rhs, \
|
||||
bfloat16* res, Index resIncr, \
|
||||
bfloat16 alpha) { \
|
||||
gemvMMA_bfloat16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||
gemv_bf16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||
} \
|
||||
};
|
||||
|
||||
@ -2086,13 +2579,12 @@ struct general_matrix_vector_product<Index, bfloat16, LhsMapper, RowMajor, Conju
|
||||
const RhsMapper& rhs, \
|
||||
bfloat16* res, Index resIncr, \
|
||||
bfloat16 alpha) { \
|
||||
gemvMMA_bfloat16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||
gemv_bf16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||
} \
|
||||
};
|
||||
|
||||
EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16()
|
||||
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16()
|
||||
#endif
|
||||
|
||||
template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
|
||||
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user