fix issue #100 (fix syrk)

This commit is contained in:
Gael Guennebaud 2010-03-06 21:16:43 +01:00
parent 271fc84e47
commit 6f0b96dcf4

View File

@ -74,47 +74,51 @@ struct ei_selfadjoint_product<Scalar,MatStorageOrder, ColMajor, AAT, UpLo>
int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction int mc = std::min<int>(Blocking::Max_mc,size); // cache block size along the M direction
Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc); Scalar* blockA = ei_aligned_stack_new(Scalar, kc*mc);
Scalar* blockB = ei_aligned_stack_new(Scalar, kc*size*Blocking::PacketSize); std::size_t sizeB = kc*Blocking::PacketSize*Blocking::nr + kc*size;
Scalar* allocatedBlockB = ei_aligned_stack_new(Scalar, sizeB);
Scalar* blockB = allocatedBlockB + kc*Blocking::PacketSize*Blocking::nr;
// note that the actual rhs is the transpose/adjoint of mat // note that the actual rhs is the transpose/adjoint of mat
typedef ei_conj_helper<NumTraits<Scalar>::IsComplex && !AAT, NumTraits<Scalar>::IsComplex && AAT> Conj; typedef ei_conj_helper<NumTraits<Scalar>::IsComplex && !AAT, NumTraits<Scalar>::IsComplex && AAT> Conj;
ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, Conj> gebp_kernel; ei_gebp_kernel<Scalar, Blocking::mr, Blocking::nr, Conj> gebp_kernel;
ei_gemm_pack_rhs<Scalar,Blocking::nr,MatStorageOrder==RowMajor ? ColMajor : RowMajor> pack_rhs;
ei_gemm_pack_lhs<Scalar,Blocking::mr,MatStorageOrder, false> pack_lhs;
ei_sybb_kernel<Scalar, Blocking::mr, Blocking::nr, Conj, UpLo> sybb;
for(int k2=0; k2<depth; k2+=kc) for(int k2=0; k2<depth; k2+=kc)
{ {
const int actual_kc = std::min(k2+kc,depth)-k2; const int actual_kc = std::min(k2+kc,depth)-k2;
// note that the actual rhs is the transpose/adjoint of mat // note that the actual rhs is the transpose/adjoint of mat
ei_gemm_pack_rhs<Scalar,Blocking::nr,MatStorageOrder==RowMajor ? ColMajor : RowMajor>() pack_rhs(blockB, &mat(0,k2), matStride, alpha, actual_kc, size);
(blockB, &mat(0,k2), matStride, alpha, actual_kc, size);
for(int i2=0; i2<size; i2+=mc) for(int i2=0; i2<size; i2+=mc)
{ {
const int actual_mc = std::min(i2+mc,size)-i2; const int actual_mc = std::min(i2+mc,size)-i2;
ei_gemm_pack_lhs<Scalar,Blocking::mr,MatStorageOrder, false>() pack_lhs(blockA, &mat(i2, k2), matStride, actual_kc, actual_mc);
(blockA, &mat(i2, k2), matStride, actual_kc, actual_mc);
// the selected actual_mc * size panel of res is split into three different part: // the selected actual_mc * size panel of res is split into three different part:
// 1 - before the diagonal => processed with gebp or skipped // 1 - before the diagonal => processed with gebp or skipped
// 2 - the actual_mc x actual_mc symmetric block => processed with a special kernel // 2 - the actual_mc x actual_mc symmetric block => processed with a special kernel
// 3 - after the diagonal => processed with gebp or skipped // 3 - after the diagonal => processed with gebp or skipped
if (UpLo==Lower) if (UpLo==Lower)
gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2)); gebp_kernel(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, std::min(size,i2),
-1, -1, 0, 0, allocatedBlockB);
ei_sybb_kernel<Scalar, Blocking::mr, Blocking::nr, Conj, UpLo>() sybb(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*i2, actual_mc, actual_kc, allocatedBlockB);
(res+resStride*i2 + i2, resStride, blockA, blockB + actual_kc*Blocking::PacketSize*i2, actual_mc, actual_kc);
if (UpLo==Upper) if (UpLo==Upper)
{ {
int j2 = i2+actual_mc; int j2 = i2+actual_mc;
gebp_kernel(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*Blocking::PacketSize*j2, actual_mc, actual_kc, std::max(0,size-j2)); gebp_kernel(res+resStride*j2+i2, resStride, blockA, blockB+actual_kc*j2, actual_mc, actual_kc, std::max(0,size-j2),
-1, -1, 0, 0, allocatedBlockB);
} }
} }
} }
ei_aligned_stack_delete(Scalar, blockA, kc*mc); ei_aligned_stack_delete(Scalar, blockA, kc*mc);
ei_aligned_stack_delete(Scalar, blockB, kc*size*Blocking::PacketSize); ei_aligned_stack_delete(Scalar, allocatedBlockB, sizeB);
} }
}; };
@ -161,7 +165,7 @@ struct ei_sybb_kernel
PacketSize = ei_packet_traits<Scalar>::size, PacketSize = ei_packet_traits<Scalar>::size,
BlockSize = EIGEN_ENUM_MAX(mr,nr) BlockSize = EIGEN_ENUM_MAX(mr,nr)
}; };
void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int size, int depth) void operator()(Scalar* res, int resStride, const Scalar* blockA, const Scalar* blockB, int size, int depth, Scalar* workspace)
{ {
ei_gebp_kernel<Scalar, mr, nr, Conj> gebp_kernel; ei_gebp_kernel<Scalar, mr, nr, Conj> gebp_kernel;
Matrix<Scalar,BlockSize,BlockSize,ColMajor> buffer; Matrix<Scalar,BlockSize,BlockSize,ColMajor> buffer;
@ -171,7 +175,7 @@ struct ei_sybb_kernel
for (int j=0; j<size; j+=BlockSize) for (int j=0; j<size; j+=BlockSize)
{ {
int actualBlockSize = std::min<int>(BlockSize,size - j); int actualBlockSize = std::min<int>(BlockSize,size - j);
const Scalar* actual_b = blockB+j*depth*PacketSize; const Scalar* actual_b = blockB+j*depth;
if(UpLo==Upper) if(UpLo==Upper)
gebp_kernel(res+j*resStride, resStride, blockA, actual_b, j, depth, actualBlockSize); gebp_kernel(res+j*resStride, resStride, blockA, actual_b, j, depth, actualBlockSize);
@ -181,7 +185,8 @@ struct ei_sybb_kernel
int i = j; int i = j;
buffer.setZero(); buffer.setZero();
// 1 - apply the kernel on the temporary buffer // 1 - apply the kernel on the temporary buffer
gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize); gebp_kernel(buffer.data(), BlockSize, blockA+depth*i, actual_b, actualBlockSize, depth, actualBlockSize,
-1, -1, 0, 0, workspace);
// 2 - triangular accumulation // 2 - triangular accumulation
for(int j1=0; j1<actualBlockSize; ++j1) for(int j1=0; j1<actualBlockSize; ++j1)
{ {
@ -195,7 +200,8 @@ struct ei_sybb_kernel
if(UpLo==Lower) if(UpLo==Lower)
{ {
int i = j+actualBlockSize; int i = j+actualBlockSize;
gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize); gebp_kernel(res+j*resStride+i, resStride, blockA+depth*i, actual_b, size-i, depth, actualBlockSize,
-1, -1, 0, 0, workspace);
} }
} }
} }