mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-17 10:16:43 +08:00
Updated the matrix multiplication code to make it compile with AVX512 enabled.
This commit is contained in:
parent
85b6d82b49
commit
ef66f2887b
@ -595,7 +595,7 @@ DoublePacket<Packet> padd(const DoublePacket<Packet> &a, const DoublePacket<Pack
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
const DoublePacket<Packet>& predux4(const DoublePacket<Packet> &a)
|
const DoublePacket<Packet>& predux_half(const DoublePacket<Packet> &a)
|
||||||
{
|
{
|
||||||
return a;
|
return a;
|
||||||
}
|
}
|
||||||
@ -1682,10 +1682,10 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
|||||||
if(SwappedTraits::LhsProgress==8)
|
if(SwappedTraits::LhsProgress==8)
|
||||||
{
|
{
|
||||||
// Special case where we have to first reduce the accumulation register C0
|
// Special case where we have to first reduce the accumulation register C0
|
||||||
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
|
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
|
||||||
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
|
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
|
||||||
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
|
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
|
||||||
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
|
typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
|
||||||
|
|
||||||
SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
|
SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
|
||||||
SResPacketHalf alphav = pset1<SResPacketHalf>(alpha);
|
SResPacketHalf alphav = pset1<SResPacketHalf>(alpha);
|
||||||
@ -1697,13 +1697,13 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
|
|||||||
SRhsPacketHalf b0;
|
SRhsPacketHalf b0;
|
||||||
straits.loadLhsUnaligned(blB, a0);
|
straits.loadLhsUnaligned(blB, a0);
|
||||||
straits.loadRhs(blA, b0);
|
straits.loadRhs(blA, b0);
|
||||||
SAccPacketHalf c0 = predux4(C0);
|
SAccPacketHalf c0 = predux_half(C0);
|
||||||
straits.madd(a0,b0,c0,b0);
|
straits.madd(a0,b0,c0,b0);
|
||||||
straits.acc(c0, alphav, R);
|
straits.acc(c0, alphav, R);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
straits.acc(predux4(C0), alphav, R);
|
straits.acc(predux_half(C0), alphav, R);
|
||||||
}
|
}
|
||||||
res.scatterPacket(i, j2, R);
|
res.scatterPacket(i, j2, R);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user