Updated the matrix multiplication code to make it compile with AVX512 enabled.

This commit is contained in:
Benoit Steiner 2016-02-01 14:38:05 -08:00
parent 85b6d82b49
commit ef66f2887b

View File

@ -595,7 +595,7 @@ DoublePacket<Packet> padd(const DoublePacket<Packet> &a, const DoublePacket<Pack
} }
template<typename Packet> template<typename Packet>
const DoublePacket<Packet>& predux4(const DoublePacket<Packet> &a) const DoublePacket<Packet>& predux_half(const DoublePacket<Packet> &a)
{ {
return a; return a;
} }
@ -1682,10 +1682,10 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
if(SwappedTraits::LhsProgress==8) if(SwappedTraits::LhsProgress==8)
{ {
// Special case where we have to first reduce the accumulation register C0 // Special case where we have to first reduce the accumulation register C0
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf; typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SResPacket>::half,SResPacket>::type SResPacketHalf;
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf; typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SLhsPacket>::type SLhsPacketHalf;
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf; typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SLhsPacket>::half,SRhsPacket>::type SRhsPacketHalf;
typedef typename conditional<SwappedTraits::LhsProgress==8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf; typedef typename conditional<SwappedTraits::LhsProgress>=8,typename unpacket_traits<SAccPacket>::half,SAccPacket>::type SAccPacketHalf;
SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2); SResPacketHalf R = res.template gatherPacket<SResPacketHalf>(i, j2);
SResPacketHalf alphav = pset1<SResPacketHalf>(alpha); SResPacketHalf alphav = pset1<SResPacketHalf>(alpha);
@ -1697,13 +1697,13 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,DataMapper,mr,nr,ConjugateLhs,Conjuga
SRhsPacketHalf b0; SRhsPacketHalf b0;
straits.loadLhsUnaligned(blB, a0); straits.loadLhsUnaligned(blB, a0);
straits.loadRhs(blA, b0); straits.loadRhs(blA, b0);
SAccPacketHalf c0 = predux4(C0); SAccPacketHalf c0 = predux_half(C0);
straits.madd(a0,b0,c0,b0); straits.madd(a0,b0,c0,b0);
straits.acc(c0, alphav, R); straits.acc(c0, alphav, R);
} }
else else
{ {
straits.acc(predux4(C0), alphav, R); straits.acc(predux_half(C0), alphav, R);
} }
res.scatterPacket(i, j2, R); res.scatterPacket(i, j2, R);
} }