Fix problem with array conversions BF16->F32 in Power.

This commit is contained in:
Chip Kerchner 2023-02-13 21:30:45 +00:00 committed by Rasmus Munk Larsen
parent 77b48c440e
commit 4a03409569

View File

@ -189,11 +189,43 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat
ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0);
typedef typename DataMapper::LinearMapper LinearMapper;
Packet8us z = pset1<Packet8us>(0);
for(Index j = 0; j < cols; j++){
const LinearMapper res2 = res.getLinearMapper(0, j);
float *result2 = result + j*rows;
BFLOAT16_UNROLL
for(Index i = 0; i < rows; i++){
Index i = 0;
for(; i + 32 <= rows; i+=32){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val;
Packet8us r32_2 = res2.template loadPacket<Packet8bf>(i + 16).m_val;
Packet8us r32_3 = res2.template loadPacket<Packet8bf>(i + 24).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
pstore(result2 + i + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_2)));
pstore(result2 + i + 20, reinterpret_cast<Packet4f>(vec_mergel(z, r32_2)));
pstore(result2 + i + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_3)));
pstore(result2 + i + 28, reinterpret_cast<Packet4f>(vec_mergel(z, r32_3)));
}
for(; i + 16 <= rows; i+=16){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i + 8).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
pstore(result2 + i + 8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1)));
pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1)));
}
for(; i + 8 <= rows; i+=8){
Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i + 0).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
pstore(result2 + i + 4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0)));
}
for(; i + 4 <= rows; i+=4){
Packet8us r32_0 = res2.template loadPacketPartial<Packet8bf>(i + 0, 4).m_val;
pstore(result2 + i + 0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0)));
}
for(; i < rows; i++){
result2[i] = res2(i);
}
}