fix issue #100 (fix syrk)

This commit is contained in:
Gael Guennebaud 2010-03-06 21:16:43 +01:00
parent 271fc84e47
commit 6f0b96dcf4

View File

@ -74,47 +74,51 @@ struct ei_selfadjoint_product<Scalar,MatStorageOrder, ColMajor, AAT, UpLo>
int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize);
std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*size;
Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar, sizeB);
Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr;
// note that the actual rhs is the transpose/adjoint of mat
typedef ei_conj_helper<NumTraits<Scalar>::IsComplex && !AAT, NumTraits<Scalar>::IsComplex && AAT> Conj;
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, Conj> gebp_kernel;
ei_gemm_pack_rhs<Scalar,Blocking::nr,MatStorageOrder==RowMajor ? ColMajor : RowMajor> pack_rhs;
ei_gemm_pack_lhs<Scalar,Blocking::mr,MatStorageOrder, false> pack_lhs;
ei_sybb_kernel<Scalar, Blocking::mr, Blocking::nr, Conj, UpLo> sybb;
for(int k2=0; k2<depth; k2+=kc)
{
const int actual_kc = std::min(k2+kc,depth)-k2;
// note that the actual rhs is the transpose/adjoint of mat
ei_gemm_pack_rhs<Scalar,Blocking::nr,MatStorageOrder==RowMajor ? ColMajor : RowMajor>()
(blockB, &mat(0,k2), matStride, alpha, actual_kc, size);
pack_rhs(blockB, &mat(0,k2), matStride, alpha, actual_kc, size);
for(int i2=0; i2<size; i2+=mc)
{
const int actual_mc = std::min(i2+mc,size)-i2;
ei_gemm_pack_lhs<Scalar,Blocking::mr,MatStorageOrder, false>()
(blockA, &mat(i2, k2), matStride, actual_kc, actual_mc);
pack_lhs(blockA, &mat(i2, k2), matStride, actual_kc, actual_mc);
// the selected actual_mc * size panel of res is split into three different part:
// 1 - before the diagonal => processed with gebp or skipped
// 2 - the actual_mc x actual_mc symmetric block => processed with a special kernel
// 3 - after the diagonal => processed with gebp or skipped
if (UpLo==Lower)
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2));
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2),
-1, -1, 0, 0, allocatedBlockB);
ei_sybb_kernel<Scalar, Blocking::mr, Blocking::nr, Conj, UpLo>()
(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*Blocking::PacketSize*i2, actual_mc, actual_kc);
sybb(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, allocatedBlockB);
if (UpLo==Upper)
{
int j2 = i2+actual_mc;
gebp_kernel(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*Blocking::PacketSize*j2, actual_mc, actual_kc, std::max(0,size-j2));
gebp_kernel(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*j2, actual_mc, actual_kc, std::max(0,size-j2),
-1, -1, 0, 0, allocatedBlockB);
}
}
}
ei_aligned_stack_delete(Scalar, blockA, kc*mc);
ei_aligned_stack_delete(Scalar, blockB, kc*size*Blocking::PacketSize);
ei_aligned_stack_delete(Scalar, allocatedBlockB, sizeB);
}
};
@ -161,7 +165,7 @@ struct ei_sybb_kernel
PacketSize = ei_packet_traits<Scalar>::size,
BlockSize = EIGEN_ENUM_MAX(mr,nr)
};
void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int size, int depth)
void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int size, int depth, Scalar* workspace)
{
ei_gebp_kernel<Scalar, mr, nr, Conj> gebp_kernel;
Matrix<Scalar,BlockSize,BlockSize,ColMajor> buffer;
@ -171,7 +175,7 @@ struct ei_sybb_kernel
for (int j=0; j<size; j+=BlockSize)
{
int actualBlockSize = std::min<int>(BlockSize,size - j);
const Scalar* actual_b = blockB+j*depth*PacketSize;
const Scalar* actual_b = blockB+j*depth;
if(UpLo==Upper)
gebp_kernel(res+j*resStride, resStride, blockA, actual_b, j, depth, actualBlockSize);
@ -181,7 +185,8 @@ struct ei_sybb_kernel
int i = j;
buffer.setZero();
// 1 - apply the kernel on the temporary buffer
gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize);
gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize,
-1, -1, 0, 0, workspace);
// 2 - triangular accumulation
for(int j1=0; j1<actualBlockSize; ++j1)
{
@ -195,7 +200,8 @@ struct ei_sybb_kernel
if(UpLo==Lower)
{
int i = j+actualBlockSize;
gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize);
gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize,
-1, -1, 0, 0, workspace);
}
}
}