Specialized loadColData correctly - fix previous BF16 GEMV MR

This commit is contained in:
Chip Kerchner 2023-05-04 16:38:17 +00:00 committed by Charles Schlosser
parent 2af03fb685
commit b8208b363c
2 changed files with 21 additions and 7 deletions

View File

@ -102,7 +102,7 @@ EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], flo
template<Index num_acc, Index size = 4> template<Index num_acc, Index size = 4>
EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha); EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha);
template<typename RhsMapper, bool linear, std::enable_if_t<linear, bool>> template<typename RhsMapper, bool linear>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j); EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j);
template<typename Packet> template<typename Packet>

View File

@ -521,16 +521,30 @@ EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[
} }
} }
template<typename RhsMapper, bool linear, std::enable_if_t<linear, bool> = true> template<typename RhsMapper, bool linear>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j) struct loadColData_impl
{ {
return rhs.template loadPacket<Packet8bf>(j + 0); // linear == false
} static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j)
{
return pgather<bfloat16, Packet8bf>(&rhs(j + 0, 0), rhs.stride());
}
};
template<typename RhsMapper, bool linear, std::enable_if_t<!linear, bool> = true> template<typename RhsMapper>
struct loadColData_impl<RhsMapper, true>
{
// linear == true
static EIGEN_ALWAYS_INLINE Packet8bf run(RhsMapper& rhs, Index j)
{
return rhs.template loadPacket<Packet8bf>(j + 0);
}
};
template<typename RhsMapper, bool linear>
EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j) EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j)
{ {
return pgather<bfloat16, Packet8bf>(&rhs(j + 0, 0), rhs.stride()); return loadColData_impl<RhsMapper, linear>::run(rhs, j);
} }
template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear> template<Index num_acc, typename LhsMapper, typename RhsMapper, bool zero, bool linear>