mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 20:56:00 +08:00
New panel modes for GEMM MMA (real & complex).
This commit is contained in:
parent
2c64a655fe
commit
4e598ad259
@ -49,11 +49,6 @@
|
||||
#include "MatrixProductMMA.h"
|
||||
#endif
|
||||
|
||||
/**************************************************************************************************
|
||||
* TODO *
|
||||
* - Check StorageOrder on dhs_pack (the innermost second loop seems unvectorized when it could). *
|
||||
* - Check the possibility of transposing as GETREAL and GETIMAG when needed. *
|
||||
**************************************************************************************************/
|
||||
// IWYU pragma: private
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
@ -120,6 +115,16 @@ const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
|
||||
20, 21, 22, 23,
|
||||
28, 29, 30, 31};
|
||||
|
||||
const static Packet16uc p16uc_GETREAL32b = { 0, 1, 2, 3,
|
||||
16, 17, 18, 19,
|
||||
8, 9, 10, 11,
|
||||
24, 25, 26, 27};
|
||||
|
||||
const static Packet16uc p16uc_GETIMAG32b = { 4, 5, 6, 7,
|
||||
20, 21, 22, 23,
|
||||
12, 13, 14, 15,
|
||||
28, 29, 30, 31};
|
||||
|
||||
/*********************************************
|
||||
* Single precision real and complex packing *
|
||||
* *******************************************/
|
||||
@ -440,6 +445,78 @@ EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
|
||||
// General template for lhs & rhs complex packing.
|
||||
template<typename Scalar, typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode, bool UseLhs>
|
||||
struct dhs_cpack {
|
||||
template<bool transpose>
|
||||
EIGEN_ALWAYS_INLINE void dhs_cblock(PacketBlock<PacketC,8>& cblock, PacketBlock<Packet,4>& block, Packet16uc permute)
|
||||
{
|
||||
if (transpose) {
|
||||
block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, permute);
|
||||
block.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, permute);
|
||||
block.packet[2] = vec_perm(cblock.packet[4].v, cblock.packet[5].v, permute);
|
||||
block.packet[3] = vec_perm(cblock.packet[6].v, cblock.packet[7].v, permute);
|
||||
|
||||
Packet4f t0, t1, t2, t3;
|
||||
#ifdef EIGEN_VECTORIZE_VSX
|
||||
t0 = reinterpret_cast<Packet>(vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
|
||||
t1 = reinterpret_cast<Packet>(vec_mergel(reinterpret_cast<Packet2ul>(block.packet[0]), reinterpret_cast<Packet2ul>(block.packet[1])));
|
||||
t2 = reinterpret_cast<Packet>(vec_mergeh(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
|
||||
t3 = reinterpret_cast<Packet>(vec_mergel(reinterpret_cast<Packet2ul>(block.packet[2]), reinterpret_cast<Packet2ul>(block.packet[3])));
|
||||
#else
|
||||
t0 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_HI));
|
||||
t1 = reinterpret_cast<Packet>(vec_perm(block.packet[0], block.packet[1], p16uc_TRANSPOSE64_LO));
|
||||
t2 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_HI));
|
||||
t3 = reinterpret_cast<Packet>(vec_perm(block.packet[2], block.packet[3], p16uc_TRANSPOSE64_LO));
|
||||
#endif
|
||||
|
||||
block.packet[0] = t0;
|
||||
block.packet[1] = t1;
|
||||
block.packet[2] = t2;
|
||||
block.packet[3] = t3;
|
||||
} else {
|
||||
block.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, permute);
|
||||
block.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, permute);
|
||||
block.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, permute);
|
||||
block.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, permute);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_ALWAYS_INLINE void dhs_ccopy(Scalar* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize)
|
||||
{
|
||||
PacketBlock<Packet,4> blockr, blocki;
|
||||
PacketBlock<PacketC,8> cblock;
|
||||
|
||||
for(; i + vectorSize <= depth; i+=vectorSize)
|
||||
{
|
||||
if (UseLhs) {
|
||||
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
|
||||
} else {
|
||||
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
|
||||
}
|
||||
|
||||
if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
|
||||
{
|
||||
dhs_cblock<true>(cblock, blockr, p16uc_GETREAL32b);
|
||||
dhs_cblock<true>(cblock, blocki, p16uc_GETIMAG32b);
|
||||
} else {
|
||||
dhs_cblock<false>(cblock, blockr, p16uc_GETREAL32);
|
||||
dhs_cblock<false>(cblock, blocki, p16uc_GETIMAG32);
|
||||
}
|
||||
|
||||
if(Conjugate)
|
||||
{
|
||||
blocki.packet[0] = -blocki.packet[0];
|
||||
blocki.packet[1] = -blocki.packet[1];
|
||||
blocki.packet[2] = -blocki.packet[2];
|
||||
blocki.packet[3] = -blocki.packet[3];
|
||||
}
|
||||
|
||||
storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
|
||||
storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
|
||||
|
||||
rir += 4*vectorSize;
|
||||
rii += 4*vectorSize;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(std::complex<Scalar>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||
{
|
||||
const Index vectorSize = quad_traits<Scalar>::vectorsize;
|
||||
@ -455,47 +532,8 @@ struct dhs_cpack {
|
||||
|
||||
rii = rir + vectorDelta;
|
||||
|
||||
for(; i + vectorSize <= depth; i+=vectorSize)
|
||||
{
|
||||
PacketBlock<Packet,4> blockr, blocki;
|
||||
PacketBlock<PacketC,8> cblock;
|
||||
dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
|
||||
|
||||
if (UseLhs) {
|
||||
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
|
||||
} else {
|
||||
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
|
||||
}
|
||||
|
||||
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
|
||||
blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETREAL32);
|
||||
blockr.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETREAL32);
|
||||
blockr.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETREAL32);
|
||||
|
||||
blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETIMAG32);
|
||||
blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[5].v, p16uc_GETIMAG32);
|
||||
blocki.packet[2] = vec_perm(cblock.packet[2].v, cblock.packet[6].v, p16uc_GETIMAG32);
|
||||
blocki.packet[3] = vec_perm(cblock.packet[3].v, cblock.packet[7].v, p16uc_GETIMAG32);
|
||||
|
||||
if(Conjugate)
|
||||
{
|
||||
blocki.packet[0] = -blocki.packet[0];
|
||||
blocki.packet[1] = -blocki.packet[1];
|
||||
blocki.packet[2] = -blocki.packet[2];
|
||||
blocki.packet[3] = -blocki.packet[3];
|
||||
}
|
||||
|
||||
if(((StorageOrder == RowMajor) && UseLhs) || (((StorageOrder == ColMajor) && !UseLhs)))
|
||||
{
|
||||
ptranspose(blockr);
|
||||
ptranspose(blocki);
|
||||
}
|
||||
|
||||
storeBlock<Scalar, Packet, 4>(blockAt + rir, blockr);
|
||||
storeBlock<Scalar, Packet, 4>(blockAt + rii, blocki);
|
||||
|
||||
rir += 4*vectorSize;
|
||||
rii += 4*vectorSize;
|
||||
}
|
||||
for(; i < depth; i++)
|
||||
{
|
||||
PacketBlock<Packet,1> blockr, blocki;
|
||||
@ -592,6 +630,36 @@ struct dhs_cpack {
|
||||
// General template for lhs & rhs packing.
|
||||
template<typename Scalar, typename DataMapper, typename Packet, int StorageOrder, bool PanelMode, bool UseLhs>
|
||||
struct dhs_pack{
|
||||
template<Index n>
|
||||
EIGEN_ALWAYS_INLINE void dhs_copy(Scalar* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth, const Index vectorSize)
|
||||
{
|
||||
PacketBlock<Packet,4> block[n];
|
||||
|
||||
for(; i + n*vectorSize <= depth; i+=n*vectorSize)
|
||||
{
|
||||
for (Index k = 0; k < n; k++) {
|
||||
if (UseLhs) {
|
||||
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, 0, i + k*vectorSize);
|
||||
} else {
|
||||
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block[k], lhs2, i + k*vectorSize, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
|
||||
{
|
||||
for (Index k = 0; k < n; k++) {
|
||||
ptranspose(block[k]);
|
||||
}
|
||||
}
|
||||
|
||||
for (Index k = 0; k < n; k++) {
|
||||
storeBlock<Scalar, Packet, 4>(blockA + ri + k*4*vectorSize, block[k]);
|
||||
}
|
||||
|
||||
ri += n*4*vectorSize;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(Scalar* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||
{
|
||||
const Index vectorSize = quad_traits<Scalar>::vectorsize;
|
||||
@ -604,24 +672,10 @@ struct dhs_pack{
|
||||
|
||||
if(PanelMode) ri += vectorSize*offset;
|
||||
|
||||
for(; i + vectorSize <= depth; i+=vectorSize)
|
||||
{
|
||||
PacketBlock<Packet,4> block;
|
||||
dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
|
||||
dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
|
||||
dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
|
||||
|
||||
if (UseLhs) {
|
||||
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs2, 0, i);
|
||||
} else {
|
||||
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs2, i, 0);
|
||||
}
|
||||
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
|
||||
{
|
||||
ptranspose(block);
|
||||
}
|
||||
|
||||
storeBlock<Scalar, Packet, 4>(blockA + ri, block);
|
||||
|
||||
ri += 4*vectorSize;
|
||||
}
|
||||
for(; i < depth; i++)
|
||||
{
|
||||
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
|
||||
@ -691,6 +745,39 @@ struct dhs_pack{
|
||||
template<typename DataMapper, int StorageOrder, bool PanelMode>
|
||||
struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
|
||||
{
|
||||
template<Index n>
|
||||
EIGEN_ALWAYS_INLINE void dhs_copy(double* blockA, const DataMapper& lhs2, Index& i, Index& ri, Index depth, const Index vectorSize)
|
||||
{
|
||||
PacketBlock<Packet2d,2> block[n];
|
||||
|
||||
for(; i + n*vectorSize <= depth; i+=n*vectorSize)
|
||||
{
|
||||
for (Index k = 0; k < n; k++) {
|
||||
if(StorageOrder == RowMajor)
|
||||
{
|
||||
block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k*vectorSize);
|
||||
block[k].packet[1] = lhs2.template loadPacket<Packet2d>(1, i + k*vectorSize);
|
||||
} else {
|
||||
block[k].packet[0] = lhs2.template loadPacket<Packet2d>(0, i + k*vectorSize + 0);
|
||||
block[k].packet[1] = lhs2.template loadPacket<Packet2d>(0, i + k*vectorSize + 1);
|
||||
}
|
||||
}
|
||||
|
||||
if(StorageOrder == RowMajor)
|
||||
{
|
||||
for (Index k = 0; k < n; k++) {
|
||||
ptranspose(block[k]);
|
||||
}
|
||||
}
|
||||
|
||||
for (Index k = 0; k < n; k++) {
|
||||
storeBlock<double, Packet2d, 2>(blockA + ri + k*2*vectorSize, block[k]);
|
||||
}
|
||||
|
||||
ri += n*2*vectorSize;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(double* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||
{
|
||||
const Index vectorSize = quad_traits<double>::vectorsize;
|
||||
@ -703,24 +790,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
|
||||
|
||||
if(PanelMode) ri += vectorSize*offset;
|
||||
|
||||
for(; i + vectorSize <= depth; i+=vectorSize)
|
||||
{
|
||||
PacketBlock<Packet2d,2> block;
|
||||
if(StorageOrder == RowMajor)
|
||||
{
|
||||
block.packet[0] = lhs2.template loadPacket<Packet2d>(0, i);
|
||||
block.packet[1] = lhs2.template loadPacket<Packet2d>(1, i);
|
||||
dhs_copy<4>(blockA, lhs2, i, ri, depth, vectorSize);
|
||||
dhs_copy<2>(blockA, lhs2, i, ri, depth, vectorSize);
|
||||
dhs_copy<1>(blockA, lhs2, i, ri, depth, vectorSize);
|
||||
|
||||
ptranspose(block);
|
||||
} else {
|
||||
block.packet[0] = lhs2.template loadPacket<Packet2d>(0, i + 0);
|
||||
block.packet[1] = lhs2.template loadPacket<Packet2d>(0, i + 1);
|
||||
}
|
||||
|
||||
storeBlock<double, Packet2d, 2>(blockA + ri, block);
|
||||
|
||||
ri += 2*vectorSize;
|
||||
}
|
||||
for(; i < depth; i++)
|
||||
{
|
||||
if(StorageOrder == RowMajor)
|
||||
@ -759,6 +832,53 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
|
||||
template<typename DataMapper, int StorageOrder, bool PanelMode>
|
||||
struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
|
||||
{
|
||||
template<Index n>
|
||||
EIGEN_ALWAYS_INLINE void dhs_copy(double* blockB, const DataMapper& rhs2, Index& i, Index& ri, Index depth, const Index vectorSize)
|
||||
{
|
||||
PacketBlock<Packet2d,2> block1[n], block2[n];
|
||||
PacketBlock<Packet2d,4> block3[n];
|
||||
|
||||
for(; i + n*vectorSize <= depth; i+=n*vectorSize)
|
||||
{
|
||||
for (Index k = 0; k < n; k++) {
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
block1[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 0);
|
||||
block1[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 1);
|
||||
block2[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 2);
|
||||
block2[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize, 3);
|
||||
} else {
|
||||
block3[k].packet[0] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 0, 0); //[a1 a2]
|
||||
block3[k].packet[1] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 0, 2); //[a3 a4]
|
||||
block3[k].packet[2] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 1, 0); //[b1 b2]
|
||||
block3[k].packet[3] = rhs2.template loadPacket<Packet2d>(i + k*vectorSize + 1, 2); //[b3 b4]
|
||||
}
|
||||
}
|
||||
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
for (Index k = 0; k < n; k++) {
|
||||
ptranspose(block1[k]);
|
||||
ptranspose(block2[k]);
|
||||
}
|
||||
}
|
||||
|
||||
for (Index k = 0; k < n; k++) {
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
pstore<double>(blockB + ri + k*4*vectorSize , block1[k].packet[0]);
|
||||
pstore<double>(blockB + ri + k*4*vectorSize + 2, block2[k].packet[0]);
|
||||
pstore<double>(blockB + ri + k*4*vectorSize + 4, block1[k].packet[1]);
|
||||
pstore<double>(blockB + ri + k*4*vectorSize + 6, block2[k].packet[1]);
|
||||
} else {
|
||||
storeBlock<double, Packet2d, 4>(blockB + ri + k*4*vectorSize, block3[k]);
|
||||
}
|
||||
}
|
||||
|
||||
ri += n*4*vectorSize;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(double* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||
{
|
||||
const Index vectorSize = quad_traits<double>::vectorsize;
|
||||
@ -771,35 +891,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
|
||||
|
||||
if(PanelMode) ri += offset*(2*vectorSize);
|
||||
|
||||
for(; i + vectorSize <= depth; i+=vectorSize)
|
||||
{
|
||||
PacketBlock<Packet2d,4> block;
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
PacketBlock<Packet2d,2> block1, block2;
|
||||
block1.packet[0] = rhs2.template loadPacket<Packet2d>(i, 0);
|
||||
block1.packet[1] = rhs2.template loadPacket<Packet2d>(i, 1);
|
||||
block2.packet[0] = rhs2.template loadPacket<Packet2d>(i, 2);
|
||||
block2.packet[1] = rhs2.template loadPacket<Packet2d>(i, 3);
|
||||
dhs_copy<4>(blockB, rhs2, i, ri, depth, vectorSize);
|
||||
dhs_copy<2>(blockB, rhs2, i, ri, depth, vectorSize);
|
||||
dhs_copy<1>(blockB, rhs2, i, ri, depth, vectorSize);
|
||||
|
||||
ptranspose(block1);
|
||||
ptranspose(block2);
|
||||
|
||||
pstore<double>(blockB + ri , block1.packet[0]);
|
||||
pstore<double>(blockB + ri + 2, block2.packet[0]);
|
||||
pstore<double>(blockB + ri + 4, block1.packet[1]);
|
||||
pstore<double>(blockB + ri + 6, block2.packet[1]);
|
||||
} else {
|
||||
block.packet[0] = rhs2.template loadPacket<Packet2d>(i + 0, 0); //[a1 a2]
|
||||
block.packet[1] = rhs2.template loadPacket<Packet2d>(i + 0, 2); //[a3 a4]
|
||||
block.packet[2] = rhs2.template loadPacket<Packet2d>(i + 1, 0); //[b1 b2]
|
||||
block.packet[3] = rhs2.template loadPacket<Packet2d>(i + 1, 2); //[b3 b4]
|
||||
|
||||
storeBlock<double, Packet2d, 4>(blockB + ri, block);
|
||||
}
|
||||
|
||||
ri += 4*vectorSize;
|
||||
}
|
||||
for(; i < depth; i++)
|
||||
{
|
||||
if(StorageOrder == ColMajor)
|
||||
@ -1296,6 +1391,54 @@ struct dhs_pack<bfloat16, DataMapper, Packet8bf, StorageOrder, PanelMode, false>
|
||||
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
|
||||
struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, true>
|
||||
{
|
||||
EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockAt, const DataMapper& lhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize)
|
||||
{
|
||||
PacketBlock<Packet,2> blockr, blocki;
|
||||
PacketBlock<PacketC,4> cblock;
|
||||
|
||||
for(; i + vectorSize <= depth; i+=vectorSize)
|
||||
{
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
|
||||
cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
|
||||
|
||||
cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
|
||||
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
|
||||
|
||||
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
|
||||
blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
|
||||
|
||||
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
|
||||
blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
|
||||
} else {
|
||||
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
|
||||
cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
|
||||
|
||||
cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
|
||||
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
|
||||
|
||||
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
|
||||
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
|
||||
|
||||
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
|
||||
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
|
||||
}
|
||||
|
||||
if(Conjugate)
|
||||
{
|
||||
blocki.packet[0] = -blocki.packet[0];
|
||||
blocki.packet[1] = -blocki.packet[1];
|
||||
}
|
||||
|
||||
storeBlock<double, Packet, 2>(blockAt + rir, blockr);
|
||||
storeBlock<double, Packet, 2>(blockAt + rii, blocki);
|
||||
|
||||
rir += 2*vectorSize;
|
||||
rii += 2*vectorSize;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockA, const DataMapper& lhs, Index depth, Index rows, Index stride, Index offset)
|
||||
{
|
||||
const Index vectorSize = quad_traits<double>::vectorsize;
|
||||
@ -1311,50 +1454,8 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
|
||||
|
||||
rii = rir + vectorDelta;
|
||||
|
||||
for(; i + vectorSize <= depth; i+=vectorSize)
|
||||
{
|
||||
PacketBlock<Packet,2> blockr, blocki;
|
||||
PacketBlock<PacketC,4> cblock;
|
||||
dhs_ccopy(blockAt, lhs2, i, rir, rii, depth, vectorSize);
|
||||
|
||||
if(StorageOrder == ColMajor)
|
||||
{
|
||||
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
|
||||
cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
|
||||
|
||||
cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
|
||||
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
|
||||
|
||||
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
|
||||
blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
|
||||
|
||||
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
|
||||
blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
|
||||
} else {
|
||||
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
|
||||
cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
|
||||
|
||||
cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
|
||||
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
|
||||
|
||||
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
|
||||
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
|
||||
|
||||
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
|
||||
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
|
||||
}
|
||||
|
||||
if(Conjugate)
|
||||
{
|
||||
blocki.packet[0] = -blocki.packet[0];
|
||||
blocki.packet[1] = -blocki.packet[1];
|
||||
}
|
||||
|
||||
storeBlock<double, Packet, 2>(blockAt + rir, blockr);
|
||||
storeBlock<double, Packet, 2>(blockAt + rii, blocki);
|
||||
|
||||
rir += 2*vectorSize;
|
||||
rii += 2*vectorSize;
|
||||
}
|
||||
for(; i < depth; i++)
|
||||
{
|
||||
PacketBlock<Packet,1> blockr, blocki;
|
||||
@ -1410,6 +1511,35 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
|
||||
template<typename DataMapper, typename Packet, typename PacketC, int StorageOrder, bool Conjugate, bool PanelMode>
|
||||
struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, PanelMode, false>
|
||||
{
|
||||
EIGEN_ALWAYS_INLINE void dhs_ccopy(double* blockBt, const DataMapper& rhs2, Index& i, Index& rir, Index& rii, Index depth, const Index vectorSize)
|
||||
{
|
||||
for(; i < depth; i++)
|
||||
{
|
||||
PacketBlock<PacketC,4> cblock;
|
||||
PacketBlock<Packet,2> blockr, blocki;
|
||||
|
||||
bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
|
||||
|
||||
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
|
||||
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
|
||||
|
||||
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
|
||||
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
|
||||
|
||||
if(Conjugate)
|
||||
{
|
||||
blocki.packet[0] = -blocki.packet[0];
|
||||
blocki.packet[1] = -blocki.packet[1];
|
||||
}
|
||||
|
||||
storeBlock<double, Packet, 2>(blockBt + rir, blockr);
|
||||
storeBlock<double, Packet, 2>(blockBt + rii, blocki);
|
||||
|
||||
rir += 2*vectorSize;
|
||||
rii += 2*vectorSize;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void operator()(std::complex<double>* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride, Index offset)
|
||||
{
|
||||
const Index vectorSize = quad_traits<double>::vectorsize;
|
||||
@ -1425,31 +1555,7 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
|
||||
|
||||
rii = rir + vectorDelta;
|
||||
|
||||
for(; i < depth; i++)
|
||||
{
|
||||
PacketBlock<PacketC,4> cblock;
|
||||
PacketBlock<Packet,2> blockr, blocki;
|
||||
|
||||
bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
|
||||
|
||||
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
|
||||
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
|
||||
|
||||
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
|
||||
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
|
||||
|
||||
if(Conjugate)
|
||||
{
|
||||
blocki.packet[0] = -blocki.packet[0];
|
||||
blocki.packet[1] = -blocki.packet[1];
|
||||
}
|
||||
|
||||
storeBlock<double, Packet, 2>(blockBt + rir, blockr);
|
||||
storeBlock<double, Packet, 2>(blockBt + rii, blocki);
|
||||
|
||||
rir += 2*vectorSize;
|
||||
rii += 2*vectorSize;
|
||||
}
|
||||
dhs_ccopy(blockBt, rhs2, i, rir, rii, depth, vectorSize);
|
||||
|
||||
rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
|
||||
}
|
||||
|
@ -42,19 +42,13 @@ EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
|
||||
__builtin_mma_xxsetaccz(acc);
|
||||
}
|
||||
|
||||
#ifdef USE_PARTIAL_PACKETS
|
||||
template<typename DataMapper, typename Packet, bool full>
|
||||
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Index elements, __vector_quad* acc)
|
||||
#else
|
||||
template<typename DataMapper, typename Packet, const Index accCols, const Index accCols2>
|
||||
EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, const Packet& pMask, __vector_quad* acc)
|
||||
#endif
|
||||
{
|
||||
PacketBlock<Packet, 4> result;
|
||||
__builtin_mma_disassemble_acc(&result.packet, acc);
|
||||
|
||||
PacketBlock<Packet, 4> tRes;
|
||||
#ifdef USE_PARTIAL_PACKETS
|
||||
if (full) {
|
||||
EIGEN_UNUSED_VARIABLE(elements);
|
||||
bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
|
||||
@ -65,11 +59,6 @@ EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const
|
||||
bscale<Packet, 4>(tRes, result, alpha);
|
||||
bstore_partial<DataMapper, Packet, 4>(tRes, data, i, elements);
|
||||
}
|
||||
#else
|
||||
bload<DataMapper, Packet, 0, ColMajor, false, 4>(tRes, data, i, 0);
|
||||
bscale<Packet, 4, (accCols != accCols2)>(tRes, result, alpha, pMask);
|
||||
bstore<DataMapper, Packet, 4>(tRes, data, i);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename DataMapper, typename Packet, typename Packetc, const Index accCols, const Index accCols2>
|
||||
@ -166,78 +155,118 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
||||
ploadRhsMMA(lhs, lhsV);
|
||||
}
|
||||
|
||||
#if (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
|
||||
#define GEMM_MULTIPLE_COLS
|
||||
|
||||
// Disable in GCC until unnecessary register moves are fixed
|
||||
//#if (EIGEN_COMP_LLVM || (__GNUC__ >= 11))
|
||||
#if EIGEN_COMP_LLVM
|
||||
#define VECTOR_PAIR_LOADS_LHS
|
||||
#endif
|
||||
|
||||
// PEEL_MMA loop factor.
|
||||
#ifdef GEMM_MULTIPLE_COLS
|
||||
#define PEEL_MMA 8
|
||||
#else
|
||||
// Register spillage with GCC12+
|
||||
#if EIGEN_COMP_LLVM || (__GNUC__ < 12) || defined(VECTOR_PAIR_LOADS_LHS)
|
||||
#define PEEL_MMA 7
|
||||
#else
|
||||
#define PEEL_MMA 6
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#define MICRO_MMA_UNROLL(func) \
|
||||
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
|
||||
|
||||
#define MICRO_MMA_WORK(func, type, peel) \
|
||||
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
|
||||
func(4,type,peel) func(5,type,peel) func(6,type,peel) func(7,type,peel)
|
||||
if (accItr == 1) { \
|
||||
func(0,type,peel,0,0) func(1,type,peel,1,0) func(2,type,peel,2,0) func(3,type,peel,3,0) \
|
||||
func(4,type,peel,4,0) func(5,type,peel,5,0) func(6,type,peel,6,0) func(7,type,peel,7,0) \
|
||||
} else if (accItr == 2) { \
|
||||
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,1,0) func(3,type,peel,1,1) \
|
||||
func(4,type,peel,2,0) func(5,type,peel,2,1) func(6,type,peel,3,0) func(7,type,peel,3,1) \
|
||||
} else { \
|
||||
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,0,2) func(3,type,peel,0,3) \
|
||||
func(4,type,peel,1,0) func(5,type,peel,1,1) func(6,type,peel,1,2) func(7,type,peel,1,3) \
|
||||
}
|
||||
|
||||
#define MICRO_MMA_WORK_ONE(iter, type, peel) \
|
||||
if (unroll_factor > iter) { \
|
||||
pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV##iter); \
|
||||
#define MICRO_MMA_WORK_ONE(iter, type, peel, left, right) \
|
||||
if (unroll_factor > left) { \
|
||||
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV##left); \
|
||||
}
|
||||
|
||||
#ifdef VECTOR_PAIR_LOADS_LHS
|
||||
#define MICRO_MMA_WORK_TWO(iter, type, peel) \
|
||||
if (unroll_factor > iter) { \
|
||||
pgerMMA<Packet, type, false>(&accZero##iter, rhsV[peel], lhsV2##iter.packet[peel & 1]); \
|
||||
#define MICRO_MMA_WORK_TWO(iter, type, peel, left, right) \
|
||||
if (unroll_factor > left) { \
|
||||
pgerMMA<Packet, type, false>(&accZero##iter, rhsV##right[peel], lhsV2##left.packet[peel & 1]); \
|
||||
}
|
||||
|
||||
#define MICRO_MMA_LOAD1_TWO(lhs_ptr, iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
if (MICRO_NORMAL(iter)) { \
|
||||
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##iter), plhsV##iter); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##iter.packet), &plhsV##iter); \
|
||||
lhs_ptr##iter += accCols*2; \
|
||||
#define MICRO_MMA_LOAD1_TWO(lhs_ptr, left) \
|
||||
if (unroll_factor > left) { \
|
||||
if (MICRO_NORMAL(left)) { \
|
||||
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr##left), plhsV##left); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsV2##left.packet), &plhsV##left); \
|
||||
lhs_ptr##left += accCols*2; \
|
||||
} else { \
|
||||
lhsV2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr##iter); \
|
||||
lhsV2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr##iter + accCols2); \
|
||||
lhs_ptr##iter += accCols2*2; \
|
||||
EIGEN_UNUSED_VARIABLE(plhsV##iter) \
|
||||
lhsV2##left.packet[0] = ploadLhs<Packet>(lhs_ptr##left); \
|
||||
lhsV2##left.packet[1] = ploadLhs<Packet>(lhs_ptr##left + accCols2); \
|
||||
lhs_ptr##left += accCols2*2; \
|
||||
EIGEN_UNUSED_VARIABLE(plhsV##left); \
|
||||
} \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(lhsV2##iter); \
|
||||
EIGEN_UNUSED_VARIABLE(plhsV##iter) \
|
||||
EIGEN_UNUSED_VARIABLE(lhsV2##left); \
|
||||
EIGEN_UNUSED_VARIABLE(plhsV##left); \
|
||||
}
|
||||
|
||||
#define MICRO_MMA_LOAD_TWO(iter) MICRO_MMA_LOAD1_TWO(lhs_ptr, iter)
|
||||
#define MICRO_MMA_LOAD_TWO(left) MICRO_MMA_LOAD1_TWO(lhs_ptr, left)
|
||||
#endif
|
||||
|
||||
#define MICRO_MMA_UNROLL_ITER(func, val) \
|
||||
func(val,0) \
|
||||
if (accItr > 1) { \
|
||||
func(val,1) \
|
||||
if (accItr > 2) { \
|
||||
func(val,2) \
|
||||
func(val,3) \
|
||||
} \
|
||||
}
|
||||
|
||||
#define MICRO_MMA_LOAD_ONE_RHS1(peel, right) \
|
||||
ploadRhsMMA(rhs_ptr##right + (accRows * peel), rhsV##right[peel]);
|
||||
|
||||
#define MICRO_MMA_LOAD_ONE_RHS(peel) \
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_ONE_RHS1, peel)
|
||||
|
||||
#define MICRO_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
|
||||
if (PEEL_MMA > peel) { \
|
||||
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
|
||||
ploadRhsMMA(rhs_ptr + (accRows * peel), rhsV[peel]); \
|
||||
MICRO_MMA_LOAD_ONE_RHS(peel) \
|
||||
MICRO_MMA_UNROLL(funcl) \
|
||||
MICRO_MMA_WORK(funcw, type, peel) \
|
||||
}
|
||||
|
||||
#ifndef VECTOR_PAIR_LOADS_LHS
|
||||
#define MICRO_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
|
||||
type rhsV[8]; \
|
||||
type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
|
||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,1) \
|
||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,3) \
|
||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,4) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,5) \
|
||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,6) MICRO_MMA_TYPE_PEEL(funcw,funcl,type,7)
|
||||
#else
|
||||
#define MICRO_MMA_LOAD_TWO_RHS(peel1, right) \
|
||||
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr##right + (accRows * peel1)), prhsV##peel1); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1);
|
||||
|
||||
#define MICRO_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
|
||||
if (PEEL_MMA > peel2) { \
|
||||
PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23, lhsV24, lhsV25, lhsV26, lhsV27; \
|
||||
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3, plhsV4, plhsV5, plhsV6, plhsV7; \
|
||||
if (sizeof(type) == 16) { \
|
||||
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr + (accRows * peel1)), prhsV##peel1); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_MMA_LOAD_TWO_RHS, peel1) \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
|
||||
ploadRhsMMA(rhs_ptr + (accRows * peel1), rhsV[peel1]); \
|
||||
ploadRhsMMA(rhs_ptr + (accRows * peel2), rhsV[peel2]); \
|
||||
MICRO_MMA_LOAD_ONE_RHS(peel1) \
|
||||
MICRO_MMA_LOAD_ONE_RHS(peel2) \
|
||||
} \
|
||||
MICRO_MMA_UNROLL(funcl2) \
|
||||
MICRO_MMA_WORK(funcw2, type, peel1) \
|
||||
@ -248,7 +277,7 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
||||
}
|
||||
|
||||
#define MICRO_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
|
||||
type rhsV[8]; \
|
||||
type rhsV0[8], rhsV1[(accItr > 1) ? 8 : 1], rhsV2[(accItr > 2) ? 8 : 1], rhsV3[(accItr > 2) ? 8 : 1]; \
|
||||
__vector_pair prhsV0, prhsV2, prhsV4, prhsV6; \
|
||||
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
|
||||
MICRO_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,2,3) \
|
||||
@ -257,19 +286,25 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
||||
#endif
|
||||
|
||||
#define MICRO_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
|
||||
type rhsV[1]; \
|
||||
type rhsV0[1], rhsV1[1], rhsV2[1], rhsV3[1]; \
|
||||
MICRO_MMA_TYPE_PEEL(funcw,funcl,type,0)
|
||||
|
||||
#define MICRO_MMA_UPDATE_RHS1(size, right) \
|
||||
rhs_ptr##right += (accRows * size);
|
||||
|
||||
#define MICRO_MMA_UPDATE_RHS(size) \
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_MMA_UPDATE_RHS1, size)
|
||||
|
||||
#define MICRO_MMA_UNROLL_TYPE(MICRO_MMA_TYPE, size) \
|
||||
MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, RhsPacket) \
|
||||
rhs_ptr += (accRows * size);
|
||||
MICRO_MMA_UPDATE_RHS(size)
|
||||
|
||||
#ifndef VECTOR_PAIR_LOADS_LHS
|
||||
#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_PEEL, PEEL_MMA)
|
||||
#else
|
||||
#define MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_TYPE, size) \
|
||||
MICRO_MMA_TYPE(MICRO_MMA_WORK_ONE, MICRO_LOAD_ONE, MICRO_MMA_WORK_TWO, MICRO_MMA_LOAD_TWO, RhsPacket) \
|
||||
rhs_ptr += (accRows * size);
|
||||
MICRO_MMA_UPDATE_RHS(size)
|
||||
|
||||
#define MICRO_MMA_ONE_PEEL MICRO_MMA_UNROLL_TYPE2(MICRO_MMA_UNROLL_TYPE_PEEL2, PEEL_MMA)
|
||||
#endif
|
||||
@ -277,7 +312,7 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
||||
#define MICRO_MMA_ONE MICRO_MMA_UNROLL_TYPE(MICRO_MMA_UNROLL_TYPE_ONE, 1)
|
||||
|
||||
#define MICRO_MMA_DST_PTR_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
if (unroll_factor * accItr > iter) { \
|
||||
bsetzeroMMA(&accZero##iter); \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(accZero##iter); \
|
||||
@ -289,45 +324,69 @@ EIGEN_ALWAYS_INLINE void ploadLhsMMA(const double* lhs, __vector_pair& lhsV)
|
||||
|
||||
#define MICRO_MMA_PREFETCH MICRO_MMA_UNROLL(MICRO_PREFETCH_ONE)
|
||||
|
||||
#ifdef USE_PARTIAL_PACKETS
|
||||
#define MICRO_MMA_STORE_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(iter)>(row + iter*accCols, res, pAlpha, accCols2, &accZero##iter); \
|
||||
#define MICRO_MMA_STORE_ONE(iter, left, right) \
|
||||
if (unroll_factor > left) { \
|
||||
storeAccumulator<DataMapper, Packet, MICRO_NORMAL_PARTIAL(left)>(row + left*accCols, res##right, pAlpha, accCols2, &accZero##iter); \
|
||||
}
|
||||
#else
|
||||
#define MICRO_MMA_STORE_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
storeAccumulator<DataMapper, Packet, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlpha, pMask, &accZero##iter); \
|
||||
|
||||
#define MICRO_MMA_ITER_UNROLL(func) \
|
||||
if (accItr == 1) { \
|
||||
func(0,0,0) func(1,1,0) func(2,2,0) func(3,3,0) \
|
||||
func(4,4,0) func(5,5,0) func(6,6,0) func(7,7,0) \
|
||||
} else if (accItr == 2) { \
|
||||
func(0,0,0) func(1,0,1) func(2,1,0) func(3,1,1) \
|
||||
func(4,2,0) func(5,2,1) func(6,3,0) func(7,3,1) \
|
||||
} else { \
|
||||
func(0,0,0) func(1,0,1) func(2,0,2) func(3,0,3) \
|
||||
func(4,1,0) func(5,1,1) func(6,1,2) func(7,1,3) \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
|
||||
#define MICRO_MMA_STORE MICRO_MMA_ITER_UNROLL(MICRO_MMA_STORE_ONE)
|
||||
|
||||
#ifdef USE_PARTIAL_PACKETS
|
||||
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool full>
|
||||
#else
|
||||
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2>
|
||||
#endif
|
||||
#define MICRO_MMA_EXTRA_ROWS(right) \
|
||||
gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3##right, blockA, rhs_base + right*accRows*strideB, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
|
||||
|
||||
#define MICRO_MMA_EXTRA_ROWS1(val, right) \
|
||||
MICRO_MMA_EXTRA_ROWS(right);
|
||||
|
||||
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool full, const Index accItr>
|
||||
EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
|
||||
const DataMapper& res,
|
||||
const DataMapper& res0,
|
||||
const DataMapper& res1,
|
||||
const DataMapper& res2,
|
||||
const DataMapper& res3,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
Index depth,
|
||||
Index strideA,
|
||||
Index strideB,
|
||||
Index offsetA,
|
||||
Index& row,
|
||||
const Packet& pAlpha,
|
||||
#ifdef USE_PARTIAL_PACKETS
|
||||
Index accCols2
|
||||
#else
|
||||
const Packet& pMask
|
||||
#endif
|
||||
)
|
||||
{
|
||||
const Scalar* rhs_ptr = rhs_base;
|
||||
const Scalar* rhs_ptr0 = rhs_base, * rhs_ptr1 = NULL, * rhs_ptr2 = NULL, * rhs_ptr3 = NULL;
|
||||
const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
|
||||
__vector_quad accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
|
||||
|
||||
if (accItr > 1) {
|
||||
rhs_ptr1 = rhs_base + (accRows * strideB);
|
||||
} else {
|
||||
EIGEN_UNUSED_VARIABLE(strideB);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr1);
|
||||
EIGEN_UNUSED_VARIABLE(res1);
|
||||
}
|
||||
if (accItr > 2) {
|
||||
rhs_ptr2 = rhs_base + (2 * accRows * strideB);
|
||||
rhs_ptr3 = rhs_base + (3 * accRows * strideB);
|
||||
} else {
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr2);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr3);
|
||||
EIGEN_UNUSED_VARIABLE(res2);
|
||||
EIGEN_UNUSED_VARIABLE(res3);
|
||||
}
|
||||
|
||||
MICRO_MMA_SRC_PTR
|
||||
MICRO_MMA_DST_PTR
|
||||
|
||||
@ -347,17 +406,16 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
|
||||
MICRO_UPDATE
|
||||
}
|
||||
|
||||
#ifdef USE_PARTIAL_PACKETS
|
||||
#define MICRO_MMA_UNROLL_ITER2(N, M) \
|
||||
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, M ? remaining_rows : accCols); \
|
||||
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, !M, accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, strideB, offsetA, row, pAlpha, M ? remaining_rows : accCols); \
|
||||
if (M) return;
|
||||
#else
|
||||
#define MICRO_MMA_UNROLL_ITER2(N, M) \
|
||||
gemm_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols>(res3, lhs_base, rhs_base, depth, strideA, offsetA, row, pAlpha, pMask); \
|
||||
if (M) return;
|
||||
#endif
|
||||
|
||||
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||
#define MICRO_MMA_ROWS(n) \
|
||||
while(row + n*accCols <= rows) { \
|
||||
MICRO_MMA_UNROLL_ITER2(n, 0); \
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accItr>
|
||||
EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
||||
const DataMapper& res,
|
||||
const Scalar* blockA,
|
||||
@ -373,45 +431,71 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
||||
const Packet& pAlpha,
|
||||
const Packet& pMask)
|
||||
{
|
||||
const DataMapper res3 = res.getSubMapper(0, col);
|
||||
const DataMapper res30 = res.getSubMapper(0, col);
|
||||
const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows*1) : res30;
|
||||
const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows*2) : res30;
|
||||
const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows*3) : res30;
|
||||
|
||||
const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
|
||||
const Scalar* lhs_base = blockA + accCols*offsetA;
|
||||
Index row = 0;
|
||||
|
||||
#define MAX_MMA_UNROLL 7
|
||||
while(row + MAX_MMA_UNROLL*accCols <= rows) {
|
||||
MICRO_MMA_UNROLL_ITER2(MAX_MMA_UNROLL, 0);
|
||||
|
||||
#if MAX_MMA_UNROLL < 2
|
||||
if (1) {
|
||||
#elif MAX_MMA_UNROLL < 4
|
||||
if (accItr <= 2) {
|
||||
#else
|
||||
if (accItr == 1) {
|
||||
#endif
|
||||
MICRO_MMA_ROWS(MAX_MMA_UNROLL);
|
||||
} else if (accItr == 2) {
|
||||
MICRO_MMA_ROWS(4);
|
||||
} else {
|
||||
MICRO_MMA_ROWS(2);
|
||||
}
|
||||
switch( (rows-row)/accCols ) {
|
||||
#if MAX_MMA_UNROLL > 7
|
||||
case 7:
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7)
|
||||
if (accItr == 1) {
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 7)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 6
|
||||
case 6:
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6)
|
||||
if (accItr == 1) {
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 6)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 5
|
||||
case 5:
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5)
|
||||
if (accItr == 1) {
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 5)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 4
|
||||
case 4:
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4)
|
||||
if (accItr == 1) {
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 4)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 3
|
||||
case 3:
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3)
|
||||
if (accItr <= 2) {
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 3)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 2
|
||||
case 2:
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2)
|
||||
if (accItr <= 2) {
|
||||
MICRO_UNROLL_ITER(MICRO_MMA_UNROLL_ITER2, 2)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_MMA_UNROLL > 1
|
||||
@ -426,10 +510,16 @@ EIGEN_ALWAYS_INLINE void gemmMMA_cols(
|
||||
|
||||
if(remaining_rows > 0)
|
||||
{
|
||||
gemm_extra_row<Scalar, Packet, DataMapper, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlpha, pMask);
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_MMA_EXTRA_ROWS1, 0)
|
||||
}
|
||||
}
|
||||
|
||||
#define MICRO_MMA_COLS(n) \
|
||||
for(; col + n*accRows <= cols; col += n*accRows) \
|
||||
{ \
|
||||
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask); \
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
|
||||
void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
@ -444,10 +534,11 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
|
||||
|
||||
Index col = 0;
|
||||
for(; col + accRows <= cols; col += accRows)
|
||||
{
|
||||
gemmMMA_cols<Scalar, Packet, RhsPacket2, DataMapper, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlpha, pMask);
|
||||
}
|
||||
#ifdef GEMM_MULTIPLE_COLS
|
||||
MICRO_MMA_COLS(4);
|
||||
MICRO_MMA_COLS(2);
|
||||
#endif
|
||||
MICRO_MMA_COLS(1);
|
||||
|
||||
if (col != cols)
|
||||
{
|
||||
@ -459,62 +550,88 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
#define advanceCols ((RhsIsReal) ? 1 : 2)
|
||||
|
||||
// PEEL_COMPLEX_MMA loop factor.
|
||||
#ifdef GEMM_MULTIPLE_COLS
|
||||
#define PEEL_COMPLEX_MMA 4
|
||||
#else
|
||||
#define PEEL_COMPLEX_MMA 3
|
||||
#endif
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL(func) \
|
||||
func(0) func(1) func(2) func(3)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_WORK(func, type, peel) \
|
||||
func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel)
|
||||
if (accItr == 1) { \
|
||||
func(0,type,peel,0,0) func(1,type,peel,1,0) func(2,type,peel,2,0) func(3,type,peel,3,0) \
|
||||
} else if (accItr == 2) { \
|
||||
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,1,0) func(3,type,peel,1,1) \
|
||||
} else { \
|
||||
func(0,type,peel,0,0) func(1,type,peel,0,1) func(2,type,peel,0,2) func(3,type,peel,0,3) \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel) \
|
||||
if (unroll_factor > iter) { \
|
||||
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV[peel], rhsVi[peel]); \
|
||||
#define MICRO_COMPLEX_MMA_WORK_ONE(iter, type, peel, left, right) \
|
||||
if (unroll_factor > left) { \
|
||||
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##left, lhsVi##left, rhsV##right[peel], rhsVi##right[peel]); \
|
||||
}
|
||||
|
||||
#ifdef VECTOR_PAIR_LOADS_LHS
|
||||
#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel) \
|
||||
if (unroll_factor > iter) { \
|
||||
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##iter.packet[peel & 1], lhsVi2##iter.packet[peel & 1], rhsV[peel], rhsVi[peel]); \
|
||||
#define MICRO_COMPLEX_MMA_WORK_TWO(iter, type, peel, left, right) \
|
||||
if (unroll_factor > left) { \
|
||||
pgercMMA<Packet, type, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV2##left.packet[peel & 1], lhsVi2##left.packet[peel & 1], rhsV##right[peel], rhsVi##right[peel]); \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter) \
|
||||
if (!LhsIsReal && (unroll_factor > iter)) { \
|
||||
if (MICRO_NORMAL(iter)) { \
|
||||
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##iter + imag_delta), plhsVi##iter); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##iter.packet), &plhsVi##iter); \
|
||||
#define MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left) \
|
||||
if (!LhsIsReal && (unroll_factor > left)) { \
|
||||
if (MICRO_NORMAL(left)) { \
|
||||
ploadLhsMMA(reinterpret_cast<const double*>(lhs_ptr_real##left + imag_delta), plhsVi##left); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&lhsVi2##left.packet), &plhsVi##left); \
|
||||
} else { \
|
||||
lhsVi2##iter.packet[0] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2); \
|
||||
lhsVi2##iter.packet[1] = ploadLhs<Packet>(lhs_ptr_real##iter + imag_delta2 + accCols2); \
|
||||
EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
|
||||
lhsVi2##left.packet[0] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2); \
|
||||
lhsVi2##left.packet[1] = ploadLhs<Packet>(lhs_ptr_real##left + imag_delta2 + accCols2); \
|
||||
EIGEN_UNUSED_VARIABLE(plhsVi##left); \
|
||||
} \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(lhsVi2##iter); \
|
||||
EIGEN_UNUSED_VARIABLE(plhsVi##iter) \
|
||||
EIGEN_UNUSED_VARIABLE(lhsVi2##left); \
|
||||
EIGEN_UNUSED_VARIABLE(plhsVi##left); \
|
||||
} \
|
||||
MICRO_MMA_LOAD1_TWO(lhs_ptr_real, iter)
|
||||
MICRO_MMA_LOAD1_TWO(lhs_ptr_real, left)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_LOAD_TWO(iter) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, iter)
|
||||
#define MICRO_COMPLEX_MMA_LOAD_TWO(left) MICRO_COMPLEX_MMA_LOAD1_TWO(lhs_ptr, left)
|
||||
#endif
|
||||
|
||||
#define MICRO_COMPLEX_MMA_LOAD_RHS1(peel, right) \
|
||||
ploadRhsMMA(rhs_ptr_real##right + (accRows * peel), rhsV##right[peel]); \
|
||||
if (!RhsIsReal) { \
|
||||
ploadRhsMMA(rhs_ptr_imag##right + (accRows * peel), rhsVi##right[peel]); \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_RHS1, peel)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_TYPE_PEEL(funcw, funcl, type, peel) \
|
||||
if (PEEL_COMPLEX_MMA > peel) { \
|
||||
Packet lhsV0, lhsV1, lhsV2, lhsV3; \
|
||||
Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
|
||||
ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV[peel]); \
|
||||
if(!RhsIsReal) { \
|
||||
ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi[peel]); \
|
||||
} \
|
||||
MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel) \
|
||||
MICRO_COMPLEX_MMA_UNROLL(funcl) \
|
||||
MICRO_COMPLEX_MMA_WORK(funcw, type, peel) \
|
||||
}
|
||||
|
||||
#ifndef VECTOR_PAIR_LOADS_LHS
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(funcw, funcl, type) \
|
||||
type rhsV[4], rhsVi[4]; \
|
||||
type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,1) \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,2) MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,3)
|
||||
#else
|
||||
#define MICRO_COMPLEX_MMA_LOAD_TWO_RHS(peel1, right) \
|
||||
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real##right + (accRows * peel1)), prhsV##peel1); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV##right[peel1]), &prhsV##peel1); \
|
||||
if(!RhsIsReal) { \
|
||||
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag##right + (accRows * peel1)), prhsVi##peel1); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi##right[peel1]), &prhsVi##peel1); \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type, peel1, peel2) \
|
||||
if (PEEL_COMPLEX_MMA > peel2) { \
|
||||
PacketBlock<Packet,2> lhsV20, lhsV21, lhsV22, lhsV23; \
|
||||
@ -522,23 +639,12 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
__vector_pair plhsV0, plhsV1, plhsV2, plhsV3; \
|
||||
__vector_pair plhsVi0, plhsVi1, plhsVi2, plhsVi3; \
|
||||
if (sizeof(type) == 16) { \
|
||||
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_real + (accRows * peel1)), prhsV##peel1); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsV[peel1]), &prhsV##peel1); \
|
||||
if(!RhsIsReal) { \
|
||||
ploadRhsMMA(reinterpret_cast<const double*>(rhs_ptr_imag + (accRows * peel1)), prhsVi##peel1); \
|
||||
__builtin_vsx_disassemble_pair(reinterpret_cast<void*>(&rhsVi[peel1]), &prhsVi##peel1); \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
|
||||
} \
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_LOAD_TWO_RHS, peel1) \
|
||||
} else { \
|
||||
EIGEN_UNUSED_VARIABLE(prhsV##peel1); \
|
||||
EIGEN_UNUSED_VARIABLE(prhsVi##peel1); \
|
||||
ploadRhsMMA(rhs_ptr_real + (accRows * peel1), rhsV[peel1]); \
|
||||
ploadRhsMMA(rhs_ptr_real + (accRows * peel2), rhsV[peel2]); \
|
||||
if(!RhsIsReal) { \
|
||||
ploadRhsMMA(rhs_ptr_imag + (accRows * peel1), rhsVi[peel1]); \
|
||||
ploadRhsMMA(rhs_ptr_imag + (accRows * peel2), rhsVi[peel2]); \
|
||||
} \
|
||||
MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel1); \
|
||||
MICRO_COMPLEX_MMA_LOAD_ONE_RHS(peel2); \
|
||||
} \
|
||||
MICRO_COMPLEX_MMA_UNROLL(funcl2) \
|
||||
MICRO_COMPLEX_MMA_WORK(funcw2, type, peel1) \
|
||||
@ -550,7 +656,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2(funcw1, funcl1, funcw2, funcl2, type) \
|
||||
type rhsV[4], rhsVi[4]; \
|
||||
type rhsV0[4], rhsVi0[4], rhsV1[(accItr > 1) ? 4 : 1], rhsVi1[(accItr > 1) ? 4 : 1], rhsV2[(accItr > 2) ? 4 : 1], rhsVi2[(accItr > 2) ? 4 : 1], rhsV3[(accItr > 2) ? 4 : 1], rhsVi3[(accItr > 2) ? 4 : 1]; \
|
||||
__vector_pair prhsV0, prhsV2; \
|
||||
__vector_pair prhsVi0, prhsVi2; \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL2(funcw1,funcl1,funcw2,funcl2,type,0,1) \
|
||||
@ -558,21 +664,26 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
#endif
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(funcw, funcl, type) \
|
||||
type rhsV[1], rhsVi[1]; \
|
||||
type rhsV0[1], rhsVi0[1], rhsV1[1], rhsVi1[1], rhsV2[1], rhsVi2[1], rhsV3[1], rhsVi3[1]; \
|
||||
MICRO_COMPLEX_MMA_TYPE_PEEL(funcw,funcl,type,0)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UPDATE_RHS1(size, right) \
|
||||
rhs_ptr_real##right += (accRows * size); \
|
||||
if(!RhsIsReal) rhs_ptr_imag##right += (accRows * size);
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UPDATE_RHS(size) \
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_UPDATE_RHS1, size)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_TYPE, size) \
|
||||
MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, RhsPacket) \
|
||||
rhs_ptr_real += (accRows * size); \
|
||||
if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
|
||||
MICRO_COMPLEX_MMA_UPDATE_RHS(size);
|
||||
|
||||
#ifndef VECTOR_PAIR_LOADS_LHS
|
||||
#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL, PEEL_COMPLEX_MMA)
|
||||
#else
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_TYPE, size) \
|
||||
MICRO_COMPLEX_MMA_TYPE(MICRO_COMPLEX_MMA_WORK_ONE, MICRO_COMPLEX_LOAD_ONE, MICRO_COMPLEX_MMA_WORK_TWO, MICRO_COMPLEX_MMA_LOAD_TWO, RhsPacket) \
|
||||
rhs_ptr_real += (accRows * size); \
|
||||
if(!RhsIsReal) rhs_ptr_imag += (accRows * size);
|
||||
MICRO_COMPLEX_MMA_UPDATE_RHS(size);
|
||||
|
||||
#define MICRO_COMPLEX_MMA_ONE_PEEL MICRO_COMPLEX_MMA_UNROLL_TYPE2(MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL2, PEEL_COMPLEX_MMA)
|
||||
#endif
|
||||
@ -580,7 +691,7 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
#define MICRO_COMPLEX_MMA_ONE MICRO_COMPLEX_MMA_UNROLL_TYPE(MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE, 1)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_DST_PTR_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
if (unroll_factor * accItr > iter) { \
|
||||
bsetzeroMMA(&accReal##iter); \
|
||||
bsetzeroMMA(&accImag##iter); \
|
||||
} else { \
|
||||
@ -594,16 +705,34 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
|
||||
|
||||
#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
|
||||
if (unroll_factor > iter) { \
|
||||
storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (iter + 1)) ? accCols : accCols2>(row + iter*accCols, res, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
|
||||
#define MICRO_COMPLEX_MMA_STORE_ONE(iter, left, right) \
|
||||
if (unroll_factor > left) { \
|
||||
storeComplexAccumulator<DataMapper, Packet, Packetc, accCols, (unroll_factor != (left + 1)) ? accCols : accCols2>(row + left*accCols, res##right, pAlphaReal, pAlphaImag, pMask, &accReal##iter, &accImag##iter); \
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
|
||||
#define MICRO_COMPLEX_MMA_ITER_UNROLL(func) \
|
||||
if (accItr == 1) { \
|
||||
func(0,0,0) func(1,1,0) func(2,2,0) func(3,3,0) \
|
||||
} else if (accItr == 2) { \
|
||||
func(0,0,0) func(1,0,1) func(2,1,0) func(3,1,1) \
|
||||
} else { \
|
||||
func(0,0,0) func(1,0,1) func(2,0,2) func(3,0,3) \
|
||||
}
|
||||
|
||||
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_ITER_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
|
||||
|
||||
#define MICRO_COMPLEX_MMA_EXTRA_ROWS(right) \
|
||||
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3##right, blockA, rhs_base + right*accRows*(RhsIsReal ? 1 : 2)*strideB, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||
|
||||
#define MICRO_COMPLEX_MMA_EXTRA_ROWS1(val, right) \
|
||||
MICRO_COMPLEX_MMA_EXTRA_ROWS(right);
|
||||
|
||||
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, const Index accCols2, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index accItr>
|
||||
EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
|
||||
const DataMapper& res,
|
||||
const DataMapper& res0,
|
||||
const DataMapper& res1,
|
||||
const DataMapper& res2,
|
||||
const DataMapper& res3,
|
||||
const Scalar* lhs_base,
|
||||
const Scalar* rhs_base,
|
||||
Index depth,
|
||||
@ -615,14 +744,48 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
|
||||
const Packet& pAlphaImag,
|
||||
const Packet& pMask)
|
||||
{
|
||||
const Scalar* rhs_ptr_real = rhs_base;
|
||||
const Scalar* rhs_ptr_imag = NULL;
|
||||
const Scalar* rhs_ptr_real0 = rhs_base, * rhs_ptr_real1 = NULL, * rhs_ptr_real2 = NULL, * rhs_ptr_real3 = NULL;
|
||||
const Scalar* rhs_ptr_imag0 = NULL, * rhs_ptr_imag1 = NULL, * rhs_ptr_imag2 = NULL, * rhs_ptr_imag3 = NULL;
|
||||
const Index imag_delta = accCols*strideA;
|
||||
const Index imag_delta2 = accCols2*strideA;
|
||||
|
||||
if(!RhsIsReal) {
|
||||
rhs_ptr_imag = rhs_base + accRows*strideB;
|
||||
rhs_ptr_imag0 = rhs_base + accRows*strideB;
|
||||
} else {
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag0);
|
||||
}
|
||||
if (accItr > 1) {
|
||||
if(!RhsIsReal) {
|
||||
rhs_ptr_real1 = rhs_base + (2*accRows*strideB);
|
||||
rhs_ptr_imag1 = rhs_base + (3*accRows*strideB);
|
||||
} else {
|
||||
rhs_ptr_real1 = rhs_base + accRows*strideB;
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
|
||||
}
|
||||
} else {
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_real1);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag1);
|
||||
EIGEN_UNUSED_VARIABLE(res1);
|
||||
}
|
||||
if (accItr > 2) {
|
||||
if(!RhsIsReal) {
|
||||
rhs_ptr_real2 = rhs_base + (4*accRows*strideB);
|
||||
rhs_ptr_imag2 = rhs_base + (5*accRows*strideB);
|
||||
rhs_ptr_real3 = rhs_base + (6*accRows*strideB);
|
||||
rhs_ptr_imag3 = rhs_base + (7*accRows*strideB);
|
||||
} else {
|
||||
rhs_ptr_real2 = rhs_base + (2*accRows*strideB);
|
||||
rhs_ptr_real3 = rhs_base + (3*accRows*strideB);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
|
||||
}
|
||||
} else {
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_real2);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_real3);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag2);
|
||||
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag3);
|
||||
EIGEN_UNUSED_VARIABLE(res2);
|
||||
EIGEN_UNUSED_VARIABLE(res3);
|
||||
}
|
||||
const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
|
||||
const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
|
||||
@ -651,10 +814,15 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_UNROLL_ITER2(N, M) \
|
||||
gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
|
||||
gemm_complex_unrolled_MMA_iteration<N + (M ? 1 : 0), Scalar, Packet, Packetc, RhsPacket, DataMapper, accRows, accCols, M ? M : accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, accItr>(res30, res31, res32, res33, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, pAlphaReal, pAlphaImag, pMask); \
|
||||
if (M) return;
|
||||
|
||||
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
#define MICRO_COMPLEX_MMA_ROWS(n) \
|
||||
while(row + n*accCols <= rows) { \
|
||||
MICRO_COMPLEX_MMA_UNROLL_ITER2(n, 0); \
|
||||
}
|
||||
|
||||
template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index accItr>
|
||||
EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
|
||||
const DataMapper& res,
|
||||
const Scalar* blockA,
|
||||
@ -671,35 +839,50 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
|
||||
const Packet& pAlphaImag,
|
||||
const Packet& pMask)
|
||||
{
|
||||
const DataMapper res3 = res.getSubMapper(0, col);
|
||||
const DataMapper res30 = res.getSubMapper(0, col);
|
||||
const DataMapper res31 = (accItr > 1) ? res30.getSubMapper(0, accRows*1) : res30;
|
||||
const DataMapper res32 = (accItr > 2) ? res30.getSubMapper(0, accRows*2) : res30;
|
||||
const DataMapper res33 = (accItr > 2) ? res30.getSubMapper(0, accRows*3) : res30;
|
||||
|
||||
const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
|
||||
const Scalar* lhs_base = blockA + accCols*offsetA;
|
||||
Index row = 0;
|
||||
|
||||
#define MAX_COMPLEX_MMA_UNROLL 4
|
||||
while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
|
||||
MICRO_COMPLEX_MMA_UNROLL_ITER2(MAX_COMPLEX_MMA_UNROLL, 0);
|
||||
|
||||
#if MAX_COMPLEX_MMA_UNROLL < 2
|
||||
if (1) {
|
||||
#elif MAX_COMPLEX_MMA_UNROLL < 4
|
||||
if (accItr <= 2) {
|
||||
#else
|
||||
if (accItr == 1) {
|
||||
#endif
|
||||
MICRO_COMPLEX_MMA_ROWS(MAX_COMPLEX_MMA_UNROLL);
|
||||
} else if (accItr == 2) {
|
||||
MICRO_COMPLEX_MMA_ROWS(2);
|
||||
} else {
|
||||
MICRO_COMPLEX_MMA_ROWS(1);
|
||||
}
|
||||
switch( (rows-row)/accCols ) {
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 4
|
||||
case 4:
|
||||
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 4)
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 3
|
||||
case 3:
|
||||
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3)
|
||||
if (accItr == 1) {
|
||||
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 3)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 2
|
||||
case 2:
|
||||
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2)
|
||||
if (accItr == 1) {
|
||||
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 2)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
#if MAX_COMPLEX_MMA_UNROLL > 1
|
||||
case 1:
|
||||
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1)
|
||||
if (accItr <= 2) {
|
||||
MICRO_COMPLEX_UNROLL_ITER(MICRO_COMPLEX_MMA_UNROLL_ITER2, 1)
|
||||
}
|
||||
break;
|
||||
#endif
|
||||
default:
|
||||
@ -709,10 +892,16 @@ EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
|
||||
|
||||
if(remaining_rows > 0)
|
||||
{
|
||||
gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||
MICRO_MMA_UNROLL_ITER(MICRO_COMPLEX_MMA_EXTRA_ROWS1, 0)
|
||||
}
|
||||
}
|
||||
|
||||
#define MICRO_COMPLEX_MMA_COLS(n) \
|
||||
for(; col + n*accRows <= cols; col += n*accRows) \
|
||||
{ \
|
||||
gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, n>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask); \
|
||||
}
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||||
void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
|
||||
{
|
||||
@ -731,10 +920,11 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS
|
||||
typedef typename std::conditional_t<(sizeof(Scalar) == sizeof(float)), RhsPacket, __vector_pair> RhsPacket2;
|
||||
|
||||
Index col = 0;
|
||||
for(; col + accRows <= cols; col += accRows)
|
||||
{
|
||||
gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket2, DataMapper, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, remaining_rows, pAlphaReal, pAlphaImag, pMask);
|
||||
}
|
||||
#ifdef GEMM_MULTIPLE_COLS
|
||||
MICRO_COMPLEX_MMA_COLS(4);
|
||||
MICRO_COMPLEX_MMA_COLS(2);
|
||||
#endif
|
||||
MICRO_COMPLEX_MMA_COLS(1);
|
||||
|
||||
if (col != cols)
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user