Simplified the contraction code`

This commit is contained in:
Benoit Steiner 2016-12-21 16:42:56 -08:00
parent 519d63d350
commit 4236aebe10

View File

@ -720,24 +720,20 @@ protected:
const LhsScalar* leftData = m_leftImpl.data();
const RhsScalar* rightData = m_rightImpl.data();
libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m);
libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k);
libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m);
const libxsmm_blasint stride_A = static_cast<libxsmm_blasint>(transposeA ? k : m);
const libxsmm_blasint stride_B = static_cast<libxsmm_blasint>(transposeB ? n : k);
const libxsmm_blasint stride_C = static_cast<libxsmm_blasint>(m);
libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc);
const libxsmm_blasint stride_blockA = static_cast<libxsmm_blasint>(mc);
// Use bigger stride to avoid hitting same cache line too often.
// This consistently gives +~0.5 Gflops.
libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>(
const libxsmm_blasint stride_panelB = static_cast<libxsmm_blasint>(
kc % 32 == 0 ? kc + 16 : kc
);
// Kernel for the general case (not edges)
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar> kernel;
const LhsScalar *ap;
const RhsScalar *bp;
const Scalar *cp;
LhsScalar* blockA = NULL;
RhsScalar* panelB = NULL;
@ -748,8 +744,8 @@ protected:
panelB = static_cast<RhsScalar*>(this->m_device.allocate(nc_outer * stride_panelB * sizeof(RhsScalar)));
}
Index kernel_stride_A = copyA ? stride_blockA : stride_A;
Index kernel_stride_B = copyB ? stride_panelB : stride_B;
const Index kernel_stride_A = copyA ? stride_blockA : stride_A;
const Index kernel_stride_B = copyB ? stride_panelB : stride_B;
kernel = internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, mc, nc, kc, kernel_stride_A, kernel_stride_B, stride_C, 1, 1, blocking.prefetch());
// Outer blocking
@ -763,6 +759,7 @@ protected:
// Inner blocking
for (Index ki = ki_outer; ki < mini(ki_outer+kc_outer, k); ki += kc) {
const Index actual_kc = mini(ki_outer+kc_outer, mini(ki+kc, k)) - ki;
const float beta = ki == 0 ? 0 : 1;
if (copyB) {
if (transposeB) {
@ -785,30 +782,24 @@ protected:
internal::pack_simple<LhsScalar, Index>(blockA, a, actual_kc, actual_mc, stride_blockA, stride_A);
}
}
const LhsScalar* actual_a = copyA ? blockA : a;
for (Index ni = ni_outer; ni < mini(ni_outer+nc_outer, n); ni += nc) {
const Index actual_nc = mini(ni_outer+nc_outer, mini(ni+nc, n)) - ni;
const RhsScalar* b = rightData + ni*stride_B + ki;
Scalar* c = buffer + ni*stride_C + mi;
cp = c + nc*stride_C;
const LhsScalar * actual_a = copyA ? blockA : a;
const Index actual_lda = copyA ? stride_blockA : stride_A;
ap = copyA ? blockA : a;
const Scalar* cp = c + nc*stride_C;
const RhsScalar* actual_b = copyB ? panelB + (ni-ni_outer)*stride_panelB : b;
const Index actual_ldb = copyB ? stride_panelB : stride_B;
bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B;
float beta = ki == 0 ? 0 : 1;
const RhsScalar* bp = copyB ? panelB + nc*stride_panelB : b + nc*stride_B;
if (actual_mc == mc && actual_kc == kc && actual_nc == nc && beta == 1) {
// Most used, cached kernel.
kernel(actual_a, actual_b, c, ap, bp, cp);
kernel(actual_a, actual_b, c, actual_a, bp, cp);
} else {
// Edges - use libxsmm kernel cache.
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, actual_lda, actual_ldb, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, ap, bp, cp);
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(0, actual_mc, actual_nc, actual_kc, kernel_stride_A, kernel_stride_B, stride_C, 1, beta, blocking.prefetch())(actual_a, actual_b, c, actual_a, bp, cp);
}
}
}