Better performance for Power10 using more load and store vector pairs for GEMV

This commit is contained in:
Chip Kerchner 2022-06-27 18:11:55 +00:00 committed by Rasmus Munk Larsen
parent 0e18714167
commit c603275dc9

View File

@ -207,12 +207,8 @@ EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar* res, ResScalar& alpha, ResScal
} \
}
#if EIGEN_COMP_LLVM
#define GEMV_LOADPAIR2_COL_MMA(iter1, iter2)
#else
#define GEMV_LOADPAIR2_COL_MMA(iter1, iter2) \
b##iter1 = *reinterpret_cast<__vector_pair *>(res + i + ((iter2) * ResPacketSize));
#endif
#define GEMV_LOAD2_COL_MMA(iter1, iter2, iter3, N) \
if (GEMV_GETN(N) > iter1) { \
@ -231,8 +227,9 @@ EIGEN_ALWAYS_INLINE void storeMaddData(ResScalar* res, ResScalar& alpha, ResScal
#if EIGEN_COMP_LLVM
#define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
ResPacket f##iter2[2]; \
f##iter2[0] = pmadd(result##iter2.packet[0], palpha, ploadu<ResPacket>(res + i + ((iter4) * ResPacketSize))); \
f##iter2[1] = pmadd(result##iter3.packet[(iter2 == iter3) ? 2 : 0], palpha, ploadu<ResPacket>(res + i + (((iter4) + 1) * ResPacketSize))); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(f##iter2), &b##iter2); \
f##iter2[0] = pmadd(result##iter2.packet[0], palpha, f##iter2[0]); \
f##iter2[1] = pmadd(result##iter3.packet[(iter2 == iter3) ? 2 : 0], palpha, f##iter2[1]); \
GEMV_BUILDPAIR_MMA(b##iter2, f##iter2[0], f##iter2[1]);
#else
#define GEMV_WORKPAIR2_COL_MMA(iter2, iter3, iter4) \
@ -932,7 +929,7 @@ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, A
{
PResPacket c2 = pcplxflipconj(c0);
PResPacket c3 = pcplxflipconj(c1);
#if EIGEN_COMP_LLVM || !defined(_ARCH_PWR10)
#if !defined(_ARCH_PWR10)
ScalarPacket c4 = pload_complex<ResPacket>(res + (iter2 * ResPacketSize));
ScalarPacket c5 = pload_complex<ResPacket>(res + ((iter2 + 1) * ResPacketSize));
PResPacket c6 = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c4, b0));
@ -941,6 +938,13 @@ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, A
pstoreu(res + ((iter2 + 1) * ResPacketSize), c7);
#else
__vector_pair a = *reinterpret_cast<__vector_pair *>(res + (iter2 * ResPacketSize));
#if EIGEN_COMP_LLVM
PResPacket c6[2];
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(c6), &a);
c6[0] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c0.v, c2.v, c6[0].v, b0));
c6[1] = PResPacket(pmadd_complex<ScalarPacket, AlphaData>(c1.v, c3.v, c6[1].v, b0));
GEMV_BUILDPAIR_MMA(a, c6[0].v, c6[1].v);
#else
if (GEMV_IS_COMPLEX_FLOAT) {
__asm__ ("xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.r.v), "wa" (c0.v), "wa" (c1.v));
__asm__ ("xvmaddasp %L0,%x1,%x2\n\txvmaddasp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.i.v), "wa" (c2.v), "wa" (c3.v));
@ -948,6 +952,7 @@ EIGEN_ALWAYS_INLINE void pstoreu_pmadd_complex(PResPacket& c0, PResPacket& c1, A
__asm__ ("xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.r.v), "wa" (c0.v), "wa" (c1.v));
__asm__ ("xvmaddadp %L0,%x1,%x2\n\txvmaddadp %0,%x1,%x3" : "+&d" (a) : "wa" (b0.separate.i.v), "wa" (c2.v), "wa" (c3.v));
}
#endif
*reinterpret_cast<__vector_pair *>(res + (iter2 * ResPacketSize)) = a;
#endif
}