New panel modes for GEMM MMA (real & complex).

This commit is contained in:
Chip Kerchner 2023-09-06 20:03:45 +00:00 committed by Rasmus Munk Larsen
parent 2c64a655fe
commit 4e598ad259
No known key found for this signature in database
2 changed files with 631 additions and 335 deletions

View File

@ -49,11 +49,6 @@
#include "MatrixProductMMA.h"
#endif
/**************************************************************************************************
* TODO *
* - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). *
* - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
**************************************************************************************************/
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
@ -120,6 +115,16 @@ const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
20, 21, 22, 23,
28, 29, 30, 31};
const static Packet16uc p16uc_GETREAL32b = { 0, 1, 2, 3,
16, 17, 18, 19,
8, 9, 10, 11,
24, 25, 26, 27};
const static Packet16uc p16uc_GETIMAG32b = { 4, 5, 6, 7,
20, 21, 22, 23,
12, 13, 14, 15,
28, 29, 30, 31};
/*********************************************
* Single precision real and complex packing *
* *******************************************/
@ -440,6 +445,78 @@ EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
// General template for lhs & rhs complex packing.
template<typename Scalar, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
struct dhs_cpack {
template<bool transpose>
EIGEN_ALWAYS_INLINE void dhs_cblock(PacketBlock<PacketC,8>& cblock, PacketBlock<Packet,4>& block, Packet16uc permute)
{
if (transpose) {
block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, permute);
block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, permute);
block.packet[2] = vec_perm(cblock.packet[4].v, cblock.packet[5].v, permute);
block.packet[3] = vec_perm(cblock.packet[6].v, cblock.packet[7].v, permute);
Packet4f t0, t1, t2, t3;
#ifdef EIGEN_VECTORIZE_VSX
t0 = reinterpret_cast<Packet>(vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
t1 = reinterpret_cast<Packet>(vec_mergel(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
t2 = reinterpret_cast<Packet>(vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
t3 = reinterpret_cast<Packet>(vec_mergel(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
#else
t0 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_HI));
t1 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_LO));
t2 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_HI));
t3 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_LO));
#endif
block.packet[0] = t0;
block.packet[1] = t1;
block.packet[2] = t2;
block.packet[3] = t3;
} else {
block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, permute);
block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, permute);
block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, permute);
block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, permute);
}
}
EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize)
{
PacketBlock<Packet,4> blockr, blocki;
PacketBlock<PacketC,8> cblock;
for(; i + vectorSize <= depth; i+=vectorSize)
{
if (UseLhs) {
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
} else {
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
}
if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
{
dhs_cblock<true>(cblock, blockr, p16uc_GETREAL32b);
dhs_cblock<true>(cblock, blocki, p16uc_GETIMAG32b);
} else {
dhs_cblock<false>(cblock, blockr, p16uc_GETREAL32);
dhs_cblock<false>(cblock, blocki, p16uc_GETIMAG32);
}
if(Conjugate)
{
blocki.packet[0] = -blocki.packet[0];
blocki.packet[1] = -blocki.packet[1];
blocki.packet[2] = -blocki.packet[2];
blocki.packet[3] = -blocki.packet[3];
}
storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
rir += 4*vectorSize;
rii += 4*vectorSize;
}
}
EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
const Index vectorSize = quad_traits<Scalar>::vectorsize;
@ -455,47 +532,8 @@ struct dhs_cpack {
rii = rir + vectorDelta;
for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet,4> blockr, blocki;
PacketBlock<PacketC,8> cblock;
dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
if (UseLhs) {
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
} else {
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
}
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
if(Conjugate)
{
blocki.packet[0] = -blocki.packet[0];
blocki.packet[1] = -blocki.packet[1];
blocki.packet[2] = -blocki.packet[2];
blocki.packet[3] = -blocki.packet[3];
}
if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
{
ptranspose(blockr);
ptranspose(blocki);
}
storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
rir += 4*vectorSize;
rii += 4*vectorSize;
}
for(; i < depth; i++)
{
PacketBlock<Packet,1> blockr, blocki;
@ -592,6 +630,36 @@ struct dhs_cpack {
// General template for lhs & rhs packing.
template<typename Scalar, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
struct dhs_pack{
template<Index n>
EIGEN_ALWAYS_INLINE void dhs_copy(Scalar* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth, const Index vectorSize)
{
PacketBlock<Packet,4> block[n];
for(; i + n*vectorSize <= depth; i+=n*vectorSize)
{
for (Index k = 0; k < n; k++) {
if (UseLhs) {
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, 0, i + k*vectorSize);
} else {
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, i + k*vectorSize, 0);
}
}
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
{
for (Index k = 0; k < n; k++) {
ptranspose(block[k]);
}
}
for (Index k = 0; k < n; k++) {
storeBlock<Scalar, Packet, 4>(blockA + ri + k*4*vectorSize, block[k]);
}
ri += n*4*vectorSize;
}
}
EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
const Index vectorSize = quad_traits<Scalar>::vectorsize;
@ -604,24 +672,10 @@ struct dhs_pack{
if(PanelMode) ri += vectorSize*offset;
for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet,4> block;
dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
if (UseLhs) {
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs2, 0, i);
} else {
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs2, i, 0);
}
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
{
ptranspose(block);
}
storeBlock<Scalar, Packet, 4>(blockA + ri, block);
ri += 4*vectorSize;
}
for(; i < depth; i++)
{
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
@ -691,6 +745,39 @@ struct dhs_pack{
template<typename DataMapper, int StorageOrder, bool PanelMode>
struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
{
template<Index n>
EIGEN_ALWAYS_INLINE void dhs_copy(double* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth, const Index vectorSize)
{
PacketBlock<Packet2d,2> block[n];
for(; i + n*vectorSize <= depth; i+=n*vectorSize)
{
for (Index k = 0; k < n; k++) {
if(StorageOrder == RowMajor)
{
block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k*vectorSize);
block[k].packet[1] = lhs2.template loadPacket<Packet2d>(1, i + k*vectorSize);
} else {
block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k*vectorSize + 0);
block[k].packet[1] = lhs2.template loadPacket<Packet2d>(0, i + k*vectorSize + 1);
}
}
if(StorageOrder == RowMajor)
{
for (Index k = 0; k < n; k++) {
ptranspose(block[k]);
}
}
for (Index k = 0; k < n; k++) {
storeBlock<double, Packet2d, 2>(blockA + ri + k*2*vectorSize, block[k]);
}
ri += n*2*vectorSize;
}
}
EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
const Index vectorSize = quad_traits<double>::vectorsize;
@ -703,24 +790,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
if(PanelMode) ri += vectorSize*offset;
for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet2d,2> block;
if(StorageOrder == RowMajor)
{
block.packet[0] = lhs2.template loadPacket<Packet2d>(0, i);
block.packet[1] = lhs2.template loadPacket<Packet2d>(1, i);
dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
ptranspose(block);
} else {
block.packet[0] = lhs2.template loadPacket<Packet2d>(0, i + 0);
block.packet[1] = lhs2.template loadPacket<Packet2d>(0, i + 1);
}
storeBlock<double, Packet2d, 2>(blockA + ri, block);
ri += 2*vectorSize;
}
for(; i < depth; i++)
{
if(StorageOrder == RowMajor)
@ -759,6 +832,53 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
template<typename DataMapper, int StorageOrder, bool PanelMode>
struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
{
template<Index n>
EIGEN_ALWAYS_INLINE void dhs_copy(double* blockB, const DataMapper& rhs2, Index& i, Index& ri, Index depth, const Index vectorSize)
{
PacketBlock<Packet2d,2> block1[n], block2[n];
PacketBlock<Packet2d,4> block3[n];
for(; i + n*vectorSize <= depth; i+=n*vectorSize)
{
for (Index k = 0; k < n; k++) {
if(StorageOrder == ColMajor)
{
block1[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 0);
block1[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 1);
block2[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 2);
block2[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 3);
} else {
block3[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 0, 0); //[a1 a2]
block3[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 0, 2); //[a3 a4]
block3[k].packet[2] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 1, 0); //[b1 b2]
block3[k].packet[3] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 1, 2); //[b3 b4]
}
}
if(StorageOrder == ColMajor)
{
for (Index k = 0; k < n; k++) {
ptranspose(block1[k]);
ptranspose(block2[k]);
}
}
for (Index k = 0; k < n; k++) {
if(StorageOrder == ColMajor)
{
pstore<double>(blockB + ri + k*4*vectorSize , block1[k].packet[0]);
pstore<double>(blockB + ri + k*4*vectorSize + 2, block2[k].packet[0]);
pstore<double>(blockB + ri + k*4*vectorSize + 4, block1[k].packet[1]);
pstore<double>(blockB + ri + k*4*vectorSize + 6, block2[k].packet[1]);
} else {
storeBlock<double, Packet2d, 4>(blockB + ri + k*4*vectorSize, block3[k]);
}
}
ri += n*4*vectorSize;
}
}
EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
{
const Index vectorSize = quad_traits<double>::vectorsize;
@ -771,35 +891,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
if(PanelMode) ri += offset*(2*vectorSize);
for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet2d,4> block;
if(StorageOrder == ColMajor)
{
PacketBlock<Packet2d,2> block1, block2;
block1.packet[0] = rhs2.template loadPacket<Packet2d>(i, 0);
block1.packet[1] = rhs2.template loadPacket<Packet2d>(i, 1);
block2.packet[0] = rhs2.template loadPacket<Packet2d>(i, 2);
block2.packet[1] = rhs2.template loadPacket<Packet2d>(i, 3);
dhs_copy<4>(blockB, rhs2, i, ri, depth, vectorSize);
dhs_copy<2>(blockB, rhs2, i, ri, depth, vectorSize);
dhs_copy<1>(blockB, rhs2, i, ri, depth, vectorSize);
ptranspose(block1);
ptranspose(block2);
pstore<double>(blockB + ri , block1.packet[0]);
pstore<double>(blockB + ri + 2, block2.packet[0]);
pstore<double>(blockB + ri + 4, block1.packet[1]);
pstore<double>(blockB + ri + 6, block2.packet[1]);
} else {
block.packet[0] = rhs2.template loadPacket<Packet2d>(i + 0, 0); //[a1 a2]
block.packet[1] = rhs2.template loadPacket<Packet2d>(i + 0, 2); //[a3 a4]
block.packet[2] = rhs2.template loadPacket<Packet2d>(i + 1, 0); //[b1 b2]
block.packet[3] = rhs2.template loadPacket<Packet2d>(i + 1, 2); //[b3 b4]
storeBlock<double, Packet2d, 4>(blockB + ri, block);
}
ri += 4*vectorSize;
}
for(; i < depth; i++)
{
if(StorageOrder == ColMajor)
@ -1296,6 +1391,54 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false>
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
{
EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize)
{
PacketBlock<Packet,2> blockr, blocki;
PacketBlock<PacketC,4> cblock;
for(; i + vectorSize <= depth; i+=vectorSize)
{
if(StorageOrder == ColMajor)
{
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
} else {
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
}
if(Conjugate)
{
blocki.packet[0] = -blocki.packet[0];
blocki.packet[1] = -blocki.packet[1];
}
storeBlock<double, Packet, 2>(blockAt + rir, blockr);
storeBlock<double, Packet, 2>(blockAt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
}
}
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
{
const Index vectorSize = quad_traits<double>::vectorsize;
@ -1311,50 +1454,8 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
rii = rir + vectorDelta;
for(; i + vectorSize <= depth; i+=vectorSize)
{
PacketBlock<Packet,2> blockr, blocki;
PacketBlock<PacketC,4> cblock;
dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
if(StorageOrder == ColMajor)
{
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
} else {
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
}
if(Conjugate)
{
blocki.packet[0] = -blocki.packet[0];
blocki.packet[1] = -blocki.packet[1];
}
storeBlock<double, Packet, 2>(blockAt + rir, blockr);
storeBlock<double, Packet, 2>(blockAt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
}
for(; i < depth; i++)
{
PacketBlock<Packet,1> blockr, blocki;
@ -1410,6 +1511,35 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
{
EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockBt, const DataMapper& rhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize)
{
for(; i < depth; i++)
{
PacketBlock<PacketC,4> cblock;
PacketBlock<Packet,2> blockr, blocki;
bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
if(Conjugate)
{
blocki.packet[0] = -blocki.packet[0];
blocki.packet[1] = -blocki.packet[1];
}
storeBlock<double, Packet, 2>(blockBt + rir, blockr);
storeBlock<double, Packet, 2>(blockBt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
}
}
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
{
const Index vectorSize = quad_traits<double>::vectorsize;
@ -1425,31 +1555,7 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
rii = rir + vectorDelta;
for(; i < depth; i++)
{
PacketBlock<PacketC,4> cblock;
PacketBlock<Packet,2> blockr, blocki;
bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
if(Conjugate)
{
blocki.packet[0] = -blocki.packet[0];
blocki.packet[1] = -blocki.packet[1];
}
storeBlock<double, Packet, 2>(blockBt + rir, blockr);
storeBlock<double, Packet, 2>(blockBt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
}
dhs_ccopy(blockBt, rhs2, i, rir, rii, depth, vectorSize);
rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
}

View File

@ -42,19 +42,13 @@ EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
__builtin_mma_xxsetaccz(acc);
}
#ifdef USE_PARTIAL_PACKETS
template<typename DataMapper, typename Packet, bool full>
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Index elements, __vector_quad* acc)
#else
template<typename DataMapper, typename Packet, const Index accCols, const Index accCols2>
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Packet& pMask, __vector_quad* acc)
#endif
{
PacketBlock<Packet, 4> result;
__builtin_mma_disassemble_acc(&result.packet, acc);
PacketBlock<Packet, 4> tRes;
#ifdef USE_PARTIAL_PACKETS
if (full) {
EIGEN_UNUSED_VARIABLE(elements);
bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
@ -65,11 +59,6 @@ EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const
bscale<Packet, 4>(tRes, result, alpha);
bstore_partial<DataMapper, Packet, 4>(tRes, data, i, elements);
}
#else
bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
bscale<Packet, 4, (accCols != accCols2)>(tRes, result, alpha, pMask);
bstore<DataMapper, Packet, 4>(tRes, data, i);
#endif
}
template<typename DataMapper, typename Packet, typename Packetc, const Index accCols, const Index accCols2>
@ -166,78 +155,118 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
ploadRhsMMA(lhs, lhsV);
}
#if (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
#define GEMM_MULTIPLE_COLS
// Disable in GCC until unnecessary register moves are fixed
//#if (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
#if EIGEN_COMP_LLVM
#define VECTOR_PAIR_LOADS_LHS
#endif
// PEEL_MMA loop factor.
#ifdef GEMM_MULTIPLE_COLS
#define PEEL_MMA 8
#else
// Register spillage with GCC12+
#if EIGEN_COMP_LLVM || (__GNUC__ < 12) || defined(VECTOR_PAIR_LOADS_LHS)
#define PEEL_MMA 7
#else
#define PEEL_MMA 6
#endif
#endif
#define MICRO_MMA_UNROLL(func) \
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
#define MICRO_MMA_WORK(func, type, peel) \
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel)
if (accItr == 1) { \
func(0,type,peel,0,0) func(1,type,peel,1,0) func(2,type,peel,2,0) func(3,type,peel,3,0) \
func(4,type,peel,4,0) func(5,type,peel,5,0) func(6,type,peel,6,0) func(7,type,peel,7,0) \
} else if (accItr == 2) { \
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,1,0) func(3,type,peel,1,1) \
func(4,type,peel,2,0) func(5,type,peel,2,1) func(6,type,peel,3,0) func(7,type,peel,3,1) \
} else { \
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,0,2) func(3,type,peel,0,3) \
func(4,type,peel,1,0) func(5,type,peel,1,1) func(6,type,peel,1,2) func(7,type,peel,1,3) \
}
#define MICRO_MMA_WORK_ONE(iter, type, peel) \
if (unroll_factor > iter) { \
pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV##iter); \
#define MICRO_MMA_WORK_ONE(iter, type, peel, left, right) \
if (unroll_factor > left) { \
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV##left); \
}
#ifdef VECTOR_PAIR_LOADS_LHS
#define MICRO_MMA_WORK_TWO(iter, type, peel) \
if (unroll_factor > iter) { \
pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV2##iter.packet[peel & 1]); \
#define MICRO_MMA_WORK_TWO(iter, type, peel, left, right) \
if (unroll_factor > left) { \
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV2##left.packet[peel & 1]); \
}
#define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \
if (unroll_factor > iter) { \
if (MICRO_NORMAL(iter)) { \
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##iter), plhsV##iter); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##iter.packet), &plhsV##iter); \
lhs_ptr##iter += accCols*2; \
#define MICRO_MMA_LOAD1_TWO(lhs_ptr, left) \
if (unroll_factor > left) { \
if (MICRO_NORMAL(left)) { \
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##left), plhsV##left); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##left.packet), &plhsV##left); \
lhs_ptr##left += accCols*2; \
} else { \
lhsV2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr##iter); \
lhsV2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr##iter + accCols2); \
lhs_ptr##iter += accCols2*2; \
EIGEN_UNUSED_VARIABLE(plhsV##iter) \
lhsV2##left.packet[0] = ploadLhs<Packet>(lhs_ptr##left); \
lhsV2##left.packet[1] = ploadLhs<Packet>(lhs_ptr##left + accCols2); \
lhs_ptr##left += accCols2*2; \
EIGEN_UNUSED_VARIABLE(plhsV##left); \
} \
} else { \
EIGEN_UNUSED_VARIABLE(lhsV2##iter); \
EIGEN_UNUSED_VARIABLE(plhsV##iter) \
EIGEN_UNUSED_VARIABLE(lhsV2##left); \
EIGEN_UNUSED_VARIABLE(plhsV##left); \
}
#define MICRO_MMA_LOAD_TWO(iter) MICRO_MMA_LOAD1_TWO(lhs_ptr, iter)
#define MICRO_MMA_LOAD_TWO(left) MICRO_MMA_LOAD1_TWO(lhs_ptr, left)
#endif
#define MICRO_MMA_UNROLL_ITER(func, val) \
func(val,0) \
if (accItr > 1) { \
func(val,1) \
if (accItr > 2) { \
func(val,2) \
func(val,3) \
} \
}
#define MICRO_MMA_LOAD_ONE_RHS1(peel, right) \
ploadRhsMMA(rhs_ptr##right + (accRows * peel), rhsV##right[peel]);
#define MICRO_MMA_LOAD_ONE_RHS(peel) \
MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_ONE_RHS1, peel)
#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
if (PEEL_MMA > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV[peel]); \
MICRO_MMA_LOAD_ONE_RHS(peel) \
MICRO_MMA_UNROLL(funcl) \
MICRO_MMA_WORK(funcw, type, peel) \
}
#ifndef VECTOR_PAIR_LOADS_LHS
#define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
type rhsV[8]; \
type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,1) \
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,3) \
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,4) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,5) \
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,6) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,7)
#else
#define MICRO_MMA_LOAD_TWO_RHS(peel1, right) \
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr##right + (accRows * peel1)), prhsV##peel1); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1);
#define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
if (PEEL_MMA > peel2) { \
PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
if (sizeof(type) == 16) { \
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr + (accRows * peel1)), prhsV##peel1); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_TWO_RHS, peel1) \
} else { \
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV[peel1]); \
ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV[peel2]); \
MICRO_MMA_LOAD_ONE_RHS(peel1) \
MICRO_MMA_LOAD_ONE_RHS(peel2) \
} \
MICRO_MMA_UNROLL(funcl2) \
MICRO_MMA_WORK(funcw2, type, peel1) \
@ -248,7 +277,7 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
}
#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
type rhsV[8]; \
type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
__vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \
@ -257,19 +286,25 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
#endif
#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
type rhsV[1]; \
type rhsV0[1], rhsV1[1], rhsV2[1], rhsV3[1]; \
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0)
#define MICRO_MMA_UPDATE_RHS1(size, right) \
rhs_ptr##right += (accRows * size);
#define MICRO_MMA_UPDATE_RHS(size) \
MICRO_MMA_UNROLL_ITER(MICRO_MMA_UPDATE_RHS1, size)
#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
rhs_ptr += (accRows * size);
MICRO_MMA_UPDATE_RHS(size)
#ifndef VECTOR_PAIR_LOADS_LHS
#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA)
#else
#define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
rhs_ptr += (accRows * size);
MICRO_MMA_UPDATE_RHS(size)
#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
#endif
@ -277,7 +312,7 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
#define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
#define MICRO_MMA_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
if (unroll_factor * accItr > iter) { \
bsetzeroMMA(&accZero##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accZero##iter); \
@ -289,45 +324,69 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
#ifdef USE_PARTIAL_PACKETS
#define MICRO_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(iter)>(row + iter*accCols, res, pAlpha, accCols2, &accZero##iter); \
#define MICRO_MMA_STORE_ONE(iter, left, right) \
if (unroll_factor > left) { \
storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(left)>(row + left*accCols, res##right, pAlpha, accCols2, &accZero##iter); \
}
#else
#define MICRO_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
storeAccumulator<DataMapper, Packet, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \
#define MICRO_MMA_ITER_UNROLL(func) \
if (accItr == 1) { \
func(0,0,0) func(1,1,0) func(2,2,0) func(3,3,0) \
func(4,4,0) func(5,5,0) func(6,6,0) func(7,7,0) \
} else if (accItr == 2) { \
func(0,0,0) func(1,0,1) func(2,1,0) func(3,1,1) \
func(4,2,0) func(5,2,1) func(6,3,0) func(7,3,1) \
} else { \
func(0,0,0) func(1,0,1) func(2,0,2) func(3,0,3) \
func(4,1,0) func(5,1,1) func(6,1,2) func(7,1,3) \
}
#endif
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
#define MICRO_MMA_STORE MICRO_MMA_ITER_UNROLL(MICRO_MMA_STORE_ONE)
#ifdef USE_PARTIAL_PACKETS
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool full>
#else
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2>
#endif
#define MICRO_MMA_EXTRA_ROWS(right) \
gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3##right, blockA, rhs_base + right*accRows*strideB, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
#define MICRO_MMA_EXTRA_ROWS1(val, right) \
MICRO_MMA_EXTRA_ROWS(right);
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool full, const Index accItr>
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
const DataMapper& res,
const DataMapper& res0,
const DataMapper& res1,
const DataMapper& res2,
const DataMapper& res3,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
Index strideB,
Index offsetA,
Index& row,
const Packet& pAlpha,
#ifdef USE_PARTIAL_PACKETS
Index accCols2
#else
const Packet& pMask
#endif
)
{
const Scalar* rhs_ptr = rhs_base;
const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL, * rhs_ptr3 = NULL;
const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
__vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
if (accItr > 1) {
rhs_ptr1 = rhs_base + (accRows * strideB);
} else {
EIGEN_UNUSED_VARIABLE(strideB);
EIGEN_UNUSED_VARIABLE(rhs_ptr1);
EIGEN_UNUSED_VARIABLE(res1);
}
if (accItr > 2) {
rhs_ptr2 = rhs_base + (2 * accRows * strideB);
rhs_ptr3 = rhs_base + (3 * accRows * strideB);
} else {
EIGEN_UNUSED_VARIABLE(rhs_ptr2);
EIGEN_UNUSED_VARIABLE(rhs_ptr3);
EIGEN_UNUSED_VARIABLE(res2);
EIGEN_UNUSED_VARIABLE(res3);
}
MICRO_MMA_SRC_PTR
MICRO_MMA_DST_PTR
@ -347,17 +406,16 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
MICRO_UPDATE
}
#ifdef USE_PARTIAL_PACKETS
#define MICRO_MMA_UNROLL_ITER2(N, M) \
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, M ? remaining_rows : accCols); \
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M, accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, strideB, offsetA, row, pAlpha, M ? remaining_rows : accCols); \
if (M) return;
#else
#define MICRO_MMA_UNROLL_ITER2(N, M) \
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \
if (M) return;
#endif
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
#define MICRO_MMA_ROWS(n) \
while(row + n*accCols <= rows) { \
MICRO_MMA_UNROLL_ITER2(n, 0); \
}
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accItr>
EIGEN_ALWAYS_INLINE void gemmMMA_cols(
const DataMapper& res,
const Scalar* blockA,
@ -373,45 +431,71 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols(
const Packet& pAlpha,
const Packet& pMask)
{
const DataMapper res3 = res.getSubMapper(0, col);
const DataMapper res30 = res.getSubMapper(0, col);
const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows*1) : res30;
const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows*2) : res30;
const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows*3) : res30;
const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
const Scalar* lhs_base = blockA + accCols*offsetA;
Index row = 0;
#define MAX_MMA_UNROLL 7
while(row + MAX_MMA_UNROLL*accCols <= rows) {
MICRO_MMA_UNROLL_ITER2(MAX_MMA_UNROLL, 0);
#if MAX_MMA_UNROLL < 2
if (1) {
#elif MAX_MMA_UNROLL < 4
if (accItr <= 2) {
#else
if (accItr == 1) {
#endif
MICRO_MMA_ROWS(MAX_MMA_UNROLL);
} else if (accItr == 2) {
MICRO_MMA_ROWS(4);
} else {
MICRO_MMA_ROWS(2);
}
switch( (rows-row)/accCols ) {
#if MAX_MMA_UNROLL > 7
case 7:
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7)
if (accItr == 1) {
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7)
}
break;
#endif
#if MAX_MMA_UNROLL > 6
case 6:
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6)
if (accItr == 1) {
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6)
}
break;
#endif
#if MAX_MMA_UNROLL > 5
case 5:
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5)
if (accItr == 1) {
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5)
}
break;
#endif
#if MAX_MMA_UNROLL > 4
case 4:
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4)
if (accItr == 1) {
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4)
}
break;
#endif
#if MAX_MMA_UNROLL > 3
case 3:
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3)
if (accItr <= 2) {
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3)
}
break;
#endif
#if MAX_MMA_UNROLL > 2
case 2:
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2)
if (accItr <= 2) {
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2)
}
break;
#endif
#if MAX_MMA_UNROLL > 1
@ -426,10 +510,16 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols(
if(remaining_rows > 0)
{
gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
MICRO_MMA_UNROLL_ITER(MICRO_MMA_EXTRA_ROWS1, 0)
}
}
#define MICRO_MMA_COLS(n) \
for(; col + n*accRows <= cols; col += n*accRows) \
{ \
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); \
}
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
@ -444,10 +534,11 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
}
#ifdef GEMM_MULTIPLE_COLS
MICRO_MMA_COLS(4);
MICRO_MMA_COLS(2);
#endif
MICRO_MMA_COLS(1);
if (col != cols)
{
@ -459,62 +550,88 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#define advanceCols ((RhsIsReal) ? 1 : 2)
// PEEL_COMPLEX_MMA loop factor.
#ifdef GEMM_MULTIPLE_COLS
#define PEEL_COMPLEX_MMA 4
#else
#define PEEL_COMPLEX_MMA 3
#endif
#define MICRO_COMPLEX_MMA_UNROLL(func) \
func(0) func(1) func(2) func(3)
#define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel)
if (accItr == 1) { \
func(0,type,peel,0,0) func(1,type,peel,1,0) func(2,type,peel,2,0) func(3,type,peel,3,0) \
} else if (accItr == 2) { \
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,1,0) func(3,type,peel,1,1) \
} else { \
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,0,2) func(3,type,peel,0,3) \
}
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
if (unroll_factor > iter) { \
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV[peel], rhsVi[peel]); \
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel, left, right) \
if (unroll_factor > left) { \
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##left, lhsVi##left, rhsV##right[peel], rhsVi##right[peel]); \
}
#ifdef VECTOR_PAIR_LOADS_LHS
#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \
if (unroll_factor > iter) { \
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV[peel], rhsVi[peel]); \
#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel, left, right) \
if (unroll_factor > left) { \
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##left.packet[peel & 1], lhsVi2##left.packet[peel & 1], rhsV##right[peel], rhsVi##right[peel]); \
}
#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \
if (!LhsIsReal && (unroll_factor > iter)) { \
if (MICRO_NORMAL(iter)) { \
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##iter + imag_delta), plhsVi##iter); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##iter.packet), &plhsVi##iter); \
#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) \
if (!LhsIsReal && (unroll_factor > left)) { \
if (MICRO_NORMAL(left)) { \
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##left + imag_delta), plhsVi##left); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##left.packet), &plhsVi##left); \
} else { \
lhsVi2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2); \
lhsVi2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2 + accCols2); \
EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
lhsVi2##left.packet[0] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2); \
lhsVi2##left.packet[1] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2 + accCols2); \
EIGEN_UNUSED_VARIABLE(plhsVi##left); \
} \
} else { \
EIGEN_UNUSED_VARIABLE(lhsVi2##iter); \
EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
EIGEN_UNUSED_VARIABLE(lhsVi2##left); \
EIGEN_UNUSED_VARIABLE(plhsVi##left); \
} \
MICRO_MMA_LOAD1_TWO(lhs_ptr_real, iter)
MICRO_MMA_LOAD1_TWO(lhs_ptr_real, left)
#define MICRO_COMPLEX_MMA_LOAD_TWO(iter) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter)
#define MICRO_COMPLEX_MMA_LOAD_TWO(left) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left)
#endif
#define MICRO_COMPLEX_MMA_LOAD_RHS1(peel, right) \
ploadRhsMMA(rhs_ptr_real##right + (accRows * peel), rhsV##right[peel]); \
if (!RhsIsReal) { \
ploadRhsMMA(rhs_ptr_imag##right + (accRows * peel), rhsVi##right[peel]); \
}
#define MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_RHS1, peel)
#define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
if (PEEL_COMPLEX_MMA > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3; \
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV[peel]); \
if(!RhsIsReal) { \
ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi[peel]); \
} \
MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
MICRO_COMPLEX_MMA_UNROLL(funcl) \
MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
}
#ifndef VECTOR_PAIR_LOADS_LHS
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
type rhsV[4], rhsVi[4]; \
type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,1) \
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,3)
#else
#define MICRO_COMPLEX_MMA_LOAD_TWO_RHS(peel1, right) \
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real##right + (accRows * peel1)), prhsV##peel1); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1); \
if(!RhsIsReal) { \
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag##right + (accRows * peel1)), prhsVi##peel1); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi##right[peel1]), &prhsVi##peel1); \
} else { \
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
}
#define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
if (PEEL_COMPLEX_MMA > peel2) { \
PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23; \
@ -522,23 +639,12 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
__vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
if (sizeof(type) == 16) { \
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real + (accRows * peel1)), prhsV##peel1); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
if(!RhsIsReal) { \
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag + (accRows * peel1)), prhsVi##peel1); \
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi[peel1]), &prhsVi##peel1); \
} else { \
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
} \
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_TWO_RHS, peel1) \
} else { \
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV[peel1]); \
ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV[peel2]); \
if(!RhsIsReal) { \
ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi[peel1]); \
ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi[peel2]); \
} \
MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel1); \
MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel2); \
} \
MICRO_COMPLEX_MMA_UNROLL(funcl2) \
MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
@ -550,7 +656,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
}
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
type rhsV[4], rhsVi[4]; \
type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
__vector_pair prhsV0, prhsV2; \
__vector_pair prhsVi0, prhsVi2; \
MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
@ -558,21 +664,26 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#endif
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
type rhsV[1], rhsVi[1]; \
type rhsV0[1], rhsVi0[1], rhsV1[1], rhsVi1[1], rhsV2[1], rhsVi2[1], rhsV3[1], rhsVi3[1]; \
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0)
#define MICRO_COMPLEX_MMA_UPDATE_RHS1(size, right) \
rhs_ptr_real##right += (accRows * size); \
if(!RhsIsReal) rhs_ptr_imag##right += (accRows * size);
#define MICRO_COMPLEX_MMA_UPDATE_RHS(size) \
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_UPDATE_RHS1, size)
#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
rhs_ptr_real += (accRows * size); \
if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
MICRO_COMPLEX_MMA_UPDATE_RHS(size);
#ifndef VECTOR_PAIR_LOADS_LHS
#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA)
#else
#define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
rhs_ptr_real += (accRows * size); \
if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
MICRO_COMPLEX_MMA_UPDATE_RHS(size);
#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
#endif
@ -580,7 +691,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
if (unroll_factor * accItr > iter) { \
bsetzeroMMA(&accReal##iter); \
bsetzeroMMA(&accImag##iter); \
} else { \
@ -594,16 +705,34 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
#define MICRO_COMPLEX_MMA_STORE_ONE(iter, left, right) \
if (unroll_factor > left) { \
storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (left + 1)) ? accCols : accCols2>(row + left*accCols, res##right, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
}
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
#define MICRO_COMPLEX_MMA_ITER_UNROLL(func) \
if (accItr == 1) { \
func(0,0,0) func(1,1,0) func(2,2,0) func(3,3,0) \
} else if (accItr == 2) { \
func(0,0,0) func(1,0,1) func(2,1,0) func(3,1,1) \
} else { \
func(0,0,0) func(1,0,1) func(2,0,2) func(3,0,3) \
}
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_ITER_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
#define MICRO_COMPLEX_MMA_EXTRA_ROWS(right) \
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3##right, blockA, rhs_base + right*accRows*(RhsIsReal ? 1 : 2)*strideB, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
#define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right) \
MICRO_COMPLEX_MMA_EXTRA_ROWS(right);
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index accItr>
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
const DataMapper& res,
const DataMapper& res0,
const DataMapper& res1,
const DataMapper& res2,
const DataMapper& res3,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
@ -615,14 +744,48 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
const Packet& pAlphaImag,
const Packet& pMask)
{
const Scalar* rhs_ptr_real = rhs_base;
const Scalar* rhs_ptr_imag = NULL;
const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL, * rhs_ptr_real3 = NULL;
const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL, * rhs_ptr_imag3 = NULL;
const Index imag_delta = accCols*strideA;
const Index imag_delta2 = accCols2*strideA;
if(!RhsIsReal) {
rhs_ptr_imag = rhs_base + accRows*strideB;
rhs_ptr_imag0 = rhs_base + accRows*strideB;
} else {
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag0);
}
if (accItr > 1) {
if(!RhsIsReal) {
rhs_ptr_real1 = rhs_base + (2*accRows*strideB);
rhs_ptr_imag1 = rhs_base + (3*accRows*strideB);
} else {
rhs_ptr_real1 = rhs_base + accRows*strideB;
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
}
} else {
EIGEN_UNUSED_VARIABLE(rhs_ptr_real1);
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
EIGEN_UNUSED_VARIABLE(res1);
}
if (accItr > 2) {
if(!RhsIsReal) {
rhs_ptr_real2 = rhs_base + (4*accRows*strideB);
rhs_ptr_imag2 = rhs_base + (5*accRows*strideB);
rhs_ptr_real3 = rhs_base + (6*accRows*strideB);
rhs_ptr_imag3 = rhs_base + (7*accRows*strideB);
} else {
rhs_ptr_real2 = rhs_base + (2*accRows*strideB);
rhs_ptr_real3 = rhs_base + (3*accRows*strideB);
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
}
} else {
EIGEN_UNUSED_VARIABLE(rhs_ptr_real2);
EIGEN_UNUSED_VARIABLE(rhs_ptr_real3);
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
EIGEN_UNUSED_VARIABLE(res2);
EIGEN_UNUSED_VARIABLE(res3);
}
const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
@ -651,10 +814,15 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
}
#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
if (M) return;
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
#define MICRO_COMPLEX_MMA_ROWS(n) \
while(row + n*accCols <= rows) { \
MICRO_COMPLEX_MMA_UNROLL_ITER2(n, 0); \
}
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index accItr>
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
const DataMapper& res,
const Scalar* blockA,
@ -671,35 +839,50 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
const Packet& pAlphaImag,
const Packet& pMask)
{
const DataMapper res3 = res.getSubMapper(0, col);
const DataMapper res30 = res.getSubMapper(0, col);
const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows*1) : res30;
const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows*2) : res30;
const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows*3) : res30;
const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
const Scalar* lhs_base = blockA + accCols*offsetA;
Index row = 0;
#define MAX_COMPLEX_MMA_UNROLL 4
while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
MICRO_COMPLEX_MMA_UNROLL_ITER2(MAX_COMPLEX_MMA_UNROLL, 0);
#if MAX_COMPLEX_MMA_UNROLL < 2
if (1) {
#elif MAX_COMPLEX_MMA_UNROLL < 4
if (accItr <= 2) {
#else
if (accItr == 1) {
#endif
MICRO_COMPLEX_MMA_ROWS(MAX_COMPLEX_MMA_UNROLL);
} else if (accItr == 2) {
MICRO_COMPLEX_MMA_ROWS(2);
} else {
MICRO_COMPLEX_MMA_ROWS(1);
}
switch( (rows-row)/accCols ) {
#if MAX_COMPLEX_MMA_UNROLL > 4
case 4:
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 4)
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 3
case 3:
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3)
if (accItr == 1) {
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3)
}
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 2
case 2:
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2)
if (accItr == 1) {
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2)
}
break;
#endif
#if MAX_COMPLEX_MMA_UNROLL > 1
case 1:
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1)
if (accItr <= 2) {
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1)
}
break;
#endif
default:
@ -709,10 +892,16 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
if(remaining_rows > 0)
{
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_EXTRA_ROWS1, 0)
}
}
#define MICRO_COMPLEX_MMA_COLS(n) \
for(; col + n*accRows <= cols; col += n*accRows) \
{ \
gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); \
}
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
@ -731,10 +920,11 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS
typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
#ifdef GEMM_MULTIPLE_COLS
MICRO_COMPLEX_MMA_COLS(4);
MICRO_COMPLEX_MMA_COLS(2);
#endif
MICRO_COMPLEX_MMA_COLS(1);
if (col != cols)
{