diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index 392e027c7..bdf434722 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -189,11 +189,43 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); typedef typename DataMapper::LinearMapper LinearMapper; + Packet8us z = pset1(0); for(Index j = 0; j < cols; j++){ const LinearMapper res2 = res.getLinearMapper(0, j); float *result2 = result + j*rows; - BFLOAT16_UNROLL - for(Index i = 0; i < rows; i++){ + Index i = 0; + for(; i + 32 <= rows; i+=32){ + Packet8us r32_0 = res2.template loadPacket(i + 0).m_val; + Packet8us r32_1 = res2.template loadPacket(i + 8).m_val; + Packet8us r32_2 = res2.template loadPacket(i + 16).m_val; + Packet8us r32_3 = res2.template loadPacket(i + 24).m_val; + pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); + pstore(result2 + i + 4, reinterpret_cast(vec_mergel(z, r32_0))); + pstore(result2 + i + 8, reinterpret_cast(vec_mergeh(z, r32_1))); + pstore(result2 + i + 12, reinterpret_cast(vec_mergel(z, r32_1))); + pstore(result2 + i + 16, reinterpret_cast(vec_mergeh(z, r32_2))); + pstore(result2 + i + 20, reinterpret_cast(vec_mergel(z, r32_2))); + pstore(result2 + i + 24, reinterpret_cast(vec_mergeh(z, r32_3))); + pstore(result2 + i + 28, reinterpret_cast(vec_mergel(z, r32_3))); + } + for(; i + 16 <= rows; i+=16){ + Packet8us r32_0 = res2.template loadPacket(i + 0).m_val; + Packet8us r32_1 = res2.template loadPacket(i + 8).m_val; + pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); + pstore(result2 + i + 4, reinterpret_cast(vec_mergel(z, r32_0))); + pstore(result2 + i + 8, reinterpret_cast(vec_mergeh(z, r32_1))); + pstore(result2 + i + 12, reinterpret_cast(vec_mergel(z, r32_1))); + } + for(; i + 8 <= rows; i+=8){ + Packet8us r32_0 = res2.template loadPacket(i + 0).m_val; + pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); + pstore(result2 + i + 4, reinterpret_cast(vec_mergel(z, r32_0))); + } + for(; i + 4 <= rows; i+=4){ + Packet8us r32_0 = res2.template loadPacketPartial(i + 0, 4).m_val; + pstore(result2 + i + 0, reinterpret_cast(vec_mergeh(z, r32_0))); + } + for(; i < rows; i++){ result2[i] = res2(i); } }