Add subMappers to Power GEMM packing - simplifies the address calculations (10% faster)

This commit is contained in:
Chip Kerchner 2022-05-23 15:18:29 +00:00 committed by Antonio Sánchez
parent 32348091ba
commit aa8b7e2c37
3 changed files with 109 additions and 92 deletions

View File

@ -104,12 +104,6 @@ const static Packet16uc p16uc_GETIMAG32 = { 4, 5, 6, 7,
12, 13, 14, 15,
20, 21, 22, 23,
28, 29, 30, 31};
const static Packet16uc p16uc_GETREAL64 = { 0, 1, 2, 3, 4, 5, 6, 7,
16, 17, 18, 19, 20, 21, 22, 23};
//[a,ai],[b,bi] = [ai,bi]
const static Packet16uc p16uc_GETIMAG64 = { 8, 9, 10, 11, 12, 13, 14, 15,
24, 25, 26, 27, 28, 29, 30, 31};
/*********************************************
* Single precision real and complex packing *
@ -441,6 +435,7 @@ struct dhs_cpack {
for(; j + vectorSize <= rows; j+=vectorSize)
{
const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
Index i = 0;
rii = rir + vectorDelta;
@ -451,9 +446,9 @@ struct dhs_cpack {
PacketBlock<PacketC,8> cblock;
if (UseLhs) {
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs, j, i);
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, 0, i);
} else {
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs, i, j);
bload<DataMapper, PacketC, 2, StorageOrder, true, 4>(cblock, lhs2, i, 0);
}
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
@ -494,19 +489,19 @@ struct dhs_cpack {
if(((StorageOrder == ColMajor) && UseLhs) || (((StorageOrder == RowMajor) && !UseLhs)))
{
if (UseLhs) {
cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 2, i);
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
cblock.packet[1] = lhs2.template loadPacket<PacketC>(2, i);
} else {
cblock.packet[0] = lhs.template loadPacket<PacketC>(i, j + 0);
cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
cblock.packet[0] = lhs2.template loadPacket<PacketC>(i, 0);
cblock.packet[1] = lhs2.template loadPacket<PacketC>(i, 2);
}
} else {
if (UseLhs) {
cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i));
cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i));
cblock.packet[0] = pload2(lhs2(0, i), lhs2(1, i));
cblock.packet[1] = pload2(lhs2(2, i), lhs2(3, i));
} else {
cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1));
cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3));
cblock.packet[0] = pload2(lhs2(i, 0), lhs2(i, 1));
cblock.packet[1] = pload2(lhs2(i, 2), lhs2(i, 3));
}
}
@ -534,16 +529,17 @@ struct dhs_cpack {
for(; j < rows; j++)
{
const DataMapper lhs2 = lhs.getSubMapper(0, j);
rii = rir + ((PanelMode) ? stride : depth);
for(Index i = 0; i < depth; i++)
{
blockAt[rir] = lhs(i, j).real();
blockAt[rir] = lhs2(i, 0).real();
if(Conjugate)
blockAt[rii] = -lhs(i, j).imag();
blockAt[rii] = -lhs2(i, 0).imag();
else
blockAt[rii] = lhs(i, j).imag();
blockAt[rii] = lhs2(i, 0).imag();
rir += 1;
rii += 1;
@ -588,6 +584,7 @@ struct dhs_pack{
for(; j + vectorSize <= rows; j+=vectorSize)
{
const DataMapper lhs2 = UseLhs ? lhs.getSubMapper(j, 0) : lhs.getSubMapper(0, j);
Index i = 0;
if(PanelMode) ri += vectorSize*offset;
@ -597,9 +594,9 @@ struct dhs_pack{
PacketBlock<Packet,4> block;
if (UseLhs) {
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs, j, i);
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs2, 0, i);
} else {
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs, i, j);
bload<DataMapper, Packet, 4, StorageOrder, false, 4>(block, lhs2, i, 0);
}
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
{
@ -615,22 +612,22 @@ struct dhs_pack{
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
{
if (UseLhs) {
blockA[ri+0] = lhs(j+0, i);
blockA[ri+1] = lhs(j+1, i);
blockA[ri+2] = lhs(j+2, i);
blockA[ri+3] = lhs(j+3, i);
blockA[ri+0] = lhs2(0, i);
blockA[ri+1] = lhs2(1, i);
blockA[ri+2] = lhs2(2, i);
blockA[ri+3] = lhs2(3, i);
} else {
blockA[ri+0] = lhs(i, j+0);
blockA[ri+1] = lhs(i, j+1);
blockA[ri+2] = lhs(i, j+2);
blockA[ri+3] = lhs(i, j+3);
blockA[ri+0] = lhs2(i, 0);
blockA[ri+1] = lhs2(i, 1);
blockA[ri+2] = lhs2(i, 2);
blockA[ri+3] = lhs2(i, 3);
}
} else {
Packet lhsV;
if (UseLhs) {
lhsV = lhs.template loadPacket<Packet>(j, i);
lhsV = lhs2.template loadPacket<Packet>(0, i);
} else {
lhsV = lhs.template loadPacket<Packet>(i, j);
lhsV = lhs2.template loadPacket<Packet>(i, 0);
}
pstore<Scalar>(blockA + ri, lhsV);
}
@ -647,9 +644,10 @@ struct dhs_pack{
for(; j < rows; j++)
{
const DataMapper lhs2 = lhs.getSubMapper(0, j);
for(Index i = 0; i < depth; i++)
{
blockA[ri] = lhs(i, j);
blockA[ri] = lhs2(i, 0);
ri += 1;
}
@ -685,6 +683,7 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
for(; j + vectorSize <= rows; j+=vectorSize)
{
const DataMapper lhs2 = lhs.getSubMapper(j, 0);
Index i = 0;
if(PanelMode) ri += vectorSize*offset;
@ -694,13 +693,13 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
PacketBlock<Packet2d,2> block;
if(StorageOrder == RowMajor)
{
block.packet[0] = lhs.template loadPacket<Packet2d>(j + 0, i);
block.packet[1] = lhs.template loadPacket<Packet2d>(j + 1, i);
block.packet[0] = lhs2.template loadPacket<Packet2d>(0, i);
block.packet[1] = lhs2.template loadPacket<Packet2d>(1, i);
ptranspose(block);
} else {
block.packet[0] = lhs.template loadPacket<Packet2d>(j, i + 0);
block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
block.packet[0] = lhs2.template loadPacket<Packet2d>(0, i + 0);
block.packet[1] = lhs2.template loadPacket<Packet2d>(0, i + 1);
}
storeBlock<double, Packet2d, 2>(blockA + ri, block);
@ -711,10 +710,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, true>
{
if(StorageOrder == RowMajor)
{
blockA[ri+0] = lhs(j+0, i);
blockA[ri+1] = lhs(j+1, i);
blockA[ri+0] = lhs2(0, i);
blockA[ri+1] = lhs2(1, i);
} else {
Packet2d lhsV = lhs.template loadPacket<Packet2d>(j, i);
Packet2d lhsV = lhs2.template loadPacket<Packet2d>(0, i);
pstore<double>(blockA + ri, lhsV);
}
@ -752,6 +751,7 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
{
const DataMapper rhs2 = rhs.getSubMapper(0, j);
Index i = 0;
if(PanelMode) ri += offset*(2*vectorSize);
@ -762,10 +762,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
if(StorageOrder == ColMajor)
{
PacketBlock<Packet2d,2> block1, block2;
block1.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 0);
block1.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 1);
block2.packet[0] = rhs.template loadPacket<Packet2d>(i, j + 2);
block2.packet[1] = rhs.template loadPacket<Packet2d>(i, j + 3);
block1.packet[0] = rhs2.template loadPacket<Packet2d>(i, 0);
block1.packet[1] = rhs2.template loadPacket<Packet2d>(i, 1);
block2.packet[0] = rhs2.template loadPacket<Packet2d>(i, 2);
block2.packet[1] = rhs2.template loadPacket<Packet2d>(i, 3);
ptranspose(block1);
ptranspose(block2);
@ -775,10 +775,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
pstore<double>(blockB + ri + 4, block1.packet[1]);
pstore<double>(blockB + ri + 6, block2.packet[1]);
} else {
block.packet[0] = rhs.template loadPacket<Packet2d>(i + 0, j + 0); //[a1 a2]
block.packet[1] = rhs.template loadPacket<Packet2d>(i + 0, j + 2); //[a3 a4]
block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
block.packet[0] = rhs2.template loadPacket<Packet2d>(i + 0, 0); //[a1 a2]
block.packet[1] = rhs2.template loadPacket<Packet2d>(i + 0, 2); //[a3 a4]
block.packet[2] = rhs2.template loadPacket<Packet2d>(i + 1, 0); //[b1 b2]
block.packet[3] = rhs2.template loadPacket<Packet2d>(i + 1, 2); //[b3 b4]
storeBlock<double, Packet2d, 4>(blockB + ri, block);
}
@ -789,20 +789,20 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
{
if(StorageOrder == ColMajor)
{
blockB[ri+0] = rhs(i, j+0);
blockB[ri+1] = rhs(i, j+1);
blockB[ri+0] = rhs2(i, 0);
blockB[ri+1] = rhs2(i, 1);
ri += vectorSize;
blockB[ri+0] = rhs(i, j+2);
blockB[ri+1] = rhs(i, j+3);
blockB[ri+0] = rhs2(i, 2);
blockB[ri+1] = rhs2(i, 3);
} else {
Packet2d rhsV = rhs.template loadPacket<Packet2d>(i, j);
Packet2d rhsV = rhs2.template loadPacket<Packet2d>(i, 0);
pstore<double>(blockB + ri, rhsV);
ri += vectorSize;
rhsV = rhs.template loadPacket<Packet2d>(i, j + 2);
rhsV = rhs2.template loadPacket<Packet2d>(i, 2);
pstore<double>(blockB + ri, rhsV);
}
ri += vectorSize;
@ -815,9 +815,10 @@ struct dhs_pack<double, DataMapper, Packet2d, StorageOrder, PanelMode, false>
for(; j < cols; j++)
{
const DataMapper rhs2 = rhs.getSubMapper(0, j);
for(Index i = 0; i < depth; i++)
{
blockB[ri] = rhs(i, j);
blockB[ri] = rhs2(i, 0);
ri += 1;
}
@ -840,6 +841,7 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
for(; j + vectorSize <= rows; j+=vectorSize)
{
const DataMapper lhs2 = lhs.getSubMapper(j, 0);
Index i = 0;
rii = rir + vectorDelta;
@ -851,29 +853,29 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
if(StorageOrder == ColMajor)
{
cblock.packet[0] = lhs.template loadPacket<PacketC>(j, i + 0); //[a1 a1i]
cblock.packet[1] = lhs.template loadPacket<PacketC>(j, i + 1); //[b1 b1i]
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i + 0); //[a1 a1i]
cblock.packet[1] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 1, i + 0); //[a2 a2i]
cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i]
cblock.packet[2] = lhs2.template loadPacket<PacketC>(1, i + 0); //[a2 a2i]
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i]
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETREAL64); //[a1 a2]
blockr.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[2].v); //[a1 a2]
blockr.packet[1] = vec_mergeh(cblock.packet[1].v, cblock.packet[3].v); //[b1 b2]
blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[2].v, p16uc_GETIMAG64);
blocki.packet[1] = vec_perm(cblock.packet[1].v, cblock.packet[3].v, p16uc_GETIMAG64);
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[2].v);
blocki.packet[1] = vec_mergel(cblock.packet[1].v, cblock.packet[3].v);
} else {
cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i); //[a1 a1i]
cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i); //[a2 a2i]
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i); //[a1 a1i]
cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i); //[a2 a2i]
cblock.packet[2] = lhs.template loadPacket<PacketC>(j + 0, i + 1); //[b1 b1i]
cblock.packet[3] = lhs.template loadPacket<PacketC>(j + 1, i + 1); //[b2 b2i
cblock.packet[2] = lhs2.template loadPacket<PacketC>(0, i + 1); //[b1 b1i]
cblock.packet[3] = lhs2.template loadPacket<PacketC>(1, i + 1); //[b2 b2i
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); //[a1 a2]
blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); //[b1 b2]
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v); //[a1 a2]
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v); //[b1 b2]
blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
}
if(Conjugate)
@ -893,11 +895,11 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
PacketBlock<Packet,1> blockr, blocki;
PacketBlock<PacketC,2> cblock;
cblock.packet[0] = lhs.template loadPacket<PacketC>(j + 0, i);
cblock.packet[1] = lhs.template loadPacket<PacketC>(j + 1, i);
cblock.packet[0] = lhs2.template loadPacket<PacketC>(0, i);
cblock.packet[1] = lhs2.template loadPacket<PacketC>(1, i);
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
if(Conjugate)
{
@ -953,6 +955,7 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
for(; j + 2*vectorSize <= cols; j+=2*vectorSize)
{
const DataMapper rhs2 = rhs.getSubMapper(0, j);
Index i = 0;
rii = rir + vectorDelta;
@ -962,13 +965,13 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
PacketBlock<PacketC,4> cblock;
PacketBlock<Packet,2> blockr, blocki;
bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs, i, j);
bload<DataMapper, PacketC, 2, ColMajor, false, 4>(cblock, rhs2, i, 0);
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
blockr.packet[0] = vec_mergeh(cblock.packet[0].v, cblock.packet[1].v);
blockr.packet[1] = vec_mergeh(cblock.packet[2].v, cblock.packet[3].v);
blocki.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETIMAG64);
blocki.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETIMAG64);
blocki.packet[0] = vec_mergel(cblock.packet[0].v, cblock.packet[1].v);
blocki.packet[1] = vec_mergel(cblock.packet[2].v, cblock.packet[3].v);
if(Conjugate)
{
@ -990,16 +993,17 @@ struct dhs_cpack<double, DataMapper, Packet, PacketC, StorageOrder, Conjugate, P
for(; j < cols; j++)
{
const DataMapper rhs2 = rhs.getSubMapper(0, j);
rii = rir + ((PanelMode) ? stride : depth);
for(Index i = 0; i < depth; i++)
{
blockBt[rir] = rhs(i, j).real();
blockBt[rir] = rhs2(i, 0).real();
if(Conjugate)
blockBt[rii] = -rhs(i, j).imag();
blockBt[rii] = -rhs2(i, 0).imag();
else
blockBt[rii] = rhs(i, j).imag();
blockBt[rii] = rhs2(i, 0).imag();
rir += 1;
rii += 1;

View File

@ -1937,12 +1937,12 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(ResPacket& a, ResPa
GEMV_UNROLL_ROW(GEMV_INIT_ROW, N) \
Index j = 0; \
for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
RhsPacket a0 = rhs2.template load<RhsPacket, Unaligned>(j, 0); \
RhsPacket a0 = rhs2.template load<RhsPacket, Unaligned>(j); \
GEMV_UNROLL_ROW(GEMV_WORK_ROW, N) \
} \
GEMV_UNROLL_ROW_HALF(GEMV_PREDUX2, (N >> 1)) \
for (; j < cols; ++j) { \
RhsScalar a0 = rhs2(j, 0); \
RhsScalar a0 = rhs2(j); \
GEMV_UNROLL_ROW_HALF(GEMV_MULT, (N >> 1)) \
} \
GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW, (N >> 1)) \
@ -1965,7 +1965,7 @@ EIGEN_STRONG_INLINE void gemv_row(
// The following copy tells the compiler that lhs's attributes are not modified outside this function
// This helps GCC to generate proper code.
LhsMapper lhs(alhs);
RhsMapper rhs2(rhs);
typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
eigen_internal_assert(rhs.stride() == 1);
conj_helper<LhsScalar, RhsScalar, false, false> cj;
@ -2006,14 +2006,14 @@ EIGEN_STRONG_INLINE void gemv_row(
Index j = 0;
for (; j + LhsPacketSize <= cols; j += LhsPacketSize)
{
RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j, 0);
RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j);
d0 = pcj.pmadd(lhs.template load<LhsPacket, LhsAlignment>(i + 0, j), b0, d0);
}
ResScalar dd0 = predux(d0);
for (; j < cols; ++j)
{
dd0 += cj.pmul(lhs(i, j), rhs2(j, 0));
dd0 += cj.pmul(lhs(i, j), rhs2(j));
}
res[i * resIncr] += alpha * dd0;
}
@ -2075,14 +2075,14 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PRe
#define GEMV_PROCESS_ROW_COMPLEX_SINGLE_WORK(which, N) \
j = 0; \
for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
const RhsScalar& b1 = rhs2(j, 0); \
const RhsScalar& b1 = rhs2(j); \
RhsScalar* b = const_cast<RhsScalar *>(&b1); \
GEMV_UNROLL_ROW(which, N) \
}
#define GEMV_PROCESS_END_ROW_COMPLEX(N) \
for (; j < cols; ++j) { \
RhsScalar b0 = rhs2(j, 0); \
RhsScalar b0 = rhs2(j); \
GEMV_UNROLL_ROW_HALF(GEMV_MULT_COMPLEX, (N >> 1)) \
} \
GEMV_UNROLL_ROW_HALF(GEMV_STORE_ROW_COMPLEX, (N >> 1))
@ -2216,7 +2216,7 @@ EIGEN_ALWAYS_INLINE ScalarBlock<ResScalar, 2> predux_complex(PResPacket& a0, PRe
GEMV_UNROLL_ROW(GEMV_INIT_COMPLEX_OLD, N) \
j = 0; \
for (; j + LhsPacketSize <= cols; j += LhsPacketSize) { \
RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j, 0); \
RhsPacket b0 = rhs2.template load<RhsPacket, Unaligned>(j); \
GEMV_UNROLL_ROW(GEMV_WORK_ROW_COMPLEX_OLD, N) \
}
@ -2289,7 +2289,7 @@ EIGEN_STRONG_INLINE void gemv_complex_row(
// The following copy tells the compiler that lhs's attributes are not modified outside this function
// This helps GCC to generate proper code.
LhsMapper lhs(alhs);
RhsMapper rhs2(rhs);
typename RhsMapper::LinearMapper rhs2 = rhs.getLinearMapper(0, 0);
eigen_internal_assert(rhs.stride() == 1);
conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
@ -2340,7 +2340,7 @@ EIGEN_STRONG_INLINE void gemv_complex_row(
GEMV_PROCESS_ROW_COMPLEX_PREDUX(0)
for (; j < cols; ++j)
{
dd0 += cj.pmul(lhs(i, j), rhs2(j, 0));
dd0 += cj.pmul(lhs(i, j), rhs2(j));
}
res[i * resIncr] += alpha * dd0;
}

View File

@ -416,6 +416,7 @@ class TensorContractionSubMapper {
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> ParentMapper;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Self;
typedef Self LinearMapper;
typedef Self SubMapper;
enum {
// We can use direct offsets iff the parent mapper supports then and we can compute the strides.
@ -485,6 +486,13 @@ class TensorContractionSubMapper {
return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubMapper getSubMapper(Index i, Index j) const {
if (UseDirectOffsets) {
return SubMapper(m_base_mapper, i, j);
}
return SubMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
}
template <typename PacketT, int AlignmentType>
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i) const {
EIGEN_STATIC_ASSERT((internal::is_same<PacketT, PacketT>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
@ -531,6 +539,7 @@ class TensorContractionInputMapper
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment, MakePointer_> SubMapper;
typedef SubMapper VectorMapper;
typedef SubMapper LinearMapper;
EIGEN_DEVICE_FUNC TensorContractionInputMapper(const Tensor& tensor,
const nocontract_t& nocontract_strides,
@ -544,6 +553,10 @@ class TensorContractionInputMapper
return SubMapper(*this, i, j);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
return LinearMapper(*this, i, j);
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
return VectorMapper(*this, i, j);
}