Fix for mixed products

This commit is contained in:
Gael Guennebaud 2014-04-25 13:22:34 +02:00
parent 2dbfd83424
commit c20e3641de

View File

@ -480,8 +480,14 @@ public:
loadRhs(b,dest); loadRhs(b,dest);
} }
// linking error if instantiated without being optimized out: EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3); {
// FIXME not sure that's the best way to implement it!
loadRhs(b+0, b0);
loadRhs(b+1, b1);
loadRhs(b+2, b2);
loadRhs(b+3, b3);
}
// Vectorized path // Vectorized path
EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacketType& b0, DoublePacketType& b1) EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, DoublePacketType& b0, DoublePacketType& b1)
@ -602,9 +608,11 @@ public:
dest = pset1<RhsPacket>(*b); dest = pset1<RhsPacket>(*b);
} }
// linking error if instantiated without being optimized out: void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3)
// void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3); {
// pbroadcast4(b, b0, b1, b2, b3);
}
// EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1) // EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1)
// { // {
// // FIXME not sure that's the best way to implement it! // // FIXME not sure that's the best way to implement it!
@ -1137,19 +1145,16 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
for(Index k=0; k<peeled_kc; k+=pk) for(Index k=0; k<peeled_kc; k+=pk)
{ {
EIGEN_ASM_COMMENT("begin gegp micro kernel 1pX4"); EIGEN_ASM_COMMENT("begin gegp micro kernel 1pX4");
RhsPacket B_0, B1; RhsPacket B_0, B1, B2, B3;
#define EIGEN_GEBGP_ONESTEP(K) \ #define EIGEN_GEBGP_ONESTEP(K) \
traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \ traits.loadLhs(&blA[(0+1*K)*LhsProgress], A0); \
traits.loadRhs(&blB[(0+4*K)*RhsProgress], B_0); \ traits.broadcastRhs(&blB[(0+4*K)*RhsProgress], B_0, B1, B2, B3); \
traits.madd(A0, B_0, C0, B1); \ traits.madd(A0, B_0, C0, B_0); \
traits.loadRhs(&blB[1+4*K*RhsProgress], B_0); \ traits.madd(A0, B1, C1, B1); \
traits.madd(A0, B_0, C1, B1); \ traits.madd(A0, B2, C2, B2); \
traits.loadRhs(&blB[2+4*K*RhsProgress], B_0); \ traits.madd(A0, B3, C3, B3);
traits.madd(A0, B_0, C2, B1); \
traits.loadRhs(&blB[3+4*K*RhsProgress], B_0); \
traits.madd(A0, B_0, C3 , B1); \
internal::prefetch(blB+(48+0)); internal::prefetch(blB+(48+0));
EIGEN_GEBGP_ONESTEP(0); EIGEN_GEBGP_ONESTEP(0);
EIGEN_GEBGP_ONESTEP(1); EIGEN_GEBGP_ONESTEP(1);
@ -1167,7 +1172,7 @@ void gebp_kernel<LhsScalar,RhsScalar,Index,mr,nr,ConjugateLhs,ConjugateRhs>
// process remaining peeled loop // process remaining peeled loop
for(Index k=peeled_kc; k<depth; k++) for(Index k=peeled_kc; k<depth; k++)
{ {
RhsPacket B_0, B1; RhsPacket B_0, B1, B2, B3;
EIGEN_GEBGP_ONESTEP(0); EIGEN_GEBGP_ONESTEP(0);
blB += 4*RhsProgress; blB += 4*RhsProgress;
blA += 1*LhsProgress; blA += 1*LhsProgress;
@ -1448,10 +1453,12 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1 const Index peeled_mc0 = Pack2>=1*PacketSize ? peeled_mc1
: Pack2>1 ? (rows/Pack2)*Pack2 : 0; : Pack2>1 ? (rows/Pack2)*Pack2 : 0;
Index i=0;
// Pack 3 packets // Pack 3 packets
if(Pack1>=3*PacketSize) if(Pack1>=3*PacketSize)
{ {
for(Index i=0; i<peeled_mc3; i+=3*PacketSize) for(; i<peeled_mc3; i+=3*PacketSize)
{ {
if(PanelMode) count += (3*PacketSize) * offset; if(PanelMode) count += (3*PacketSize) * offset;
@ -1471,7 +1478,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
// Pack 2 packets // Pack 2 packets
if(Pack1>=2*PacketSize) if(Pack1>=2*PacketSize)
{ {
for(Index i=peeled_mc3; i<peeled_mc2; i+=2*PacketSize) for(; i<peeled_mc2; i+=2*PacketSize)
{ {
if(PanelMode) count += (2*PacketSize) * offset; if(PanelMode) count += (2*PacketSize) * offset;
@ -1489,7 +1496,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
// Pack 1 packets // Pack 1 packets
if(Pack1>=1*PacketSize) if(Pack1>=1*PacketSize)
{ {
for(Index i=peeled_mc2; i<peeled_mc1; i+=1*PacketSize) for(; i<peeled_mc1; i+=1*PacketSize)
{ {
if(PanelMode) count += (1*PacketSize) * offset; if(PanelMode) count += (1*PacketSize) * offset;
@ -1506,7 +1513,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
// Pack scalars // Pack scalars
if(Pack2<PacketSize && Pack2>1) if(Pack2<PacketSize && Pack2>1)
{ {
for(Index i=peeled_mc1; i<peeled_mc0; i+=Pack2) for(; i<peeled_mc0; i+=Pack2)
{ {
if(PanelMode) count += Pack2 * offset; if(PanelMode) count += Pack2 * offset;
@ -1517,7 +1524,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, ColMajor, Conj
if(PanelMode) count += Pack2 * (stride-offset-depth); if(PanelMode) count += Pack2 * (stride-offset-depth);
} }
} }
for(Index i=peeled_mc0; i<rows; i++) for(; i<rows; i++)
{ {
if(PanelMode) count += offset; if(PanelMode) count += offset;
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)