mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-29 17:55:13 +08:00
Add MMA to BF16 GEMV - 5.0-6.3X faster (for Power)
This commit is contained in:
parent
2067b54b13
commit
9d72412385
@ -146,8 +146,8 @@ EIGEN_ALWAYS_INLINE void colLoopBodyIter(Index depth, Index rows, const Packet4f
|
|||||||
|
|
||||||
zeroAccumulators<num_acc>(quad_acc);
|
zeroAccumulators<num_acc>(quad_acc);
|
||||||
|
|
||||||
Index k;
|
Index k = 0;
|
||||||
for(k = 0; k + 2 <= depth; k += 2){
|
for(Index j = depth >> 1; j--; k += 2){
|
||||||
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
KLoop<num_acc, num_packets, false, rhsExtraCols, lhsExtraRows, num_rhs, num_lhs>(indexA, indexB, quad_acc, strideB, k, offsetB, extra_cols, extra_rows);
|
||||||
}
|
}
|
||||||
if(depth&1){
|
if(depth&1){
|
||||||
@ -356,7 +356,6 @@ template<typename DataMapper>
|
|||||||
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat16* indexB, Index rows, Index depth, Index cols, bfloat16 alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||||
{
|
{
|
||||||
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
||||||
if (falpha == float(0)) return;
|
|
||||||
const Packet4f pAlpha = pset1<Packet4f>(falpha);
|
const Packet4f pAlpha = pset1<Packet4f>(falpha);
|
||||||
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
|
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
|
||||||
|
|
||||||
@ -395,6 +394,488 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* indexA, const bfloat
|
|||||||
convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
|
convertArrayF32toBF16<DataMapper>(result, cols, rows, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<const Index size, bool inc, Index delta>
|
||||||
|
EIGEN_ALWAYS_INLINE void storeBF16fromResult(bfloat16* dst, Packet8bf data, Index resInc)
|
||||||
|
{
|
||||||
|
if (inc) {
|
||||||
|
if (size == 4) {
|
||||||
|
pscatter_partial(dst + delta*resInc, data, resInc, 4);
|
||||||
|
} else {
|
||||||
|
pscatter(dst + delta*resInc, data, resInc);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (size == 4) {
|
||||||
|
pstoreu_partial(dst + delta, data, 4);
|
||||||
|
} else {
|
||||||
|
pstoreu(dst + delta, data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<const Index size, bool inc>
|
||||||
|
EIGEN_ALWAYS_INLINE void convertPointerF32toBF16(Index& i, float* result, Index rows, bfloat16*& dst, Index resInc)
|
||||||
|
{
|
||||||
|
for(Index j = (rows - i) / size; j--; i += size, dst += size*resInc){
|
||||||
|
PacketBlock<Packet8bf,(size+4)/8> r32;
|
||||||
|
r32.packet[0] = convertF32toBF16<size != 4>(result + i + 0);
|
||||||
|
if (size >= 16) {
|
||||||
|
r32.packet[1] = convertF32toBF16<true>(result + i + 8);
|
||||||
|
}
|
||||||
|
if (size >= 32) {
|
||||||
|
r32.packet[2] = convertF32toBF16<true>(result + i + 16);
|
||||||
|
r32.packet[3] = convertF32toBF16<true>(result + i + 24);
|
||||||
|
}
|
||||||
|
storeBF16fromResult<size, inc, 0>(dst, r32.packet[0], resInc);
|
||||||
|
if (size >= 16) {
|
||||||
|
storeBF16fromResult<size, inc, 8>(dst, r32.packet[1], resInc);
|
||||||
|
}
|
||||||
|
if (size >= 32) {
|
||||||
|
storeBF16fromResult<size, inc, 16>(dst, r32.packet[2], resInc);
|
||||||
|
storeBF16fromResult<size, inc, 24>(dst, r32.packet[3], resInc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool inc, Index delta>
|
||||||
|
EIGEN_ALWAYS_INLINE Packet8bf loadBF16fromResult(bfloat16* src, Index resInc)
|
||||||
|
{
|
||||||
|
if (inc) {
|
||||||
|
return pgather<bfloat16, Packet8bf>(src + delta*resInc, resInc);
|
||||||
|
} else {
|
||||||
|
return ploadu<Packet8bf>(src + delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<const Index size, bool inc>
|
||||||
|
EIGEN_ALWAYS_INLINE void convertPointerBF16toF32(Index& i, float *result, Index rows, bfloat16*& src, Index resInc)
|
||||||
|
{
|
||||||
|
for(; i + size <= rows; i += size, src += size*resInc){
|
||||||
|
PacketBlock<Packet8bf,(size+4)/8> r32;
|
||||||
|
r32.packet[0] = loadBF16fromResult<inc, 0>(src, resInc);
|
||||||
|
if (size >= 16) {
|
||||||
|
r32.packet[1] = loadBF16fromResult<inc, 8>(src, resInc);
|
||||||
|
}
|
||||||
|
if (size >= 32) {
|
||||||
|
r32.packet[2] = loadBF16fromResult<inc, 16>(src, resInc);
|
||||||
|
r32.packet[3] = loadBF16fromResult<inc, 24>(src, resInc);
|
||||||
|
}
|
||||||
|
storeConvertBlockBF16<size>(result + i, r32);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool inc = false>
|
||||||
|
EIGEN_ALWAYS_INLINE void convertArrayPointerBF16toF32(float *result, Index rows, bfloat16* src, Index resInc = 1)
|
||||||
|
{
|
||||||
|
Index i = 0;
|
||||||
|
convertPointerBF16toF32<32, inc>(i, result, rows, src, resInc);
|
||||||
|
convertPointerBF16toF32<16, inc>(i, result, rows, src, resInc);
|
||||||
|
convertPointerBF16toF32<8, inc>(i, result, rows, src, resInc);
|
||||||
|
convertPointerBF16toF32<4, inc>(i, result, rows, src, resInc);
|
||||||
|
for(; i < rows; i++, src += resInc){
|
||||||
|
result[i] = Eigen::bfloat16_impl::bfloat16_to_float(*src);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool inc = false>
|
||||||
|
EIGEN_ALWAYS_INLINE void convertArrayPointerF32toBF16(float *result, Index rows, bfloat16* dst, Index resInc = 1)
|
||||||
|
{
|
||||||
|
Index i = 0;
|
||||||
|
convertPointerF32toBF16<32,inc>(i, result, rows, dst, resInc);
|
||||||
|
convertPointerF32toBF16<16,inc>(i, result, rows, dst, resInc);
|
||||||
|
convertPointerF32toBF16<8,inc>(i, result, rows, dst, resInc);
|
||||||
|
convertPointerF32toBF16<4,inc>(i, result, rows, dst, resInc);
|
||||||
|
for(; i < rows; i++, dst += resInc){
|
||||||
|
*dst = Eigen::bfloat16(result[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<bool extraRows>
|
||||||
|
EIGEN_ALWAYS_INLINE void outputVecCol(Packet4f acc, float *result, Packet4f pAlpha, Index extra_rows)
|
||||||
|
{
|
||||||
|
Packet4f d0 = ploadu<Packet4f>(result);
|
||||||
|
d0 = pmadd(acc, pAlpha, d0);
|
||||||
|
if (extraRows) {
|
||||||
|
pstoreu_partial(result, d0, extra_rows);
|
||||||
|
} else {
|
||||||
|
pstoreu(result, d0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc, bool extraRows>
|
||||||
|
EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha, Index extra_rows)
|
||||||
|
{
|
||||||
|
for(Index k = 0; k < num_acc - (extraRows ? 1 : 0); k++) {
|
||||||
|
outputVecCol<false>(acc[k][0], result + k*4, pAlpha, extra_rows);
|
||||||
|
}
|
||||||
|
if (extraRows) {
|
||||||
|
outputVecCol<true>(acc[num_acc - 1][0], result + (num_acc - 1)*4, pAlpha, extra_rows);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc, typename LhsMapper, bool zero>
|
||||||
|
EIGEN_ALWAYS_INLINE void loadVecLoop(Index k, LhsMapper& lhs, Packet8bf (&a0)[num_acc], Packet8bf b1)
|
||||||
|
{
|
||||||
|
a0[k + 0] = lhs.template loadPacket<Packet8bf>(k*4, 0);
|
||||||
|
if (!zero) {
|
||||||
|
b1 = lhs.template loadPacket<Packet8bf>(k*4, 1);
|
||||||
|
}
|
||||||
|
if (num_acc > (k + 1)) {
|
||||||
|
a0[k + 1] = vec_mergel(a0[k + 0].m_val, b1.m_val);
|
||||||
|
}
|
||||||
|
a0[k + 0] = vec_mergeh(a0[k + 0].m_val, b1.m_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc>
|
||||||
|
EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (&a0)[num_acc], Packet8bf b0)
|
||||||
|
{
|
||||||
|
BFLOAT16_UNROLL
|
||||||
|
for(Index k = 0; k < num_acc; k++) {
|
||||||
|
__builtin_mma_xvbf16ger2pp(&(quad_acc[k]), reinterpret_cast<Packet16uc>(b0.m_val), reinterpret_cast<Packet16uc>(a0[k].m_val));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero>
|
||||||
|
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc])
|
||||||
|
{
|
||||||
|
Packet8bf a0[num_acc];
|
||||||
|
Packet8bf b1 = pset1<Packet8bf>(Eigen::bfloat16(0));
|
||||||
|
Packet8bf b0 = rhs.template loadPacket<Packet8bf>(j + 0);
|
||||||
|
|
||||||
|
if (zero) {
|
||||||
|
b0 = vec_mergeh(b0.m_val, b1.m_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||||
|
for(Index k = 0; k < num_acc; k += 2) {
|
||||||
|
loadVecLoop<num_acc, LhsMapper, zero>(k, lhs2, a0, b1);
|
||||||
|
}
|
||||||
|
|
||||||
|
multVec<num_acc>(quad_acc, a0, b0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define MAX_BFLOAT16_VEC_ACC 8
|
||||||
|
|
||||||
|
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
|
||||||
|
void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||||
|
{
|
||||||
|
constexpr Index step = (num_acc * 4);
|
||||||
|
const Index extra_rows = (extraRows) ? (rows & 3) : 0;
|
||||||
|
constexpr bool multiIters = !extraRows && (num_acc == MAX_BFLOAT16_VEC_ACC);
|
||||||
|
|
||||||
|
do{
|
||||||
|
Packet4f acc[num_acc][4];
|
||||||
|
__vector_quad quad_acc[num_acc];
|
||||||
|
|
||||||
|
zeroAccumulators<num_acc>(quad_acc);
|
||||||
|
|
||||||
|
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||||
|
Index j = 0;
|
||||||
|
for(Index k = cend >> 1; k--; j += 2) {
|
||||||
|
vecColLoop<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, quad_acc);
|
||||||
|
}
|
||||||
|
if (cend & 1) {
|
||||||
|
vecColLoop<num_acc, LhsMapper, RhsMapper, true>(j, lhs2, rhs, quad_acc);
|
||||||
|
}
|
||||||
|
|
||||||
|
disassembleAccumulators<num_acc>(quad_acc, acc);
|
||||||
|
|
||||||
|
outputVecColResults<num_acc, extraRows>(acc, result, pAlpha, extra_rows);
|
||||||
|
|
||||||
|
result += step;
|
||||||
|
} while(multiIters && (step <= rows - (row += step)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
|
||||||
|
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||||
|
{
|
||||||
|
if (MAX_BFLOAT16_VEC_ACC > num_acc) {
|
||||||
|
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename LhsMapper, typename RhsMapper, bool extraRows>
|
||||||
|
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||||
|
{
|
||||||
|
switch ((rows - row) >> 2) {
|
||||||
|
case 7:
|
||||||
|
colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
if (extraRows) {
|
||||||
|
colVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename LhsMapper, typename RhsMapper>
|
||||||
|
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||||
|
{
|
||||||
|
Index row = 0;
|
||||||
|
if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) {
|
||||||
|
colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
result += row;
|
||||||
|
}
|
||||||
|
if (rows & 3) {
|
||||||
|
colVecColLoopBodyExtra<LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
} else {
|
||||||
|
colVecColLoopBodyExtra<LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename LhsMapper, typename RhsMapper>
|
||||||
|
void gemvMMA_bfloat16_col(
|
||||||
|
Index rows, Index cols,
|
||||||
|
const LhsMapper& alhs,
|
||||||
|
const RhsMapper& rhs,
|
||||||
|
bfloat16* res, Index resIncr,
|
||||||
|
bfloat16 alpha)
|
||||||
|
{
|
||||||
|
typedef typename RhsMapper::LinearMapper LinearMapper;
|
||||||
|
|
||||||
|
EIGEN_UNUSED_VARIABLE(resIncr);
|
||||||
|
eigen_internal_assert(resIncr == 1);
|
||||||
|
|
||||||
|
// The following copy tells the compiler that lhs's attributes are not modified outside this function
|
||||||
|
// This helps GCC to generate proper code.
|
||||||
|
LhsMapper lhs(alhs);
|
||||||
|
RhsMapper rhs2(rhs);
|
||||||
|
|
||||||
|
const Index lhsStride = lhs.stride();
|
||||||
|
|
||||||
|
// TODO: improve the following heuristic:
|
||||||
|
const Index block_cols = cols < 128 ? cols : (lhsStride * sizeof(bfloat16) < 16000 ? 16 : 8);
|
||||||
|
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
||||||
|
Packet4f pAlpha = pset1<Packet4f>(falpha);
|
||||||
|
|
||||||
|
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
|
||||||
|
|
||||||
|
convertArrayPointerBF16toF32(result, rows, res);
|
||||||
|
|
||||||
|
for (Index j2 = 0; j2 < cols; j2 += block_cols)
|
||||||
|
{
|
||||||
|
Index jend = numext::mini(j2 + block_cols, cols);
|
||||||
|
|
||||||
|
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
|
||||||
|
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
|
||||||
|
calcVecColLoops<LhsMapper, LinearMapper>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
convertArrayPointerF32toBF16(result, rows, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Packet16uc p16uc_ELEMENT_VEC3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
|
||||||
|
|
||||||
|
template<Index num_acc>
|
||||||
|
EIGEN_ALWAYS_INLINE void preduxVecResults2(Packet4f (&acc)[num_acc][4], Index k)
|
||||||
|
{
|
||||||
|
if (num_acc > (k + 1)) {
|
||||||
|
acc[k][0] = vec_mergeh(acc[k][0], acc[k + 1][0]);
|
||||||
|
acc[k][1] = vec_mergeo(acc[k][1], acc[k + 1][1]);
|
||||||
|
acc[k][2] = vec_mergel(acc[k][2], acc[k + 1][2]);
|
||||||
|
acc[k][3] = vec_perm(acc[k][3], acc[k + 1][3], p16uc_ELEMENT_VEC3);
|
||||||
|
|
||||||
|
acc[k][0] = (acc[k][0] + acc[k][2]) + (acc[k][1] + acc[k][3]);
|
||||||
|
} else {
|
||||||
|
acc[k][0] = vec_mergeh(acc[k][0], acc[k][1]);
|
||||||
|
acc[k][0] += vec_mergel(acc[k][2], acc[k][3]);
|
||||||
|
#ifdef _BIG_ENDIAN
|
||||||
|
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 12);
|
||||||
|
#else
|
||||||
|
acc[k][0] += vec_sld(acc[k][0], acc[k][0], 4);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc>
|
||||||
|
EIGEN_ALWAYS_INLINE void preduxVecResults(Packet4f (&acc)[num_acc][4])
|
||||||
|
{
|
||||||
|
for(Index k = 0; k < num_acc; k += 4) {
|
||||||
|
preduxVecResults2<num_acc>(acc, k + 0);
|
||||||
|
if (num_acc > (k + 2)) {
|
||||||
|
preduxVecResults2<num_acc>(acc, k + 2);
|
||||||
|
acc[k + 0][0] = reinterpret_cast<Packet4f>(vec_mergeh(reinterpret_cast<Packet2ul>(acc[k + 0][0]), reinterpret_cast<Packet2ul>(acc[k + 2][0])));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc>
|
||||||
|
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][4], float *result, Packet4f pAlpha)
|
||||||
|
{
|
||||||
|
constexpr Index extra = num_acc & 3;
|
||||||
|
|
||||||
|
for(Index k = 0; k < num_acc; k += 4) {
|
||||||
|
Packet4f d0 = ploadu<Packet4f>(result + k);
|
||||||
|
d0 = pmadd(acc[k + 0][0], pAlpha, d0);
|
||||||
|
|
||||||
|
if (num_acc > (k + 3)) {
|
||||||
|
pstoreu(result + k, d0);
|
||||||
|
} else {
|
||||||
|
if (extra == 3) {
|
||||||
|
pstoreu_partial(result + k, d0, extra);
|
||||||
|
} else if (extra == 2) {
|
||||||
|
Packet2ul d1 = reinterpret_cast<Packet2ul>(d0);
|
||||||
|
*(unsigned long long *)(result + k) = d1[0];
|
||||||
|
} else {
|
||||||
|
Packet4i d1 = reinterpret_cast<Packet4i>(d0);
|
||||||
|
*(unsigned int *)(result + k) = d1[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool extra>
|
||||||
|
EIGEN_ALWAYS_INLINE void multVecLoop(__vector_quad (&quad_acc)[num_acc], const LhsMapper& lhs, RhsMapper& rhs, Index j, Index extra_cols)
|
||||||
|
{
|
||||||
|
Packet8bf a0[num_acc], b0;
|
||||||
|
|
||||||
|
if (extra) {
|
||||||
|
b0 = rhs.template loadPacketPartial<Packet8bf>(j, extra_cols);
|
||||||
|
} else {
|
||||||
|
b0 = rhs.template loadPacket<Packet8bf>(j);
|
||||||
|
}
|
||||||
|
|
||||||
|
const LhsMapper lhs2 = lhs.getSubMapper(0, j);
|
||||||
|
for(Index k = 0; k < num_acc; k++) {
|
||||||
|
if (extra) {
|
||||||
|
a0[k] = lhs2.template loadPacketPartial<Packet8bf>(k, 0, extra_cols);
|
||||||
|
} else {
|
||||||
|
a0[k] = lhs2.template loadPacket<Packet8bf>(k, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
multVec<num_acc>(quad_acc, a0, b0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<Index num_acc, typename LhsMapper, typename RhsMapper>
|
||||||
|
EIGEN_ALWAYS_INLINE void vecLoop(Index cols, const LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc], Index extra_cols)
|
||||||
|
{
|
||||||
|
Index j = 0;
|
||||||
|
for(Index k = cols >> 3; k--; j += 8) {
|
||||||
|
multVecLoop<num_acc, LhsMapper, RhsMapper, false>(quad_acc, lhs, rhs, j, extra_cols);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (extra_cols) {
|
||||||
|
multVecLoop<num_acc, LhsMapper, RhsMapper, true>(quad_acc, lhs, rhs, j, extra_cols);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<const Index num_acc, typename LhsMapper, typename RhsMapper>
|
||||||
|
void colVecLoopBody(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||||
|
{
|
||||||
|
constexpr bool multiIters = (num_acc == MAX_BFLOAT16_VEC_ACC);
|
||||||
|
const Index extra_cols = (cols & 7);
|
||||||
|
|
||||||
|
do{
|
||||||
|
Packet4f acc[num_acc][4];
|
||||||
|
__vector_quad quad_acc[num_acc];
|
||||||
|
|
||||||
|
zeroAccumulators<num_acc>(quad_acc);
|
||||||
|
|
||||||
|
const LhsMapper lhs2 = lhs.getSubMapper(row, 0);
|
||||||
|
vecLoop<num_acc, LhsMapper, RhsMapper>(cols, lhs2, rhs, quad_acc, extra_cols);
|
||||||
|
|
||||||
|
disassembleAccumulators<num_acc>(quad_acc, acc);
|
||||||
|
|
||||||
|
preduxVecResults<num_acc>(acc);
|
||||||
|
|
||||||
|
outputVecResults<num_acc>(acc, result, pAlpha);
|
||||||
|
|
||||||
|
result += num_acc;
|
||||||
|
} while(multiIters && (num_acc <= rows - (row += num_acc)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename LhsMapper, typename RhsMapper>
|
||||||
|
EIGEN_ALWAYS_INLINE void colVecLoopBodyExtra(Index& row, Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||||
|
{
|
||||||
|
switch (rows - row) {
|
||||||
|
case 7:
|
||||||
|
colVecLoopBody<7, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
colVecLoopBody<6, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
colVecLoopBody<5, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
colVecLoopBody<4, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
colVecLoopBody<3, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
colVecLoopBody<2, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
case 1:
|
||||||
|
colVecLoopBody<1, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename LhsMapper, typename RhsMapper>
|
||||||
|
EIGEN_ALWAYS_INLINE void calcVecLoops(Index cols, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
|
||||||
|
{
|
||||||
|
Index row = 0;
|
||||||
|
if (rows >= MAX_BFLOAT16_VEC_ACC) {
|
||||||
|
colVecLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
result += row;
|
||||||
|
}
|
||||||
|
colVecLoopBodyExtra<LhsMapper, RhsMapper>(row, cols, rows, lhs, rhs, pAlpha, result);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename LhsMapper, typename RhsMapper>
|
||||||
|
EIGEN_STRONG_INLINE void gemvMMA_bfloat16_row(
|
||||||
|
Index rows, Index cols,
|
||||||
|
const LhsMapper& alhs,
|
||||||
|
const RhsMapper& rhs,
|
||||||
|
bfloat16* res, Index resIncr,
|
||||||
|
bfloat16 alpha)
|
||||||
|
{
|
||||||
|
typedef typename RhsMapper::LinearMapper LinearMapper;
|
||||||
|
|
||||||
|
// The following copy tells the compiler that lhs's attributes are not modified outside this function
|
||||||
|
// This helps GCC to generate proper code.
|
||||||
|
LhsMapper lhs(alhs);
|
||||||
|
LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
|
||||||
|
|
||||||
|
eigen_internal_assert(rhs.stride() == 1);
|
||||||
|
|
||||||
|
float falpha = Eigen::bfloat16_impl::bfloat16_to_float(alpha);
|
||||||
|
const Packet4f pAlpha = pset1<Packet4f>(falpha);
|
||||||
|
|
||||||
|
ei_declare_aligned_stack_constructed_variable(float, result, rows, 0);
|
||||||
|
if (resIncr == 1) {
|
||||||
|
convertArrayPointerBF16toF32(result, rows, res);
|
||||||
|
} else {
|
||||||
|
convertArrayPointerBF16toF32<true>(result, rows, res, resIncr);
|
||||||
|
}
|
||||||
|
calcVecLoops<LhsMapper, LinearMapper>(cols, rows, lhs, rhs2, pAlpha, result);
|
||||||
|
if (resIncr == 1) {
|
||||||
|
convertArrayPointerF32toBF16(result, rows, res);
|
||||||
|
} else {
|
||||||
|
convertArrayPointerF32toBF16<true>(result, rows, res, resIncr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif //EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
|
#endif //EIGEN_MATRIX_PRODUCT_MMA_BFLOAT16_ALTIVEC_H
|
||||||
|
@ -2061,6 +2061,39 @@ EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL(double)
|
|||||||
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
|
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(float)
|
||||||
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
|
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW(double)
|
||||||
|
|
||||||
|
#ifdef USE_GEMV_MMA
|
||||||
|
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16() \
|
||||||
|
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
|
||||||
|
struct general_matrix_vector_product<Index, bfloat16, LhsMapper, ColMajor, ConjugateLhs, bfloat16, RhsMapper, ConjugateRhs, Version> \
|
||||||
|
{ \
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
|
||||||
|
Index rows, Index cols, \
|
||||||
|
const LhsMapper& lhs, \
|
||||||
|
const RhsMapper& rhs, \
|
||||||
|
bfloat16* res, Index resIncr, \
|
||||||
|
bfloat16 alpha) { \
|
||||||
|
gemvMMA_bfloat16_col<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||||
|
} \
|
||||||
|
};
|
||||||
|
|
||||||
|
#define EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16() \
|
||||||
|
template<typename Index, typename LhsMapper, bool ConjugateLhs, typename RhsMapper, bool ConjugateRhs, int Version> \
|
||||||
|
struct general_matrix_vector_product<Index, bfloat16, LhsMapper, RowMajor, ConjugateLhs, bfloat16, RhsMapper, ConjugateRhs, Version> \
|
||||||
|
{ \
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( \
|
||||||
|
Index rows, Index cols, \
|
||||||
|
const LhsMapper& lhs, \
|
||||||
|
const RhsMapper& rhs, \
|
||||||
|
bfloat16* res, Index resIncr, \
|
||||||
|
bfloat16 alpha) { \
|
||||||
|
gemvMMA_bfloat16_row<LhsMapper, RhsMapper>(rows, cols, lhs, rhs, res, resIncr, alpha); \
|
||||||
|
} \
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_POWER_GEMV_REAL_SPECIALIZE_COL_BFLOAT16()
|
||||||
|
EIGEN_POWER_GEMV_REAL_SPECIALIZE_ROW_BFLOAT16()
|
||||||
|
#endif
|
||||||
|
|
||||||
template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
|
template<typename ResScalar, typename PResPacket, typename ResPacket, typename LhsPacket, typename RhsPacket>
|
||||||
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)
|
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PResPacket& b0, ResPacket& a1, ResPacket& b1)
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user