Fix Power GEMV order of operations in predux for MMA.

This commit is contained in:
Chip Kerchner 2022-04-11 21:29:05 +00:00 committed by Antonio Sánchez
parent a81bba962a
commit 53eec53d2a

View File

@ -120,8 +120,8 @@ EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar* res, ResScalar& alpha, ResScal
GEMV_BUILDPAIR_MMA(b##iter1, GEMV_LOADPACKET_COL(iter2), GEMV_LOADPACKET_COL((iter2) + 1));
#else
#define GEMV_LOADPAIR_COL_MMA(iter1, iter2) \
const LhsScalar& src##iter1 = lhs(i + 0, j); \
__asm__ ("lxvp %x0,%1(%2)" : "=wa" (b##iter1) : "K" (iter1 * 32), "a" (&src##iter1));
const LhsScalar& src##iter1 = lhs(i + ((iter1 * 32) / sizeof(LhsScalar)), j); \
b##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src##iter1));
#endif
#define GEMV_LOAD1A_COL_MMA(iter, N) \
@ -1412,8 +1412,8 @@ EIGEN_ALWAYS_INLINE void disassembleResults(__vector_quad* c0, PacketBlock<Scala
#else
#define GEMV_LOADPAIR_COL_COMPLEX_MMA(iter1, iter2) \
if (sizeof(LhsPacket) == 16) { \
const LhsScalar& src = lhs(i + 0, j); \
__asm__ ("lxvp %x0,%1(%2)" : "=wa" (a##iter1) : "K" (iter1 * 32), "a" (&src)); \
const LhsScalar& src = lhs(i + ((32 * iter1) / sizeof(LhsScalar)), j); \
a##iter1 = *reinterpret_cast<__vector_pair *>(const_cast<LhsScalar *>(&src)); \
EIGEN_UNUSED_VARIABLE(f##iter1); \
} else { \
f##iter1 = lhs.template load<PLhsPacket, Unaligned>(i + ((iter2) * ResPacketSize), j); \
@ -1719,17 +1719,26 @@ template <typename Scalar, int N> struct ScalarBlock {
};
#ifdef USE_GEMV_MMA
static Packet16uc p16uc_ELEMENT_3 = { 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f, 0x0c,0x0d,0x0e,0x0f, 0x1c,0x1d,0x1e,0x1f };
/** \internal predux (add elements of a vector) from a MMA accumulator - real results */
template<typename ResScalar, typename ResPacket>
EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_real(__vector_quad* acc0, __vector_quad* acc1)
{
ScalarBlock<ResScalar, 2> cc0;
union {
ScalarBlock<ResScalar, 2> cs;
double cd;
} cc0;
PacketBlock<ResPacket, 4> result0, result1;
__builtin_mma_disassemble_acc(&result0.packet, acc0);
__builtin_mma_disassemble_acc(&result1.packet, acc1);
cc0.scalar[0] = result0.packet[0][0] + result0.packet[1][1] + result0.packet[2][2] + result0.packet[3][3];
cc0.scalar[1] = result1.packet[0][0] + result1.packet[1][1] + result1.packet[2][2] + result1.packet[3][3];
return cc0;
result0.packet[0] = vec_mergeh(result0.packet[0], result1.packet[0]);
result0.packet[1] = vec_mergeo(result0.packet[1], result1.packet[1]);
result0.packet[2] = vec_mergel(result0.packet[2], result1.packet[2]);
result0.packet[3] = vec_perm(result0.packet[3], result1.packet[3], p16uc_ELEMENT_3);
result0.packet[0] = vec_add(vec_add(result0.packet[0], result0.packet[2]), vec_add(result0.packet[1], result0.packet[3]));
cc0.cd = pfirst(reinterpret_cast<Packet2d>(result0.packet[0]));
return cc0.cs;
}
template<>