Add MMA to BF16 GEMV - 5.0-6.3X faster (for Power)

This commit is contained in:
Chip Kerchner 2023-03-13 19:37:13 +00:00 committed by Rasmus Munk Larsen
parent 2067b54b13
commit 9d72412385
2 changed files with 517 additions and 3 deletions

View File

@ -146,8 +146,8 @@ EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f
zeroAccumulators<num_acc>(quad_acc); zeroAccumulators<num_acc>(quad_acc);
Index k; Index k = 0;
for(k = 0; k + 2 <= depth; k += 2){ for(Index j = depth >> 1; j--; k += 2){
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows); KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
} }
if(depth&1){ if(depth&1){
@ -356,7 +356,6 @@ template<typename DataMapper>
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{ {
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha); float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
if (falpha == float(0)) return;
const Packet4f pAlpha = pset1<Packet4f>(falpha); const Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
@ -395,6 +394,488 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat
convertArrayF32toBF16<DataMapper>(result, cols, rows, res); convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
} }
template<const Index size, bool inc, Index delta>
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc)
{
if (inc) {
if (size == 4) {
pscatter_partial(dst + delta*resInc, data, resInc, 4);
} else {
pscatter(dst + delta*resInc, data, resInc);
}
} else {
if (size == 4) {
pstoreu_partial(dst + delta, data, 4);
} else {
pstoreu(dst + delta, data);
}
}
}
template<const Index size, bool inc>
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc)
{
for(Index j = (rows - i) / size; j--; i += size, dst += size*resInc){
PacketBlock<Packet8bf,(size+4)/8> r32;
r32.packet[0] = convertF32toBF16<size != 4>(result + i + 0);
if (size >= 16) {
r32.packet[1] = convertF32toBF16<true>(result + i + 8);
}
if (size >= 32) {
r32.packet[2] = convertF32toBF16<true>(result + i + 16);
r32.packet[3] = convertF32toBF16<true>(result + i + 24);
}
storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc);
if (size >= 16) {
storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
}
if (size >= 32) {
storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
}
}
}
template<bool inc, Index delta>
EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc)
{
if (inc) {
return pgather<bfloat16, Packet8bf>(src + delta*resInc, resInc);
} else {
return ploadu<Packet8bf>(src + delta);
}
}
template<const Index size, bool inc>
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
{
for(; i + size <= rows; i += size, src += size*resInc){
PacketBlock<Packet8bf,(size+4)/8> r32;
r32.packet[0] = loadBF16fromResult<inc, 0>(src, resInc);
if (size >= 16) {
r32.packet[1] = loadBF16fromResult<inc, 8>(src, resInc);
}
if (size >= 32) {
r32.packet[2] = loadBF16fromResult<inc, 16>(src, resInc);
r32.packet[3] = loadBF16fromResult<inc, 24>(src, resInc);
}
storeConvertBlockBF16<size>(result + i, r32);
}
}
template<bool inc = false>
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index rows, bfloat16* src, Index resInc = 1)
{
Index i = 0;
convertPointerBF16toF32<32, inc>(i, result, rows, src, resInc);
convertPointerBF16toF32<16, inc>(i, result, rows, src, resInc);
convertPointerBF16toF32<8, inc>(i, result, rows, src, resInc);
convertPointerBF16toF32<4, inc>(i, result, rows, src, resInc);
for(; i < rows; i++, src += resInc){
result[i] = Eigen::bfloat16_impl::bfloat16_to_float(*src);
}
}
template<bool inc = false>
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1)
{
Index i = 0;
convertPointerF32toBF16<32,inc>(i, result, rows, dst, resInc);
convertPointerF32toBF16<16,inc>(i, result, rows, dst, resInc);
convertPointerF32toBF16<8,inc>(i, result, rows, dst, resInc);
convertPointerF32toBF16<4,inc>(i, result, rows, dst, resInc);
for(; i < rows; i++, dst += resInc){
*dst = Eigen::bfloat16(result[i]);
}
}
template<bool extraRows>
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
{
Packet4f d0 = ploadu<Packet4f>(result);
d0 = pmadd(acc, pAlpha, d0);
if (extraRows) {
pstoreu_partial(result, d0, extra_rows);
} else {
pstoreu(result, d0);
}
}
template<Index num_acc, bool extraRows>
EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha, Index extra_rows)
{
for(Index k = 0; k < num_acc - (extraRows ? 1 : 0); k++) {
outputVecCol<false>(acc[k][0], result + k*4, pAlpha, extra_rows);
}
if (extraRows) {
outputVecCol<true>(acc[num_acc - 1][0], result + (num_acc - 1)*4, pAlpha, extra_rows);
}
}
template<Index num_acc, typename LhsMapper, bool zero>
EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1)
{
a0[k + 0] = lhs.template loadPacket<Packet8bf>(k*4, 0);
if (!zero) {
b1 = lhs.template loadPacket<Packet8bf>(k*4, 1);
}
if (num_acc > (k + 1)) {
a0[k + 1] = vec_mergel(a0[k + 0].m_val, b1.m_val);
}
a0[k + 0] = vec_mergeh(a0[k + 0].m_val, b1.m_val);
}
template<Index num_acc>
EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (&a0)[num_acc], Packet8bf b0)
{
BFLOAT16_UNROLL
for(Index k = 0; k < num_acc; k++) {
__builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(b0.m_val), reinterpret_cast<Packet16uc>(a0[k].m_val));
}
}
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero>
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc])
{
Packet8bf a0[num_acc];
Packet8bf b1 = pset1<Packet8bf>(Eigen::bfloat16(0));
Packet8bf b0 = rhs.template loadPacket<Packet8bf>(j + 0);
if (zero) {
b0 = vec_mergeh(b0.m_val, b1.m_val);
}
LhsMapper lhs2 = lhs.getSubMapper(0, j);
for(Index k = 0; k < num_acc; k += 2) {
loadVecLoop<num_acc, LhsMapper, zero>(k, lhs2, a0, b1);
}
multVec<num_acc>(quad_acc, a0, b0);
}
#define MAX_BFLOAT16_VEC_ACC 8
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
constexpr Index step = (num_acc * 4);
const Index extra_rows = (extraRows) ? (rows & 3) : 0;
constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC);
do{
Packet4f acc[num_acc][4];
__vector_quad quad_acc[num_acc];
zeroAccumulators<num_acc>(quad_acc);
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
Index j = 0;
for(Index k = cend >> 1; k--; j += 2) {
vecColLoop<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, quad_acc);
}
if (cend & 1) {
vecColLoop<num_acc, LhsMapper, RhsMapper, true>(j, lhs2, rhs, quad_acc);
}
disassembleAccumulators<num_acc>(quad_acc, acc);
outputVecColResults<num_acc, extraRows>(acc, result, pAlpha, extra_rows);
result += step;
} while(multiIters && (step <= rows - (row += step)));
}
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
if (MAX_BFLOAT16_VEC_ACC > num_acc) {
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
}
}
template<typename LhsMapper, typename RhsMapper, bool extraRows>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
switch ((rows - row) >> 2) {
case 7:
colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 6:
colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 5:
colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 4:
colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 3:
colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 2:
colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 1:
colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
default:
if (extraRows) {
colVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
}
break;
}
}
template<typename LhsMapper, typename RhsMapper>
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
Index row = 0;
if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) {
colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
result += row;
}
if (rows & 3) {
colVecColLoopBodyExtra<LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
} else {
colVecColLoopBodyExtra<LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
}
}
template<typename LhsMapper, typename RhsMapper>
void gemvMMA_bfloat16_col(
Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs,
bfloat16* res, Index resIncr,
bfloat16 alpha)
{
typedef typename RhsMapper::LinearMapper LinearMapper;
EIGEN_UNUSED_VARIABLE(resIncr);
eigen_internal_assert(resIncr == 1);
// The following copy tells the compiler that lhs's attributes are not modified outside this function
// This helps GCC to generate proper code.
LhsMapper lhs(alhs);
RhsMapper rhs2(rhs);
const Index lhsStride = lhs.stride();
// TODO: improve the following heuristic:
const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8);
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
convertArrayPointerBF16toF32(result, rows, res);
for (Index j2 = 0; j2 < cols; j2 += block_cols)
{
Index jend = numext::mini(j2 + block_cols, cols);
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVecColLoops<LhsMapper, LinearMapper>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
}
convertArrayPointerF32toBF16(result, rows, res);
}
static Packet16uc p16uc_ELEMENT_VEC3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
template<Index num_acc>
EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f (&acc)[num_acc][4], Index k)
{
if (num_acc > (k + 1)) {
acc[k][0] = vec_mergeh(acc[k][0], acc[k + 1][0]);
acc[k][1] = vec_mergeo(acc[k][1], acc[k + 1][1]);
acc[k][2] = vec_mergel(acc[k][2], acc[k + 1][2]);
acc[k][3] = vec_perm(acc[k][3], acc[k + 1][3], p16uc_ELEMENT_VEC3);
acc[k][0] = (acc[k][0] + acc[k][2]) + (acc[k][1] + acc[k][3]);
} else {
acc[k][0] = vec_mergeh(acc[k][0], acc[k][1]);
acc[k][0] += vec_mergel(acc[k][2], acc[k][3]);
#ifdef _BIG_ENDIAN
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
#else
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
#endif
}
}
template<Index num_acc>
EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4])
{
for(Index k = 0; k < num_acc; k += 4) {
preduxVecResults2<num_acc>(acc, k + 0);
if (num_acc > (k + 2)) {
preduxVecResults2<num_acc>(acc, k + 2);
acc[k + 0][0] = reinterpret_cast<Packet4f>(vec_mergeh(reinterpret_cast<Packet2ul>(acc[k + 0][0]), reinterpret_cast<Packet2ul>(acc[k + 2][0])));
}
}
}
template<Index num_acc>
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha)
{
constexpr Index extra = num_acc & 3;
for(Index k = 0; k < num_acc; k += 4) {
Packet4f d0 = ploadu<Packet4f>(result + k);
d0 = pmadd(acc[k + 0][0], pAlpha, d0);
if (num_acc > (k + 3)) {
pstoreu(result + k, d0);
} else {
if (extra == 3) {
pstoreu_partial(result + k, d0, extra);
} else if (extra == 2) {
Packet2ul d1 = reinterpret_cast<Packet2ul>(d0);
*(unsigned long long *)(result + k) = d1[0];
} else {
Packet4i d1 = reinterpret_cast<Packet4i>(d0);
*(unsigned int *)(result + k) = d1[0];
}
}
}
}
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols)
{
Packet8bf a0[num_acc], b0;
if (extra) {
b0 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
} else {
b0 = rhs.template loadPacket<Packet8bf>(j);
}
const LhsMapper lhs2 = lhs.getSubMapper(0, j);
for(Index k = 0; k < num_acc; k++) {
if (extra) {
a0[k] = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
} else {
a0[k] = lhs2.template loadPacket<Packet8bf>(k, 0);
}
}
multVec<num_acc>(quad_acc, a0, b0);
}
template<Index num_acc, typename LhsMapper, typename RhsMapper>
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc], Index extra_cols)
{
Index j = 0;
for(Index k = cols >> 3; k--; j += 8) {
multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols);
}
if (extra_cols) {
multVecLoop<num_acc, LhsMapper, RhsMapper, true>(quad_acc, lhs, rhs, j, extra_cols);
}
}
template<const Index num_acc, typename LhsMapper, typename RhsMapper>
void colVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC);
const Index extra_cols = (cols & 7);
do{
Packet4f acc[num_acc][4];
__vector_quad quad_acc[num_acc];
zeroAccumulators<num_acc>(quad_acc);
const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
vecLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, quad_acc, extra_cols);
disassembleAccumulators<num_acc>(quad_acc, acc);
preduxVecResults<num_acc>(acc);
outputVecResults<num_acc>(acc, result, pAlpha);
result += num_acc;
} while(multiIters && (num_acc <= rows - (row += num_acc)));
}
template<typename LhsMapper, typename RhsMapper>
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
switch (rows - row) {
case 7:
colVecLoopBody<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
break;
case 6:
colVecLoopBody<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
break;
case 5:
colVecLoopBody<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
break;
case 4:
colVecLoopBody<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
break;
case 3:
colVecLoopBody<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
break;
case 2:
colVecLoopBody<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
break;
case 1:
colVecLoopBody<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
break;
}
}
template<typename LhsMapper, typename RhsMapper>
EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
Index row = 0;
if (rows >= MAX_BFLOAT16_VEC_ACC) {
colVecLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
result += row;
}
colVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
}
template<typename LhsMapper, typename RhsMapper>
EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(
Index rows, Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs,
bfloat16* res, Index resIncr,
bfloat16 alpha)
{
typedef typename RhsMapper::LinearMapper LinearMapper;
// The following copy tells the compiler that lhs's attributes are not modified outside this function
// This helps GCC to generate proper code.
LhsMapper lhs(alhs);
LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
eigen_internal_assert(rhs.stride() == 1);
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
const Packet4f pAlpha = pset1<Packet4f>(falpha);
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
if (resIncr == 1) {
convertArrayPointerBF16toF32(result, rows, res);
} else {
convertArrayPointerBF16toF32<true>(result, rows, res, resIncr);
}
calcVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
if (resIncr == 1) {
convertArrayPointerF32toBF16(result, rows, res);
} else {
convertArrayPointerF32toBF16<true>(result, rows, res, resIncr);
}
}
} }
} }
#endif //EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H #endif //EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H

View File

@ -2061,6 +2061,39 @@ EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(double)
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float) EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double) EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
#ifdef USE_GEMV_MMA
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, ConjugateLhs, bfloat16, RhsMapper, ConjugateRhs, Version> \
{ \
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
Index rows, Index cols, \
const LhsMapper& lhs, \
const RhsMapper& rhs, \
bfloat16* res, Index resIncr, \
bfloat16 alpha) { \
gemvMMA_bfloat16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
};
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() \
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
struct general_matrix_vector_product<Index, bfloat16, LhsMapper, RowMajor, ConjugateLhs, bfloat16, RhsMapper, ConjugateRhs, Version> \
{ \
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
Index rows, Index cols, \
const LhsMapper& lhs, \
const RhsMapper& rhs, \
bfloat16* res, Index resIncr, \
bfloat16 alpha) { \
gemvMMA_bfloat16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
} \
};
EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16()
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16()
#endif
template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket> template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1) EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)
{ {