matrix product: move the alpha factor to gebp instead of the packing,

clean some temporaries, etc.
This commit is contained in:
Gael Guennebaud 2010-07-12 16:31:46 +02:00
parent f8678272a4
commit b72b7ab76f
6 changed files with 177 additions and 225 deletions

View File

@ -136,7 +136,7 @@ inline void computeProductBlockingSizes(std::ptrdiff_t& k, std::ptrdiff_t& m, st
// FIXME // FIXME
#ifndef EIGEN_HAS_FUSE_CJMADD #ifndef EIGEN_HAS_FUSE_CJMADD
#define EIGEN_HAS_FUSE_CJMADD #define EIGEN_HAS_FUSE_CJMADD
#endif #endif
#ifdef EIGEN_HAS_FUSE_CJMADD #ifdef EIGEN_HAS_FUSE_CJMADD
#define MADD(CJ,A,B,C,T) C = CJ.pmadd(A,B,C); #define MADD(CJ,A,B,C,T) C = CJ.pmadd(A,B,C);
#else #else
@ -144,7 +144,7 @@ inline void computeProductBlockingSizes(std::ptrdiff_t& k, std::ptrdiff_t& m, st
#endif #endif
/* optimized GEneral packed Block * packed Panel product kernel /* optimized GEneral packed Block * packed Panel product kernel
* *
* Mixing type logic: C += A * B * Mixing type logic: C += A * B
* | A | B | comments * | A | B | comments
* |real |cplx | no vectorization yet, would require to pack A with duplication * |real |cplx | no vectorization yet, would require to pack A with duplication
@ -156,10 +156,7 @@ struct ei_gebp_kernel
typedef typename ei_scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar; typedef typename ei_scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
enum { enum {
Vectorizable = ei_product_blocking_traits<LhsScalar,RhsScalar>::Vectorizable /*ei_packet_traits<LhsScalar>::Vectorizable Vectorizable = ei_product_blocking_traits<LhsScalar,RhsScalar>::Vectorizable,
&& ei_packet_traits<RhsScalar>::Vectorizable
&& (ei_is_same_type<LhsScalar,RhsScalar>::ret
|| (NumTraits<LhsScalar>::IsComplex && !NumTraits<RhsScalar>::IsComplex))*/,
LhsPacketSize = Vectorizable ? ei_packet_traits<LhsScalar>::size : 1, LhsPacketSize = Vectorizable ? ei_packet_traits<LhsScalar>::size : 1,
RhsPacketSize = Vectorizable ? ei_packet_traits<RhsScalar>::size : 1, RhsPacketSize = Vectorizable ? ei_packet_traits<RhsScalar>::size : 1,
ResPacketSize = Vectorizable ? ei_packet_traits<ResScalar>::size : 1 ResPacketSize = Vectorizable ? ei_packet_traits<ResScalar>::size : 1
@ -173,7 +170,7 @@ struct ei_gebp_kernel
typedef typename ei_meta_if<Vectorizable,_RhsPacket,RhsScalar>::ret RhsPacket; typedef typename ei_meta_if<Vectorizable,_RhsPacket,RhsScalar>::ret RhsPacket;
typedef typename ei_meta_if<Vectorizable,_ResPacket,ResScalar>::ret ResPacket; typedef typename ei_meta_if<Vectorizable,_ResPacket,ResScalar>::ret ResPacket;
void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, void operator()(ResScalar* res, Index resStride, const LhsScalar* blockA, const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha,
Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, RhsScalar* unpackedB = 0) Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, RhsScalar* unpackedB = 0)
{ {
if(strideA==-1) strideA = depth; if(strideA==-1) strideA = depth;
@ -447,6 +444,7 @@ EIGEN_ASM_COMMENT("myend");
} }
ResPacket R0, R1, R2, R3, R4, R5, R6, R7; ResPacket R0, R1, R2, R3, R4, R5, R6, R7;
ResPacket alphav = ei_pset1<ResPacket>(alpha);
R0 = ei_ploadu<ResPacket>(r0); R0 = ei_ploadu<ResPacket>(r0);
R1 = ei_ploadu<ResPacket>(r1); R1 = ei_ploadu<ResPacket>(r1);
@ -457,14 +455,14 @@ EIGEN_ASM_COMMENT("myend");
if(nr==4) R6 = ei_ploadu<ResPacket>(r2 + ResPacketSize); if(nr==4) R6 = ei_ploadu<ResPacket>(r2 + ResPacketSize);
if(nr==4) R7 = ei_ploadu<ResPacket>(r3 + ResPacketSize); if(nr==4) R7 = ei_ploadu<ResPacket>(r3 + ResPacketSize);
C0 = ei_padd(R0, C0); C0 = ei_pmadd(C0, alphav, R0);
C1 = ei_padd(R1, C1); C1 = ei_pmadd(C1, alphav, R1);
if(nr==4) C2 = ei_padd(R2, C2); if(nr==4) C2 = ei_pmadd(C2, alphav, R2);
if(nr==4) C3 = ei_padd(R3, C3); if(nr==4) C3 = ei_pmadd(C3, alphav, R3);
C4 = ei_padd(R4, C4); C4 = ei_pmadd(C4, alphav, R4);
C5 = ei_padd(R5, C5); C5 = ei_pmadd(C5, alphav, R5);
if(nr==4) C6 = ei_padd(R6, C6); if(nr==4) C6 = ei_pmadd(C6, alphav, R6);
if(nr==4) C7 = ei_padd(R7, C7); if(nr==4) C7 = ei_pmadd(C7, alphav, R7);
ei_pstoreu(r0, C0); ei_pstoreu(r0, C0);
ei_pstoreu(r1, C1); ei_pstoreu(r1, C1);
@ -483,10 +481,10 @@ EIGEN_ASM_COMMENT("myend");
// gets res block as register // gets res block as register
ResPacket C0, C1, C2, C3; ResPacket C0, C1, C2, C3;
C0 = ei_ploadu<ResPacket>(&res[(j2+0)*resStride + i]); C0 = ei_pset1<ResPacket>(ResScalar(0));
C1 = ei_ploadu<ResPacket>(&res[(j2+1)*resStride + i]); C1 = ei_pset1<ResPacket>(ResScalar(0));
if(nr==4) C2 = ei_ploadu<ResPacket>(&res[(j2+2)*resStride + i]); if(nr==4) C2 = ei_pset1<ResPacket>(ResScalar(0));
if(nr==4) C3 = ei_ploadu<ResPacket>(&res[(j2+3)*resStride + i]); if(nr==4) C3 = ei_pset1<ResPacket>(ResScalar(0));
// performs "inner" product // performs "inner" product
const RhsScalar* blB = unpackedB; const RhsScalar* blB = unpackedB;
@ -573,24 +571,18 @@ EIGEN_ASM_COMMENT("myend");
if(nr==2) if(nr==2)
{ {
LhsPacket A0; LhsPacket A0;
RhsPacket B0; RhsPacket B0, B1;
#ifndef EIGEN_HAS_FUSE_CJMADD
RhsPacket T0;
#endif
A0 = ei_pload<LhsPacket>(&blA[0*LhsPacketSize]); A0 = ei_pload<LhsPacket>(&blA[0*LhsPacketSize]);
B0 = ei_pload<RhsPacket>(&blB[0*RhsPacketSize]); B0 = ei_pload<RhsPacket>(&blB[0*RhsPacketSize]);
MADD(pcj,A0,B0,C0,T0); B1 = ei_pload<RhsPacket>(&blB[1*RhsPacketSize]);
B0 = ei_pload<RhsPacket>(&blB[1*RhsPacketSize]); MADD(pcj,A0,B0,C0,B0);
MADD(pcj,A0,B0,C1,T0); MADD(pcj,A0,B1,C1,B1);
} }
else else
{ {
LhsPacket A0; LhsPacket A0;
RhsPacket B0, B1, B2, B3; RhsPacket B0, B1, B2, B3;
#ifndef EIGEN_HAS_FUSE_CJMADD
RhsPacket T0, T1;
#endif
A0 = ei_pload<LhsPacket>(&blA[0*LhsPacketSize]); A0 = ei_pload<LhsPacket>(&blA[0*LhsPacketSize]);
B0 = ei_pload<RhsPacket>(&blB[0*RhsPacketSize]); B0 = ei_pload<RhsPacket>(&blB[0*RhsPacketSize]);
@ -598,20 +590,38 @@ EIGEN_ASM_COMMENT("myend");
B2 = ei_pload<RhsPacket>(&blB[2*RhsPacketSize]); B2 = ei_pload<RhsPacket>(&blB[2*RhsPacketSize]);
B3 = ei_pload<RhsPacket>(&blB[3*RhsPacketSize]); B3 = ei_pload<RhsPacket>(&blB[3*RhsPacketSize]);
MADD(pcj,A0,B0,C0,T0); MADD(pcj,A0,B0,C0,B0);
MADD(pcj,A0,B1,C1,T1); MADD(pcj,A0,B1,C1,B1);
MADD(pcj,A0,B2,C2,T0); MADD(pcj,A0,B2,C2,B2);
MADD(pcj,A0,B3,C3,T1); MADD(pcj,A0,B3,C3,B3);
} }
blB += nr*RhsPacketSize; blB += nr*RhsPacketSize;
blA += LhsPacketSize; blA += LhsPacketSize;
} }
ei_pstoreu(&res[(j2+0)*resStride + i], C0); ResPacket R0, R1, R2, R3;
ei_pstoreu(&res[(j2+1)*resStride + i], C1); ResPacket alphav = ei_pset1<ResPacket>(alpha);
if(nr==4) ei_pstoreu(&res[(j2+2)*resStride + i], C2);
if(nr==4) ei_pstoreu(&res[(j2+3)*resStride + i], C3); ResScalar* r0 = &res[(j2+0)*resStride + i];
ResScalar* r1 = r0 + resStride;
ResScalar* r2 = r1 + resStride;
ResScalar* r3 = r2 + resStride;
R0 = ei_ploadu<ResPacket>(r0);
R1 = ei_ploadu<ResPacket>(r1);
if(nr==4) R2 = ei_ploadu<ResPacket>(r2);
if(nr==4) R3 = ei_ploadu<ResPacket>(r3);
C0 = ei_pmadd(C0, alphav, R0);
C1 = ei_pmadd(C1, alphav, R1);
if(nr==4) C2 = ei_pmadd(C2, alphav, R2);
if(nr==4) C3 = ei_pmadd(C3, alphav, R3);
ei_pstoreu(r0, C0);
ei_pstoreu(r1, C1);
if(nr==4) ei_pstoreu(r2, C2);
if(nr==4) ei_pstoreu(r3, C3);
} }
for(Index i=peeled_mc2; i<rows; i++) for(Index i=peeled_mc2; i<rows; i++)
{ {
@ -627,24 +637,18 @@ EIGEN_ASM_COMMENT("myend");
if(nr==2) if(nr==2)
{ {
LhsScalar A0; LhsScalar A0;
RhsScalar B0; RhsScalar B0, B1;
#ifndef EIGEN_HAS_FUSE_CJMADD
ResScalar T0;
#endif
A0 = blA[k]; A0 = blA[k];
B0 = blB[0*RhsPacketSize]; B0 = blB[0*RhsPacketSize];
MADD(cj,A0,B0,C0,T0); B1 = blB[1*RhsPacketSize];
B0 = blB[1*RhsPacketSize]; MADD(cj,A0,B0,C0,B0);
MADD(cj,A0,B0,C1,T0); MADD(cj,A0,B1,C1,B1);
} }
else else
{ {
LhsScalar A0; LhsScalar A0;
RhsScalar B0, B1, B2, B3; RhsScalar B0, B1, B2, B3;
#ifndef EIGEN_HAS_FUSE_CJMADD
ResScalar T0, T1;
#endif
A0 = blA[k]; A0 = blA[k];
B0 = blB[0*RhsPacketSize]; B0 = blB[0*RhsPacketSize];
@ -652,18 +656,18 @@ EIGEN_ASM_COMMENT("myend");
B2 = blB[2*RhsPacketSize]; B2 = blB[2*RhsPacketSize];
B3 = blB[3*RhsPacketSize]; B3 = blB[3*RhsPacketSize];
MADD(cj,A0,B0,C0,T0); MADD(cj,A0,B0,C0,B0);
MADD(cj,A0,B1,C1,T1); MADD(cj,A0,B1,C1,B1);
MADD(cj,A0,B2,C2,T0); MADD(cj,A0,B2,C2,B2);
MADD(cj,A0,B3,C3,T1); MADD(cj,A0,B3,C3,B3);
} }
blB += nr*RhsPacketSize; blB += nr*RhsPacketSize;
} }
res[(j2+0)*resStride + i] += C0; res[(j2+0)*resStride + i] += alpha*C0;
res[(j2+1)*resStride + i] += C1; res[(j2+1)*resStride + i] += alpha*C1;
if(nr==4) res[(j2+2)*resStride + i] += C2; if(nr==4) res[(j2+2)*resStride + i] += alpha*C2;
if(nr==4) res[(j2+3)*resStride + i] += C3; if(nr==4) res[(j2+3)*resStride + i] += alpha*C3;
} }
} }
@ -687,8 +691,9 @@ EIGEN_ASM_COMMENT("myend");
// get res block as registers // get res block as registers
ResPacket C0, C4; ResPacket C0, C4;
C0 = ei_ploadu<ResPacket>(&res[(j2+0)*resStride + i]); C0 = ei_pset1<ResPacket>(ResScalar(0));
C4 = ei_ploadu<ResPacket>(&res[(j2+0)*resStride + i + ResPacketSize]); C4 = ei_pset1<ResPacket>(ResScalar(0));
const RhsScalar* blB = unpackedB; const RhsScalar* blB = unpackedB;
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
@ -696,21 +701,31 @@ EIGEN_ASM_COMMENT("myend");
LhsPacket A0, A1; LhsPacket A0, A1;
RhsPacket B0; RhsPacket B0;
#ifndef EIGEN_HAS_FUSE_CJMADD #ifndef EIGEN_HAS_FUSE_CJMADD
RhsPacket T0, T1; RhsPacket T0;
#endif #endif
A0 = ei_pload<LhsPacket>(&blA[0*LhsPacketSize]); A0 = ei_pload<LhsPacket>(&blA[0*LhsPacketSize]);
A1 = ei_pload<LhsPacket>(&blA[1*LhsPacketSize]); A1 = ei_pload<LhsPacket>(&blA[1*LhsPacketSize]);
B0 = ei_pload<RhsPacket>(&blB[0*RhsPacketSize]); B0 = ei_pload<RhsPacket>(&blB[0*RhsPacketSize]);
MADD(pcj,A0,B0,C0,T0); MADD(pcj,A0,B0,C0,T0);
MADD(pcj,A1,B0,C4,T1); MADD(pcj,A1,B0,C4,B0);
blB += RhsPacketSize; blB += RhsPacketSize;
blA += mr; blA += mr;
} }
ResPacket R0, R4;
ResPacket alphav = ei_pset1<ResPacket>(alpha);
ei_pstoreu(&res[(j2+0)*resStride + i], C0); ResScalar* r0 = &res[(j2+0)*resStride + i];
ei_pstoreu(&res[(j2+0)*resStride + i + ResPacketSize], C4);
R0 = ei_ploadu<ResPacket>(r0);
R4 = ei_ploadu<ResPacket>(r0+ResPacketSize);
C0 = ei_pmadd(C0, alphav, R0);
C4 = ei_pmadd(C4, alphav, R4);
ei_pstoreu(r0, C0);
ei_pstoreu(r0+ResPacketSize, C4);
} }
if(rows-peeled_mc>=LhsPacketSize) if(rows-peeled_mc>=LhsPacketSize)
{ {
@ -718,20 +733,21 @@ EIGEN_ASM_COMMENT("myend");
const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsPacketSize]; const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsPacketSize];
ei_prefetch(&blA[0]); ei_prefetch(&blA[0]);
ResPacket C0 = ei_ploadu<ResPacket>(&res[(j2+0)*resStride + i]); ResPacket C0 = ei_pset1<ResPacket>(ResScalar(0));
const RhsScalar* blB = unpackedB; const RhsScalar* blB = unpackedB;
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {
#ifndef EIGEN_HAS_FUSE_CJMADD LhsPacket A0 = ei_pload<LhsPacket>(blA);
RhsPacket T0; RhsPacket B0 = ei_pload<RhsPacket>(blB);
#endif MADD(pcj, A0, B0, C0, B0);
MADD(pcj,ei_pload<LhsPacket>(blA), ei_pload<RhsPacket>(blB), C0, T0);
blB += RhsPacketSize; blB += RhsPacketSize;
blA += LhsPacketSize; blA += LhsPacketSize;
} }
ei_pstoreu(&res[(j2+0)*resStride + i], C0); ResPacket alphav = ei_pset1<ResPacket>(alpha);
ResPacket R0 = ei_ploadu<ResPacket>(&res[(j2+0)*resStride + i]);
ei_pstoreu(&res[(j2+0)*resStride + i], ei_pmadd(C0, alphav, R0));
} }
for(Index i=peeled_mc2; i<rows; i++) for(Index i=peeled_mc2; i<rows; i++)
{ {
@ -744,12 +760,11 @@ EIGEN_ASM_COMMENT("myend");
const RhsScalar* blB = unpackedB; const RhsScalar* blB = unpackedB;
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {
#ifndef EIGEN_HAS_FUSE_CJMADD LhsScalar A0 = blA[k];
ResScalar T0; RhsScalar B0 = blB[k*RhsPacketSize];
#endif MADD(cj, A0, B0, C0, B0);
MADD(cj,blA[k], blB[k*RhsPacketSize], C0, T0);
} }
res[(j2+0)*resStride + i] += C0; res[(j2+0)*resStride + i] += alpha*C0;
} }
} }
} }
@ -775,51 +790,36 @@ template<typename Scalar, typename Index, int mr, int StorageOrder, bool Conjuga
struct ei_gemm_pack_lhs struct ei_gemm_pack_lhs
{ {
void operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows, void operator()(Scalar* blockA, const Scalar* EIGEN_RESTRICT _lhs, Index lhsStride, Index depth, Index rows,
Scalar alpha = Scalar(1), Index stride=0, Index offset=0) Index stride=0, Index offset=0)
{ {
enum { PacketSize = ei_packet_traits<Scalar>::size }; enum { PacketSize = ei_packet_traits<Scalar>::size };
ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
ei_const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs,lhsStride);
bool hasAlpha = alpha != Scalar(1);
Index count = 0; Index count = 0;
Index peeled_mc = (rows/mr)*mr; Index peeled_mc = (rows/mr)*mr;
for(Index i=0; i<peeled_mc; i+=mr) for(Index i=0; i<peeled_mc; i+=mr)
{ {
if(PanelMode) count += mr * offset; if(PanelMode) count += mr * offset;
if(hasAlpha) for(Index k=0; k<depth; k++)
for(Index k=0; k<depth; k++) for(Index w=0; w<mr; w++)
for(Index w=0; w<mr; w++) blockA[count++] = cj(lhs(i+w, k));
blockA[count++] = alpha * cj(lhs(i+w, k));
else
for(Index k=0; k<depth; k++)
for(Index w=0; w<mr; w++)
blockA[count++] = cj(lhs(i+w, k));
if(PanelMode) count += mr * (stride-offset-depth); if(PanelMode) count += mr * (stride-offset-depth);
} }
if(rows-peeled_mc>=PacketSize) if(rows-peeled_mc>=PacketSize)
{ {
if(PanelMode) count += PacketSize*offset; if(PanelMode) count += PacketSize*offset;
if(hasAlpha) for(Index k=0; k<depth; k++)
for(Index k=0; k<depth; k++) for(Index w=0; w<PacketSize; w++)
for(Index w=0; w<PacketSize; w++) blockA[count++] = cj(lhs(peeled_mc+w, k));
blockA[count++] = alpha * cj(lhs(peeled_mc+w, k));
else
for(Index k=0; k<depth; k++)
for(Index w=0; w<PacketSize; w++)
blockA[count++] = cj(lhs(peeled_mc+w, k));
if(PanelMode) count += PacketSize * (stride-offset-depth); if(PanelMode) count += PacketSize * (stride-offset-depth);
peeled_mc += PacketSize; peeled_mc += PacketSize;
} }
for(Index i=peeled_mc; i<rows; i++) for(Index i=peeled_mc; i<rows; i++)
{ {
if(PanelMode) count += offset; if(PanelMode) count += offset;
if(hasAlpha) for(Index k=0; k<depth; k++)
for(Index k=0; k<depth; k++) blockA[count++] = cj(lhs(i, k));
blockA[count++] = alpha * cj(lhs(i, k));
else
for(Index k=0; k<depth; k++)
blockA[count++] = cj(lhs(i, k));
if(PanelMode) count += (stride-offset-depth); if(PanelMode) count += (stride-offset-depth);
} }
} }
@ -837,12 +837,11 @@ struct ei_gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode>
{ {
typedef typename ei_packet_traits<Scalar>::type Packet; typedef typename ei_packet_traits<Scalar>::type Packet;
enum { PacketSize = ei_packet_traits<Scalar>::size }; enum { PacketSize = ei_packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Scalar alpha, Index depth, Index cols, void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols,
Index stride=0, Index offset=0) Index stride=0, Index offset=0)
{ {
ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
bool hasAlpha = alpha != Scalar(1);
Index packet_cols = (cols/nr) * nr; Index packet_cols = (cols/nr) * nr;
Index count = 0; Index count = 0;
for(Index j2=0; j2<packet_cols; j2+=nr) for(Index j2=0; j2<packet_cols; j2+=nr)
@ -853,24 +852,14 @@ struct ei_gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode>
const Scalar* b1 = &rhs[(j2+1)*rhsStride]; const Scalar* b1 = &rhs[(j2+1)*rhsStride];
const Scalar* b2 = &rhs[(j2+2)*rhsStride]; const Scalar* b2 = &rhs[(j2+2)*rhsStride];
const Scalar* b3 = &rhs[(j2+3)*rhsStride]; const Scalar* b3 = &rhs[(j2+3)*rhsStride];
if (hasAlpha) for(Index k=0; k<depth; k++)
for(Index k=0; k<depth; k++) {
{ blockB[count+0] = cj(b0[k]);
blockB[count+0] = alpha*cj(b0[k]); blockB[count+1] = cj(b1[k]);
blockB[count+1] = alpha*cj(b1[k]); if(nr==4) blockB[count+2] = cj(b2[k]);
if(nr==4) blockB[count+2] = alpha*cj(b2[k]); if(nr==4) blockB[count+3] = cj(b3[k]);
if(nr==4) blockB[count+3] = alpha*cj(b3[k]); count += nr;
count += nr; }
}
else
for(Index k=0; k<depth; k++)
{
blockB[count+0] = cj(b0[k]);
blockB[count+1] = cj(b1[k]);
if(nr==4) blockB[count+2] = cj(b2[k]);
if(nr==4) blockB[count+3] = cj(b3[k]);
count += nr;
}
// skip what we have after // skip what we have after
if(PanelMode) count += nr * (stride-offset-depth); if(PanelMode) count += nr * (stride-offset-depth);
} }
@ -880,18 +869,11 @@ struct ei_gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, PanelMode>
{ {
if(PanelMode) count += offset; if(PanelMode) count += offset;
const Scalar* b0 = &rhs[(j2+0)*rhsStride]; const Scalar* b0 = &rhs[(j2+0)*rhsStride];
if (hasAlpha) for(Index k=0; k<depth; k++)
for(Index k=0; k<depth; k++) {
{ blockB[count] = cj(b0[k]);
blockB[count] = alpha*cj(b0[k]); count += 1;
count += 1; }
}
else
for(Index k=0; k<depth; k++)
{
blockB[count] = cj(b0[k]);
count += 1;
}
if(PanelMode) count += (stride-offset-depth); if(PanelMode) count += (stride-offset-depth);
} }
} }
@ -902,41 +884,25 @@ template<typename Scalar, typename Index, int nr, bool Conjugate, bool PanelMode
struct ei_gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode> struct ei_gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode>
{ {
enum { PacketSize = ei_packet_traits<Scalar>::size }; enum { PacketSize = ei_packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Scalar alpha, Index depth, Index cols, void operator()(Scalar* blockB, const Scalar* rhs, Index rhsStride, Index depth, Index cols,
Index stride=0, Index offset=0) Index stride=0, Index offset=0)
{ {
ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride)); ei_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj; ei_conj_if<NumTraits<Scalar>::IsComplex && Conjugate> cj;
bool hasAlpha = alpha != Scalar(1);
Index packet_cols = (cols/nr) * nr; Index packet_cols = (cols/nr) * nr;
Index count = 0; Index count = 0;
for(Index j2=0; j2<packet_cols; j2+=nr) for(Index j2=0; j2<packet_cols; j2+=nr)
{ {
// skip what we have before // skip what we have before
if(PanelMode) count += nr * offset; if(PanelMode) count += nr * offset;
if (hasAlpha) for(Index k=0; k<depth; k++)
{ {
for(Index k=0; k<depth; k++) const Scalar* b0 = &rhs[k*rhsStride + j2];
{ blockB[count+0] = cj(b0[0]);
const Scalar* b0 = &rhs[k*rhsStride + j2]; blockB[count+1] = cj(b0[1]);
blockB[count+0] = alpha*cj(b0[0]); if(nr==4) blockB[count+2] = cj(b0[2]);
blockB[count+1] = alpha*cj(b0[1]); if(nr==4) blockB[count+3] = cj(b0[3]);
if(nr==4) blockB[count+2] = alpha*cj(b0[2]); count += nr;
if(nr==4) blockB[count+3] = alpha*cj(b0[3]);
count += nr;
}
}
else
{
for(Index k=0; k<depth; k++)
{
const Scalar* b0 = &rhs[k*rhsStride + j2];
blockB[count+0] = cj(b0[0]);
blockB[count+1] = cj(b0[1]);
if(nr==4) blockB[count+2] = cj(b0[2]);
if(nr==4) blockB[count+3] = cj(b0[3]);
count += nr;
}
} }
// skip what we have after // skip what we have after
if(PanelMode) count += nr * (stride-offset-depth); if(PanelMode) count += nr * (stride-offset-depth);
@ -948,7 +914,7 @@ struct ei_gemm_pack_rhs<Scalar, Index, nr, RowMajor, Conjugate, PanelMode>
const Scalar* b0 = &rhs[j2]; const Scalar* b0 = &rhs[j2];
for(Index k=0; k<depth; k++) for(Index k=0; k<depth; k++)
{ {
blockB[count] = alpha*cj(b0[k*rhsStride]); blockB[count] = cj(b0[k*rhsStride]);
count += 1; count += 1;
} }
if(PanelMode) count += stride-offset-depth; if(PanelMode) count += stride-offset-depth;

View File

@ -74,7 +74,6 @@ static void run(Index rows, Index cols, Index depth,
ei_const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); ei_const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
typedef ei_product_blocking_traits<LhsScalar,RhsScalar> Blocking; typedef ei_product_blocking_traits<LhsScalar,RhsScalar> Blocking;
bool alphaOnLhs = NumTraits<LhsScalar>::IsComplex && !NumTraits<RhsScalar>::IsComplex;
Index kc = blocking.kc(); // cache block size along the K direction Index kc = blocking.kc(); // cache block size along the K direction
Index mc = std::min(rows,blocking.mc()); // cache block size along the M direction Index mc = std::min(rows,blocking.mc()); // cache block size along the M direction
@ -87,8 +86,6 @@ static void run(Index rows, Index cols, Index depth,
ei_gemm_pack_rhs<RhsScalar, Index, Blocking::nr, RhsStorageOrder, ConjugateRhs> pack_rhs; ei_gemm_pack_rhs<RhsScalar, Index, Blocking::nr, RhsStorageOrder, ConjugateRhs> pack_rhs;
ei_gebp_kernel<LhsScalar, RhsScalar, Index, Blocking::mr, Blocking::nr> gebp; ei_gebp_kernel<LhsScalar, RhsScalar, Index, Blocking::mr, Blocking::nr> gebp;
// if ((ConjugateRhs && !alphaOnLhs) || (ConjugateLhs && alphaOnLhs))
// alpha = ei_conj(alpha);
// ei_gemm_pack_lhs<LhsScalar, Index, Blocking::mr, LhsStorageOrder> pack_lhs; // ei_gemm_pack_lhs<LhsScalar, Index, Blocking::mr, LhsStorageOrder> pack_lhs;
// ei_gemm_pack_rhs<RhsScalar, Index, Blocking::nr, RhsStorageOrder> pack_rhs; // ei_gemm_pack_rhs<RhsScalar, Index, Blocking::nr, RhsStorageOrder> pack_rhs;
// ei_gebp_kernel<LhsScalar, RhsScalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp; // ei_gebp_kernel<LhsScalar, RhsScalar, Index, Blocking::mr, Blocking::nr, ConjugateLhs, ConjugateRhs> gebp;
@ -178,9 +175,6 @@ static void run(Index rows, Index cols, Index depth,
RhsScalar *blockB = blocking.blockB()==0 ? ei_aligned_stack_new(RhsScalar, sizeB) : blocking.blockB(); RhsScalar *blockB = blocking.blockB()==0 ? ei_aligned_stack_new(RhsScalar, sizeB) : blocking.blockB();
RhsScalar *blockW = blocking.blockW()==0 ? ei_aligned_stack_new(RhsScalar, sizeW) : blocking.blockW(); RhsScalar *blockW = blocking.blockW()==0 ? ei_aligned_stack_new(RhsScalar, sizeW) : blocking.blockW();
LhsScalar lhsAlpha = alphaOnLhs ? ei_get_factor<ResScalar,LhsScalar>::run(alpha) : LhsScalar(1);
RhsScalar rhsAlpha = alphaOnLhs ? RhsScalar(1) : ei_get_factor<ResScalar,RhsScalar>::run(alpha);
// For each horizontal panel of the rhs, and corresponding panel of the lhs... // For each horizontal panel of the rhs, and corresponding panel of the lhs...
// (==GEMM_VAR1) // (==GEMM_VAR1)
for(Index k2=0; k2<depth; k2+=kc) for(Index k2=0; k2<depth; k2+=kc)
@ -191,7 +185,7 @@ static void run(Index rows, Index cols, Index depth,
// => Pack rhs's panel into a sequential chunk of memory (L2 caching) // => Pack rhs's panel into a sequential chunk of memory (L2 caching)
// Note that this panel will be read as many times as the number of blocks in the lhs's // Note that this panel will be read as many times as the number of blocks in the lhs's
// vertical panel which is, in practice, a very low number. // vertical panel which is, in practice, a very low number.
pack_rhs(blockB, &rhs(k2,0), rhsStride, rhsAlpha, actual_kc, cols); pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, cols);
// For each mc x kc block of the lhs's vertical panel... // For each mc x kc block of the lhs's vertical panel...
@ -203,10 +197,10 @@ static void run(Index rows, Index cols, Index depth,
// We pack the lhs's block into a sequential chunk of memory (L1 caching) // We pack the lhs's block into a sequential chunk of memory (L1 caching)
// Note that this block will be read a very high number of times, which is equal to the number of // Note that this block will be read a very high number of times, which is equal to the number of
// micro vertical panel of the large rhs's panel (e.g., cols/4 times). // micro vertical panel of the large rhs's panel (e.g., cols/4 times).
pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc, lhsAlpha); pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc);
// Everything is packed, we can now call the block * panel kernel: // Everything is packed, we can now call the block * panel kernel:
gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, -1, -1, 0, 0, blockW); gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0, blockW);
} }
} }

