mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Simplified the contraction code`
This commit is contained in:
parent
519d63d350
commit
4236aebe10
@ -720,24 +720,20 @@ protected:
|
||||
const LhsScalar* leftData = m_leftImpl.data();
|
||||
const RhsScalar* rightData = m_rightImpl.data();
|
||||
|
||||
libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m);
|
||||
libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k);
|
||||
libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m);
|
||||
const libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m);
|
||||
const libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k);
|
||||
const libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m);
|
||||
|
||||
libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc);
|
||||
const libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc);
|
||||
// Use bigger stride to avoid hitting same cache line too often.
|
||||
// This consistently gives +~0.5 Gflops.
|
||||
libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>(
|
||||
const libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>(
|
||||
kc % 32 == 0 ? kc + 16 : kc
|
||||
);
|
||||
|
||||
// Kernel for the general case (not edges)
|
||||
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar> kernel;
|
||||
|
||||
const LhsScalar *ap;
|
||||
const RhsScalar *bp;
|
||||
const Scalar *cp;
|
||||
|
||||
LhsScalar* blockA = NULL;
|
||||
RhsScalar* panelB = NULL;
|
||||
|
||||
@ -748,8 +744,8 @@ protected:
|
||||
panelB = static_cast<RhsScalar*>(this->m_device.allocate(nc_outer * stride_panelB * sizeof(RhsScalar)));
|
||||
}
|
||||
|
||||
Index kernel_stride_A = copyA ? stride_blockA : stride_A;
|
||||
Index kernel_stride_B = copyB ? stride_panelB : stride_B;
|
||||
const Index kernel_stride_A = copyA ? stride_blockA : stride_A;
|
||||
const Index kernel_stride_B = copyB ? stride_panelB : stride_B;
|
||||
kernel = internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, mc, nc, kc, kernel_stride_A, kernel_stride_B, stride_C, 1, 1, blocking.prefetch());
|
||||
|
||||
// Outer blocking
|
||||
@ -763,6 +759,7 @@ protected:
|
||||
// Inner blocking
|
||||
for (Index ki = ki_outer; ki < mini(ki_outer+kc_outer, k); ki += kc) {
|
||||
const Index actual_kc = mini(ki_outer+kc_outer, mini(ki+kc, k)) - ki;
|
||||
const float beta = ki == 0 ? 0 : 1;
|
||||
|
||||
if (copyB) {
|
||||
if (transposeB) {
|
||||
@ -775,7 +772,7 @@ protected:
|
||||
for (Index mi = mi_outer; mi < mini(mi_outer+mc_outer, m); mi += mc) {
|
||||
const Index actual_mc = mini(mi_outer+mc_outer, mini(mi+mc, m)) - mi;
|
||||
|
||||
const LhsScalar * a = transposeA ? leftData + mi*stride_A + ki :
|
||||
const LhsScalar* a = transposeA ? leftData + mi*stride_A + ki :
|
||||
leftData + ki*stride_A + mi;
|
||||
|
||||
if (copyA) {
|
||||
@ -785,30 +782,24 @@ protected:
|
||||
internal::pack_simple<LhsScalar, Index>(blockA, a, actual_kc, actual_mc, stride_blockA, stride_A);
|
||||
}
|
||||
}
|
||||
const LhsScalar* actual_a = copyA ? blockA : a;
|
||||
|
||||
for (Index ni = ni_outer; ni < mini(ni_outer+nc_outer, n); ni += nc) {
|
||||
const Index actual_nc = mini(ni_outer+nc_outer, mini(ni+nc, n)) - ni;
|
||||
|
||||
const RhsScalar * b = rightData + ni*stride_B + ki;
|
||||
Scalar * c = buffer + ni*stride_C + mi;
|
||||
cp = c + nc*stride_C;
|
||||
const RhsScalar* b = rightData + ni*stride_B + ki;
|
||||
Scalar* c = buffer + ni*stride_C + mi;
|
||||
const Scalar* cp = c + nc*stride_C;
|
||||
|
||||
const LhsScalar * actual_a = copyA ? blockA : a;
|
||||
const Index actual_lda = copyA ? stride_blockA : stride_A;
|
||||
ap = copyA ? blockA : a;
|
||||
|
||||
const RhsScalar * actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b;
|
||||
const Index actual_ldb = copyB ? stride_panelB : stride_B;
|
||||
bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B;
|
||||
|
||||
float beta = ki == 0 ? 0 : 1;
|
||||
const RhsScalar* actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b;
|
||||
const RhsScalar* bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B;
|
||||
|
||||
if (actual_mc == mc && actual_kc == kc && actual_nc == nc && beta == 1) {
|
||||
// Most used, cached kernel.
|
||||
kernel(actual_a, actual_b, c, ap, bp, cp);
|
||||
kernel(actual_a, actual_b, c, actual_a, bp, cp);
|
||||
} else {
|
||||
// Edges - use libxsmm kernel cache.
|
||||
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, actual_lda, actual_ldb, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, ap, bp, cp);
|
||||
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, kernel_stride_A, kernel_stride_B, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, actual_a, bp, cp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user