Fix ColMajor BF16 GEMV for when vector is RowMajor

This commit is contained in:
Chip Kerchner 2023-05-03 20:12:50 +00:00 committed by Rasmus Munk Larsen
parent fdc749de2a
commit fda1373a15
4 changed files with 104 additions and 65 deletions

View File

@ -102,6 +102,9 @@ EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], flo
template<Index num_acc, Index size = 4>
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha);
template<typename RhsMapper, bool linear, std::enable_if_t<linear, bool>>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j);
template<typename Packet>
EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs);

View File

@ -383,12 +383,12 @@ EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf (
}
}
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero>
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc])
{
Packet8bf a0[num_acc];
Packet8bf b1 = pset1<Packet8bf>(Eigen::bfloat16(0));
Packet8bf b0 = rhs.template loadPacket<Packet8bf>(j + 0);
Packet8bf b0 = loadColData<RhsMapper, linear>(rhs, j);
if (zero) {
b0 = vec_mergeh(b0.m_val, b1.m_val);
@ -405,7 +405,7 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v
#define MAX_BFLOAT16_VEC_ACC 8
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
constexpr Index step = (num_acc * 4);
@ -420,10 +420,10 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
for(Index j = 0; j + 2 <= cend; j += 2) {
vecColLoop<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, quad_acc);
vecColLoop<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, quad_acc);
}
if (cend & 1) {
vecColLoop<num_acc, LhsMapper, RhsMapper, true>(cend - 1, lhs2, rhs, quad_acc);
vecColLoop<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, quad_acc);
}
disassembleAccumulators<num_acc>(quad_acc, acc);
@ -434,59 +434,59 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa
} while(multiIters && (step <= rows - (row += step)));
}
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
if (MAX_BFLOAT16_VEC_ACC > num_acc) {
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
}
}
template<typename LhsMapper, typename RhsMapper, bool extraRows>
template<typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
switch ((rows - row) >> 2) {
case 7:
colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 6:
colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 5:
colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 4:
colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 3:
colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 2:
colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 1:
colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
default:
if (extraRows) {
colVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
}
break;
}
}
template<typename LhsMapper, typename RhsMapper>
template<typename LhsMapper, typename RhsMapper, bool linear>
EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
Index row = 0;
if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) {
colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBody<MAX_BFLOAT16_VEC_ACC, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
result += row;
}
if (rows & 3) {
colVecColLoopBodyExtra<LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} else {
colVecColLoopBodyExtra<LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
colVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
}
}
@ -524,8 +524,13 @@ void gemvMMA_bfloat16_col(
Index jend = numext::mini(j2 + block_cols, cols);
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVecColLoops<LhsMapper, LinearMapper>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
if (rhs.stride() == 1) {
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
} else {
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
calcVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
}
}
convertArrayPointerF32toBF16(result, rows, res);

View File

@ -521,11 +521,23 @@ EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[
}
}
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero>
template<typename RhsMapper, bool linear, std::enable_if_t<linear, bool> = true>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j)
{
return rhs.template loadPacket<Packet8bf>(j + 0);
}
template<typename RhsMapper, bool linear, std::enable_if_t<!linear, bool> = true>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j)
{
return pgather<bfloat16, Packet8bf>(&rhs(j + 0, 0), rhs.stride());
}
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>
EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2])
{
Packet4f a0[num_acc][2], b0[2];
Packet8bf b2 = rhs.template loadPacket<Packet8bf>(j + 0);
Packet8bf b2 = loadColData<RhsMapper, linear>(rhs, j);
b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1);
if (!zero) {
@ -551,7 +563,7 @@ EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f (&acc)[num_acc][2])
// Uses 2X the accumulators or 4X the number of VSX registers
#define MAX_BFLOAT16_VEC_ACC_VSX 8
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
constexpr Index step = (num_acc * 4);
@ -565,10 +577,10 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh
LhsMapper lhs2 = lhs.getSubMapper(row, 0);
for(Index j = 0; j + 2 <= cend; j += 2) {
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false>(j, lhs2, rhs, acc);
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, false, linear>(j, lhs2, rhs, acc);
}
if (cend & 1) {
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true>(cend - 1, lhs2, rhs, acc);
vecColLoopVSX<num_acc, LhsMapper, RhsMapper, true, linear>(cend - 1, lhs2, rhs, acc);
}
addResultsVSX<num_acc>(acc);
@ -579,59 +591,59 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh
} while(multiIters && (step <= rows - (row += step)));
}
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows>
template<const Index num_acc, typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) {
colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBody<num_acc + (extraRows ? 1 : 0), LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
}
}
template<typename LhsMapper, typename RhsMapper, bool extraRows>
template<typename LhsMapper, typename RhsMapper, bool extraRows, bool linear>
EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
switch ((rows - row) >> 2) {
case 7:
colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 6:
colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 5:
colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 4:
colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 3:
colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 2:
colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
case 1:
colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
break;
default:
if (extraRows) {
colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
}
break;
}
}
template<typename LhsMapper, typename RhsMapper>
template<typename LhsMapper, typename RhsMapper, bool linear>
EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result)
{
Index row = 0;
if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) {
colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBody<MAX_BFLOAT16_VEC_ACC_VSX, LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
result += row;
}
if (rows & 3) {
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
} else {
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false>(row, cend, rows, lhs, rhs, pAlpha, result);
colVSXVecColLoopBodyExtra<LhsMapper, RhsMapper, false, linear>(row, cend, rows, lhs, rhs, pAlpha, result);
}
}
@ -724,8 +736,13 @@ void gemv_bfloat16_col(
Index jend = numext::mini(j2 + block_cols, cols);
LhsMapper lhs2 = lhs.getSubMapper(0, j2);
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, LinearMapper>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
if (rhs.stride() == 1) {
LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, LinearMapper, true>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
} else {
RhsMapper rhs3 = rhs2.getSubMapper(j2, 0);
calcVSXVecColLoops<LhsMapper, RhsMapper, false>(jend - j2, rows, lhs2, rhs3, pAlpha, result);
}
}
convertArrayPointerF32toBF16VSX(result, rows, res);

