Fix SparseLU special gemm kernel on 32 bits system w/o SSE

This commit is contained in:
Gael Guennebaud 2013-01-23 19:34:01 +01:00
parent ee36eaefc6
commit 73026eab4d

View File

@ -29,10 +29,11 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
typedef typename packet_traits<Scalar>::type Packet;
enum {
NumberOfRegisters = EIGEN_ARCH_DEFAULT_NUMBER_OF_REGISTERS,
PacketSize = packet_traits<Scalar>::size,
PM = 8, // peeling in M
RN = 2, // register blocking
RK = 4, // register blocking
RK = NumberOfRegisters>=16 ? 4 : 2, // register blocking
BM = 4096/sizeof(Scalar), // number of rows of A-C per chunk
SM = PM*PacketSize // step along M
};
@ -73,12 +74,12 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
Packet b00, b10, b20, b30, b01, b11, b21, b31;
b00 = pset1<Packet>(Bc0[0]);
b10 = pset1<Packet>(Bc0[1]);
b20 = pset1<Packet>(Bc0[2]);
b30 = pset1<Packet>(Bc0[3]);
if(RK==4) b20 = pset1<Packet>(Bc0[2]);
if(RK==4) b30 = pset1<Packet>(Bc0[3]);
b01 = pset1<Packet>(Bc1[0]);
b11 = pset1<Packet>(Bc1[1]);
b21 = pset1<Packet>(Bc1[2]);
b31 = pset1<Packet>(Bc1[3]);
if(RK==4) b21 = pset1<Packet>(Bc1[2]);
if(RK==4) b31 = pset1<Packet>(Bc1[3]);
Packet a0, a1, a2, a3, c0, c1, t0, t1;
@ -92,8 +93,8 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
a0 = pload<Packet>(A0);
a1 = pload<Packet>(A1);
a2 = pload<Packet>(A2);
a3 = pload<Packet>(A3);
if(RK==4) a2 = pload<Packet>(A2);
if(RK==4) a3 = pload<Packet>(A3);
#define KMADD(c, a, b, tmp) tmp = b; tmp = pmul(a,tmp); c = padd(c,tmp);
#define WORK(I) \
@ -105,12 +106,12 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
KMADD(c0, a1, b10, t0); \
KMADD(c1, a1, b11, t1); \
a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
KMADD(c0, a2, b20, t0); \
KMADD(c1, a2, b21, t1); \
a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
KMADD(c0, a3, b30, t0); \
KMADD(c1, a3, b31, t1); \
a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
if(RK==4) KMADD(c0, a2, b20, t0); \
if(RK==4) KMADD(c1, a2, b21, t1); \
if(RK==4) a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
if(RK==4) KMADD(c0, a3, b30, t0); \
if(RK==4) KMADD(c1, a3, b31, t1); \
if(RK==4) a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
pstore(C0+i+(I)*PacketSize, c0); \
pstore(C1+i+(I)*PacketSize, c1)
@ -118,10 +119,10 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
for(int i=0; i<actual_b_end1; i+=PacketSize*8)
{
EIGEN_ASM_COMMENT("SPARSELU_GEMML_KERNEL1");
_mm_prefetch((const char*)(A0+i+(5)*PacketSize), _MM_HINT_T0);
_mm_prefetch((const char*)(A1+i+(5)*PacketSize), _MM_HINT_T0);
_mm_prefetch((const char*)(A2+i+(5)*PacketSize), _MM_HINT_T0);
_mm_prefetch((const char*)(A3+i+(5)*PacketSize), _MM_HINT_T0);
prefetch((A0+i+(5)*PacketSize));
prefetch((A1+i+(5)*PacketSize));
if(RK==4) prefetch((A2+i+(5)*PacketSize));
if(RK==4) prefetch((A3+i+(5)*PacketSize));
WORK(0);
WORK(1);
WORK(2);
@ -138,10 +139,18 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
}
// process the remaining rows without vectorization
for(int i=actual_b_end2; i<actual_b; ++i)
{
if(RK==4)
{
C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1]+A2[i]*Bc1[2]+A3[i]*Bc1[3];
}
else
{
C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
C1[i] += A0[i]*Bc1[0]+A1[i]*Bc1[1];
}
}
Bc0 += RK;
Bc1 += RK;
@ -156,12 +165,12 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
for(int k=0; k<d_end; k+=RK)
{
// load and expand a RN x RK block of B
// load and expand a 1 x RK block of B
Packet b00, b10, b20, b30;
b00 = pset1<Packet>(Bc0[0]);
b10 = pset1<Packet>(Bc0[1]);
b20 = pset1<Packet>(Bc0[2]);
b30 = pset1<Packet>(Bc0[3]);
if(RK==4) b20 = pset1<Packet>(Bc0[2]);
if(RK==4) b30 = pset1<Packet>(Bc0[3]);
Packet a0, a1, a2, a3, c0, t0/*, t1*/;
@ -174,8 +183,8 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
a0 = pload<Packet>(A0);
a1 = pload<Packet>(A1);
a2 = pload<Packet>(A2);
a3 = pload<Packet>(A3);
if(RK==4) a2 = pload<Packet>(A2);
if(RK==4) a3 = pload<Packet>(A3);
#define WORK(I) \
c0 = pload<Packet>(C0+i+(I)*PacketSize); \
@ -183,10 +192,10 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
a0 = pload<Packet>(A0+i+(I+1)*PacketSize); \
KMADD(c0, a1, b10, t0); \
a1 = pload<Packet>(A1+i+(I+1)*PacketSize); \
KMADD(c0, a2, b20, t0); \
a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
KMADD(c0, a3, b30, t0); \
a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
if(RK==4) KMADD(c0, a2, b20, t0); \
if(RK==4) a2 = pload<Packet>(A2+i+(I+1)*PacketSize); \
if(RK==4) KMADD(c0, a3, b30, t0); \
if(RK==4) a3 = pload<Packet>(A3+i+(I+1)*PacketSize); \
pstore(C0+i+(I)*PacketSize, c0);
// agressive vectorization and peeling
@ -210,7 +219,10 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
// remaining scalars
for(int i=actual_b_end2; i<actual_b; ++i)
{
if(RK==4)
C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1]+A2[i]*Bc0[2]+A3[i]*Bc0[3];
else
C0[i] += A0[i]*Bc0[0]+A1[i]*Bc0[1];
}
Bc0 += RK;
@ -224,8 +236,11 @@ void sparselu_gemm(int m, int n, int d, const Scalar* A, int lda, const Scalar*
{
for(int j=0; j<n; ++j)
{
typedef Map<Matrix<Scalar,Dynamic,1>, Aligned > MapVector;
typedef Map<const Matrix<Scalar,Dynamic,1>, Aligned > ConstMapVector;
enum {
Alignment = PacketSize>1 ? Aligned : 0
};
typedef Map<Matrix<Scalar,Dynamic,1>, Alignment > MapVector;
typedef Map<const Matrix<Scalar,Dynamic,1>, Alignment > ConstMapVector;
if(rd==1) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b);
else if(rd==2) MapVector(C+j*ldc+ib,actual_b) += B[0+d_end+j*ldb] * ConstMapVector(A+(d_end+0)*lda+ib, actual_b)