View File

@ -89,7 +89,7 @@ template<typename Scalar, typename Index, int nr, int StorageOrder>
struct ei_symm_pack_rhs struct ei_symm_pack_rhs
{ {
enum { PacketSize = ei_packet_traits<Scalar>::size }; enum { PacketSize = ei_packet_traits<Scalar>::size };
void operator()(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Scalar alpha, Index rows, Index cols, Index k2) void operator()(Scalar* blockB, const Scalar* _rhs, Index rhsStride, Index rows, Index cols, Index k2)
{ {
Index end_k = k2 + rows; Index end_k = k2 + rows;
Index count = 0; Index count = 0;
@ -101,12 +101,12 @@ struct ei_symm_pack_rhs
{ {
for(Index k=k2; k<end_k; k++) for(Index k=k2; k<end_k; k++)
{ {
blockB[count+0] = alpha*rhs(k,j2+0); blockB[count+0] = rhs(k,j2+0);
blockB[count+1] = alpha*rhs(k,j2+1); blockB[count+1] = rhs(k,j2+1);
if (nr==4) if (nr==4)
{ {
blockB[count+2] = alpha*rhs(k,j2+2); blockB[count+2] = rhs(k,j2+2);
blockB[count+3] = alpha*rhs(k,j2+3); blockB[count+3] = rhs(k,j2+3);
} }
count += nr; count += nr;
} }
@ -119,12 +119,12 @@ struct ei_symm_pack_rhs
// transpose // transpose
for(Index k=k2; k<j2; k++) for(Index k=k2; k<j2; k++)
{ {
blockB[count+0] = alpha*ei_conj(rhs(j2+0,k)); blockB[count+0] = ei_conj(rhs(j2+0,k));
blockB[count+1] = alpha*ei_conj(rhs(j2+1,k)); blockB[count+1] = ei_conj(rhs(j2+1,k));
if (nr==4) if (nr==4)
{ {
blockB[count+2] = alpha*ei_conj(rhs(j2+2,k)); blockB[count+2] = ei_conj(rhs(j2+2,k));
blockB[count+3] = alpha*ei_conj(rhs(j2+3,k)); blockB[count+3] = ei_conj(rhs(j2+3,k));
} }
count += nr; count += nr;
} }
@ -134,25 +134,25 @@ struct ei_symm_pack_rhs
{ {
// normal // normal
for (Index w=0 ; w<h; ++w) for (Index w=0 ; w<h; ++w)
blockB[count+w] = alpha*rhs(k,j2+w); blockB[count+w] = rhs(k,j2+w);
blockB[count+h] = alpha*rhs(k,k); blockB[count+h] = rhs(k,k);
// transpose // transpose
for (Index w=h+1 ; w<nr; ++w) for (Index w=h+1 ; w<nr; ++w)
blockB[count+w] = alpha*ei_conj(rhs(j2+w,k)); blockB[count+w] = ei_conj(rhs(j2+w,k));
count += nr; count += nr;
++h; ++h;
} }
// normal // normal
for(Index k=j2+nr; k<end_k; k++) for(Index k=j2+nr; k<end_k; k++)
{ {
blockB[count+0] = alpha*rhs(k,j2+0); blockB[count+0] = rhs(k,j2+0);
blockB[count+1] = alpha*rhs(k,j2+1); blockB[count+1] = rhs(k,j2+1);
if (nr==4) if (nr==4)
{ {
blockB[count+2] = alpha*rhs(k,j2+2); blockB[count+2] = rhs(k,j2+2);
blockB[count+3] = alpha*rhs(k,j2+3); blockB[count+3] = rhs(k,j2+3);
} }
count += nr; count += nr;
} }
@ -163,12 +163,12 @@ struct ei_symm_pack_rhs
{ {
for(Index k=k2; k<end_k; k++) for(Index k=k2; k<end_k; k++)
{ {
blockB[count+0] = alpha*ei_conj(rhs(j2+0,k)); blockB[count+0] = ei_conj(rhs(j2+0,k));
blockB[count+1] = alpha*ei_conj(rhs(j2+1,k)); blockB[count+1] = ei_conj(rhs(j2+1,k));
if (nr==4) if (nr==4)
{ {
blockB[count+2] = alpha*ei_conj(rhs(j2+2,k)); blockB[count+2] = ei_conj(rhs(j2+2,k));
blockB[count+3] = alpha*ei_conj(rhs(j2+3,k)); blockB[count+3] = ei_conj(rhs(j2+3,k));
} }
count += nr; count += nr;
} }
@ -181,13 +181,13 @@ struct ei_symm_pack_rhs
Index half = std::min(end_k,j2); Index half = std::min(end_k,j2);
for(Index k=k2; k<half; k++) for(Index k=k2; k<half; k++)
{ {
blockB[count] = alpha*ei_conj(rhs(j2,k)); blockB[count] = ei_conj(rhs(j2,k));
count += 1; count += 1;
} }
if(half==j2 && half<k2+rows) if(half==j2 && half<k2+rows)
{ {
blockB[count] = alpha*ei_real(rhs(j2,j2)); blockB[count] = ei_real(rhs(j2,j2));
count += 1; count += 1;
} }
else else
@ -196,7 +196,7 @@ struct ei_symm_pack_rhs
// normal // normal
for(Index k=half+1; k<k2+rows; k++) for(Index k=half+1; k<k2+rows; k++)
{ {
blockB[count] = alpha*rhs(k,j2); blockB[count] = rhs(k,j2);
count += 1; count += 1;
} }
} }
@ -253,9 +253,6 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,Conjugate
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
if (ConjugateRhs)
alpha = ei_conj(alpha);
typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; typedef ei_product_blocking_traits<Scalar,Scalar> Blocking;
Index kc = size; // cache block size along the K direction Index kc = size; // cache block size along the K direction
@ -282,7 +279,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,Conjugate
// we have selected one row panel of rhs and one column panel of lhs // we have selected one row panel of rhs and one column panel of lhs
// pack rhs's panel into a sequential chunk of memory // pack rhs's panel into a sequential chunk of memory
// and expand each coeff to a constant packet for further reuse // and expand each coeff to a constant packet for further reuse
pack_rhs(blockB, &rhs(k2,0), rhsStride, alpha, actual_kc, cols); pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, cols);
// the select lhs's panel has to be split in three different parts: // the select lhs's panel has to be split in three different parts:
// 1 - the transposed panel above the diagonal block => transposed packed copy // 1 - the transposed panel above the diagonal block => transposed packed copy
@ -294,7 +291,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,Conjugate
// transposed packed copy // transposed packed copy
pack_lhs_transposed(blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc); pack_lhs_transposed(blockA, &lhs(k2, i2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols); gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
// the block diagonal // the block diagonal
{ {
@ -302,7 +299,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,Conjugate
// symmetric packed copy // symmetric packed copy
pack_lhs(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc); pack_lhs(blockA, &lhs(k2,k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res+k2, resStride, blockA, blockB, actual_mc, actual_kc, cols); gebp_kernel(res+k2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
for(Index i2=k2+kc; i2<size; i2+=mc) for(Index i2=k2+kc; i2<size; i2+=mc)
@ -311,7 +308,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,true,Conjugate
ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder,false>() ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder,false>()
(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); (blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols); gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
} }
@ -338,9 +335,6 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,Conjugat
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
if (ConjugateRhs)
alpha = ei_conj(alpha);
typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; typedef ei_product_blocking_traits<Scalar,Scalar> Blocking;
Index kc = size; // cache block size along the K direction Index kc = size; // cache block size along the K direction
@ -361,7 +355,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,Conjugat
{ {
const Index actual_kc = std::min(k2+kc,size)-k2; const Index actual_kc = std::min(k2+kc,size)-k2;
pack_rhs(blockB, _rhs, rhsStride, alpha, actual_kc, cols, k2); pack_rhs(blockB, _rhs, rhsStride, actual_kc, cols, k2);
// => GEPP // => GEPP
for(Index i2=0; i2<rows; i2+=mc) for(Index i2=0; i2<rows; i2+=mc)
@ -369,7 +363,7 @@ struct ei_product_selfadjoint_matrix<Scalar,Index,LhsStorageOrder,false,Conjugat
const Index actual_mc = std::min(i2+mc,rows)-i2; const Index actual_mc = std::min(i2+mc,rows)-i2;
pack_lhs(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc); pack_lhs(blockA, &lhs(i2, k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols); gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
} }

View File

@ -65,8 +65,8 @@ struct ei_selfadjoint_product<Scalar, Index, MatStorageOrder, ColMajor, AAT, UpL
{ {
ei_const_blas_data_mapper<Scalar, Index, MatStorageOrder> mat(_mat,matStride); ei_const_blas_data_mapper<Scalar, Index, MatStorageOrder> mat(_mat,matStride);
if(AAT) // if(AAT)
alpha = ei_conj(alpha); // alpha = ei_conj(alpha);
typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; typedef ei_product_blocking_traits<Scalar,Scalar> Blocking;
@ -99,7 +99,7 @@ struct ei_selfadjoint_product<Scalar, Index, MatStorageOrder, ColMajor, AAT, UpL
const Index actual_kc = std::min(k2+kc,depth)-k2; const Index actual_kc = std::min(k2+kc,depth)-k2;
// note that the actual rhs is the transpose/adjoint of mat // note that the actual rhs is the transpose/adjoint of mat
pack_rhs(blockB, &mat(0,k2), matStride, alpha, actual_kc, size); pack_rhs(blockB, &mat(0,k2), matStride, actual_kc, size);
for(Index i2=0; i2<size; i2+=mc) for(Index i2=0; i2<size; i2+=mc)
{ {
@ -112,15 +112,15 @@ struct ei_selfadjoint_product<Scalar, Index, MatStorageOrder, ColMajor, AAT, UpL
// 2 - the actual_mc x actual_mc symmetric block => processed with a special kernel // 2 - the actual_mc x actual_mc symmetric block => processed with a special kernel
// 3 - after the diagonal => processed with gebp or skipped // 3 - after the diagonal => processed with gebp or skipped
if (UpLo==Lower) if (UpLo==Lower)
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2), gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2), alpha,
-1, -1, 0, 0, allocatedBlockB); -1, -1, 0, 0, allocatedBlockB);
sybb(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, allocatedBlockB); sybb(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, alpha, allocatedBlockB);
if (UpLo==Upper) if (UpLo==Upper)
{ {
Index j2 = i2+actual_mc; Index j2 = i2+actual_mc;
gebp_kernel(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*j2, actual_mc, actual_kc, std::max(Index(0),size-j2), gebp_kernel(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*j2, actual_mc, actual_kc, std::max(Index(0), size-j2), alpha,
-1, -1, 0, 0, allocatedBlockB); -1, -1, 0, 0, allocatedBlockB);
} }
} }
@ -173,7 +173,7 @@ struct ei_sybb_kernel
PacketSize = ei_packet_traits<Scalar>::size, PacketSize = ei_packet_traits<Scalar>::size,
BlockSize = EIGEN_PLAIN_ENUM_MAX(mr,nr) BlockSize = EIGEN_PLAIN_ENUM_MAX(mr,nr)
}; };
void operator()(Scalar* res, Index resStride, const Scalar* blockA, const Scalar* blockB, Index size, Index depth, Scalar* workspace) void operator()(Scalar* res, Index resStride, const Scalar* blockA, const Scalar* blockB, Index size, Index depth, Scalar alpha, Scalar* workspace)
{ {
ei_gebp_kernel<Scalar, Scalar, Index, mr, nr, ConjLhs, ConjRhs> gebp_kernel; ei_gebp_kernel<Scalar, Scalar, Index, mr, nr, ConjLhs, ConjRhs> gebp_kernel;
Matrix<Scalar,BlockSize,BlockSize,ColMajor> buffer; Matrix<Scalar,BlockSize,BlockSize,ColMajor> buffer;
@ -186,14 +186,15 @@ struct ei_sybb_kernel
const Scalar* actual_b = blockB+j*depth; const Scalar* actual_b = blockB+j*depth;
if(UpLo==Upper) if(UpLo==Upper)
gebp_kernel(res+j*resStride, resStride, blockA, actual_b, j, depth, actualBlockSize, -1, -1, 0, 0, workspace); gebp_kernel(res+j*resStride, resStride, blockA, actual_b, j, depth, actualBlockSize, alpha,
-1, -1, 0, 0, workspace);
// selfadjoint micro block // selfadjoint micro block
{ {
Index i = j; Index i = j;
buffer.setZero(); buffer.setZero();
// 1 - apply the kernel on the temporary buffer // 1 - apply the kernel on the temporary buffer
gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize, alpha,
-1, -1, 0, 0, workspace); -1, -1, 0, 0, workspace);
// 2 - triangular accumulation // 2 - triangular accumulation
for(Index j1=0; j1<actualBlockSize; ++j1) for(Index j1=0; j1<actualBlockSize; ++j1)
@ -208,7 +209,7 @@ struct ei_sybb_kernel
if(UpLo==Lower) if(UpLo==Lower)
{ {
Index i = j+actualBlockSize; Index i = j+actualBlockSize;
gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize, gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize, alpha,
-1, -1, 0, 0, workspace); -1, -1, 0, 0, workspace);
} }
} }

View File

@ -105,9 +105,6 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
if (ConjugateRhs)
alpha = ei_conj(alpha);
typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; typedef ei_product_blocking_traits<Scalar,Scalar> Blocking;
enum { enum {
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr),
@ -147,7 +144,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
k2 = k2+actual_kc-kc; k2 = k2+actual_kc-kc;
} }
pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, alpha, actual_kc, cols); pack_rhs(blockB, &rhs(actual_k2,0), rhsStride, actual_kc, cols);
// the selected lhs's panel has to be split in three different parts: // the selected lhs's panel has to be split in three different parts:
// 1 - the part which is above the diagonal block => skip it // 1 - the part which is above the diagonal block => skip it
@ -176,7 +173,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
} }
pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.outerStride(), actualPanelWidth, actualPanelWidth); pack_lhs(blockA, triangularBuffer.data(), triangularBuffer.outerStride(), actualPanelWidth, actualPanelWidth);
gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols, gebp_kernel(res+startBlock, resStride, blockA, blockB, actualPanelWidth, actualPanelWidth, cols, alpha,
actualPanelWidth, actual_kc, 0, blockBOffset); actualPanelWidth, actual_kc, 0, blockBOffset);
// GEBP with remaining micro panel // GEBP with remaining micro panel
@ -186,7 +183,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget); pack_lhs(blockA, &lhs(startTarget,startBlock), lhsStride, actualPanelWidth, lengthTarget);
gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, gebp_kernel(res+startTarget, resStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, alpha,
actualPanelWidth, actual_kc, 0, blockBOffset); actualPanelWidth, actual_kc, 0, blockBOffset);
} }
} }
@ -201,7 +198,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,true,
ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder,false>() ei_gemm_pack_lhs<Scalar, Index, Blocking::mr,LhsStorageOrder,false>()
(blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc); (blockA, &lhs(i2, actual_k2), lhsStride, actual_kc, actual_mc);
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols); gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha);
} }
} }
} }
@ -231,9 +228,6 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride); ei_const_blas_data_mapper<Scalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride); ei_const_blas_data_mapper<Scalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
if (ConjugateRhs)
alpha = ei_conj(alpha);
typedef ei_product_blocking_traits<Scalar,Scalar> Blocking; typedef ei_product_blocking_traits<Scalar,Scalar> Blocking;
enum { enum {
SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr), SmallPanelWidth = EIGEN_PLAIN_ENUM_MAX(Blocking::mr,Blocking::nr),
@ -280,7 +274,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
Scalar* geb = blockB+ts*ts; Scalar* geb = blockB+ts*ts;
pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, alpha, actual_kc, rs); pack_rhs(geb, &rhs(actual_k2,IsLower ? 0 : k2), rhsStride, actual_kc, rs);
// pack the triangular part of the rhs padding the unrolled blocks with zeros // pack the triangular part of the rhs padding the unrolled blocks with zeros
if(ts>0) if(ts>0)
@ -293,7 +287,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2; Index panelLength = IsLower ? actual_kc-j2-actualPanelWidth : j2;
// general part // general part
pack_rhs_panel(blockB+j2*actual_kc, pack_rhs_panel(blockB+j2*actual_kc,
&rhs(actual_k2+panelOffset, actual_j2), rhsStride, alpha, &rhs(actual_k2+panelOffset, actual_j2), rhsStride,
panelLength, actualPanelWidth, panelLength, actualPanelWidth,
actual_kc, panelOffset); actual_kc, panelOffset);
@ -307,7 +301,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
} }
pack_rhs_panel(blockB+j2*actual_kc, pack_rhs_panel(blockB+j2*actual_kc,
triangularBuffer.data(), triangularBuffer.outerStride(), alpha, triangularBuffer.data(), triangularBuffer.outerStride(),
actualPanelWidth, actualPanelWidth, actualPanelWidth, actualPanelWidth,
actual_kc, j2); actual_kc, j2);
} }
@ -330,6 +324,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride, gebp_kernel(res+i2+(actual_k2+j2)*resStride, resStride,
blockA, blockB+j2*actual_kc, blockA, blockB+j2*actual_kc,
actual_mc, panelLength, actualPanelWidth, actual_mc, panelLength, actualPanelWidth,
alpha,
actual_kc, actual_kc, // strides actual_kc, actual_kc, // strides
blockOffset, blockOffset,// offsets blockOffset, blockOffset,// offsets
allocatedBlockB); // workspace allocatedBlockB); // workspace
@ -337,6 +332,7 @@ struct ei_product_triangular_matrix_matrix<Scalar,Index,Mode,false,
} }
gebp_kernel(res+i2+(IsLower ? 0 : k2)*resStride, resStride, gebp_kernel(res+i2+(IsLower ? 0 : k2)*resStride, resStride,
blockA, geb, actual_mc, actual_kc, rs, blockA, geb, actual_mc, actual_kc, rs,
alpha,
-1, -1, 0, 0, allocatedBlockB); -1, -1, 0, 0, allocatedBlockB);
} }
} }