View File

@ -796,12 +796,20 @@ template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pgather_c
{
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits<Packet>::size];
eigen_internal_assert(n <= unpacket_traits<Packet>::size && "number of elements will gather past end of packet");
LOAD_STORE_UNROLL_16
for (Index i = 0; i < n; i++) {
a[i] = from[i*stride];
if (stride == 1) {
if (n == unpacket_traits<Packet>::size) {
return ploadu<Packet>(from);
} else {
return ploadu_partial<Packet>(from, n);
}
} else {
LOAD_STORE_UNROLL_16
for (Index i = 0; i < n; i++) {
a[i] = from[i*stride];
}
// Leave rest of the array uninitialized
return pload_ignore<Packet>(a);
}
// Leave rest of the array uninitialized
return pload_ignore<Packet>(a);
}
template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4f pgather<float, Packet4f>(const float* from, Index stride)
@ -878,10 +886,18 @@ template<typename Packet> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pscatter_co
{
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits<Packet>::size];
eigen_internal_assert(n <= unpacket_traits<Packet>::size && "number of elements will scatter past end of packet");
pstore<__UNPACK_TYPE__(Packet)>(a, from);
LOAD_STORE_UNROLL_16
for (Index i = 0; i < n; i++) {
to[i*stride] = a[i];
if (stride == 1) {
if (n == unpacket_traits<Packet>::size) {
return pstoreu(to, from);
} else {
return pstoreu_partial(to, from, n);
}
} else {
pstore<__UNPACK_TYPE__(Packet)>(a, from);
LOAD_STORE_UNROLL_16
for (Index i = 0; i < n; i++) {
to[i*stride] = a[i];
}
}
}
@ -1256,15 +1272,14 @@ template<typename Packet> EIGEN_ALWAYS_INLINE Packet ploadu_partial_common(const
return vec_xl_len(const_cast<__UNPACK_TYPE__(Packet)*>(from), n * size);
#else
if (n) {
Index n2 = n * size;
if (16 <= n2) {
return ploadu<Packet>(from);
}
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size];
unsigned char* load2 = reinterpret_cast<unsigned char *>(load);
unsigned char* from2 = reinterpret_cast<unsigned char *>(const_cast<__UNPACK_TYPE__(Packet)*>(from));
Index n2 = n * size;
if (16 <= n2) {
pstore(load2, ploadu<Packet16uc>(from2));
} else {
memcpy((void *)load2, (void *)from2, n2);
}
memcpy((void *)load2, (void *)from2, n2);
return pload_ignore<Packet>(load);
} else {
return Packet(pset1<Packet16uc>(0));
@ -1432,16 +1447,15 @@ template<typename Packet> EIGEN_ALWAYS_INLINE void pstoreu_partial_common(__UNPA
vec_xst_len(from, to, n * size);
#else
if (n) {
Index n2 = n * size;
if (16 <= n2) {
pstoreu(to, from);
}
EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size];
pstore(store, from);
unsigned char* store2 = reinterpret_cast<unsigned char *>(store);
unsigned char* to2 = reinterpret_cast<unsigned char *>(to);
Index n2 = n * size;
if (16 <= n2) {
pstoreu(to2, pload<Packet16uc>(store2));
} else {
memcpy((void *)to2, (void *)store2, n2);
}
memcpy((void *)to2, (void *)store2, n2);
}
#endif
}