View File

@ -140,7 +140,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStora
Index blockBOffset = IsLower ? k1 : lengthTarget; Index blockBOffset = IsLower ? k1 : lengthTarget;
// update the respective rows of B from other // update the respective rows of B from other
pack_rhs(blockB, _other+startBlock, otherStride, -1, actualPanelWidth, cols, actual_kc, blockBOffset); pack_rhs(blockB, _other+startBlock, otherStride, actualPanelWidth, cols, actual_kc, blockBOffset);
// GEBP // GEBP
if (lengthTarget>0) if (lengthTarget>0)
@ -149,7 +149,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStora
pack_lhs(blockA, &tri(startTarget,startBlock), triStride, actualPanelWidth, lengthTarget); pack_lhs(blockA, &tri(startTarget,startBlock), triStride, actualPanelWidth, lengthTarget);
gebp_kernel(_other+startTarget, otherStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, gebp_kernel(_other+startTarget, otherStride, blockA, blockB, lengthTarget, actualPanelWidth, cols, Scalar(-1),
actualPanelWidth, actual_kc, 0, blockBOffset); actualPanelWidth, actual_kc, 0, blockBOffset);
} }
} }
@ -166,7 +166,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheLeft,Mode,Conjugate,TriStora
{ {
pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc); pack_lhs(blockA, &tri(i2, IsLower ? k2 : k2-kc), triStride, actual_kc, actual_mc);
gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols); gebp_kernel(_other+i2, otherStride, blockA, blockB, actual_mc, actual_kc, cols, Scalar(-1));
} }
} }
} }
@ -242,7 +242,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStor
if (panelLength>0) if (panelLength>0)
pack_rhs_panel(blockB+j2*actual_kc, pack_rhs_panel(blockB+j2*actual_kc,
&rhs(actual_k2+panelOffset, actual_j2), triStride, -1, &rhs(actual_k2+panelOffset, actual_j2), triStride,
panelLength, actualPanelWidth, panelLength, actualPanelWidth,
actual_kc, panelOffset); actual_kc, panelOffset);
} }
@ -273,6 +273,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStor
gebp_kernel(&lhs(i2,absolute_j2), otherStride, gebp_kernel(&lhs(i2,absolute_j2), otherStride,
blockA, blockB+j2*actual_kc, blockA, blockB+j2*actual_kc,
actual_mc, panelLength, actualPanelWidth, actual_mc, panelLength, actualPanelWidth,
Scalar(-1),
actual_kc, actual_kc, // strides actual_kc, actual_kc, // strides
panelOffset, panelOffset, // offsets panelOffset, panelOffset, // offsets
allocatedBlockB); // workspace allocatedBlockB); // workspace
@ -305,7 +306,7 @@ struct ei_triangular_solve_matrix<Scalar,Index,OnTheRight,Mode,Conjugate,TriStor
if (rs>0) if (rs>0)
gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb, gebp_kernel(_other+i2+startPanel*otherStride, otherStride, blockA, geb,
actual_mc, actual_kc, rs, actual_mc, actual_kc, rs, Scalar(-1),
-1, -1, 0, 0, allocatedBlockB); -1, -1, 0, 0, allocatedBlockB);
} }
} }