Added support for libxsmm kernel in multithreaded contractions

This commit is contained in:
Benoit Steiner 2016-12-21 15:06:06 -08:00
parent 0657228569
commit 519d63d350

View File

@ -116,6 +116,28 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered, int Alignment>
void evalProduct(Scalar* buffer) const {
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
if (m == 0 || n == 0 || k == 0) return;
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
if (this->m_can_use_xsmm) {
bool transposeA = !this->m_lhs_inner_dim_contiguous;
bool transposeB = !this->m_rhs_inner_dim_contiguous;
internal::TensorXsmmContractionBlocking<LhsScalar, RhsScalar, Index>
blocking(k, m, n, this->m_device.numThreads(), transposeA,
transposeB);
if (blocking.num_threads() == 1) {
this->evalGemmXSMM(buffer);
} else {
ContextXsmm<Alignment>(this, buffer, m, n, k, blocking).run();
}
return;
}
#endif
typedef
typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
LhsScalar;
@ -147,10 +169,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Traits::mr, Traits::nr, false, false>
GebpKernel;
const Index m = this->m_i_size;
const Index n = this->m_j_size;
const Index k = this->m_k_size;
if (m == 0 || n == 0 || k == 0) return;
// Compute a set of algorithm parameters:
// - kernel block sizes (bm, bn, bk)
@ -1044,6 +1063,187 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
rhsCost.dropMemoryCost();
return cost + lhsCost + rhsCost;
}
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
template<int Alignment>
class ContextXsmm {
public:
ContextXsmm(const Self* self, Scalar* buffer, Index m, Index n, Index k,
const internal::TensorXsmmContractionBlocking<LhsScalar,
RhsScalar, Index>& blocking):
device(self->m_device),
m(m), k(k), n(n),
stride_a(blocking.transposeA() ? k : m),
stride_b(blocking.transposeB() ? n : k),
stride_c(m),
bm(blocking.mc()), bk(blocking.kc()), bn(blocking.nc()),
blocks_m(blocking.blocks_m()), blocks_k(blocking.blocks_k()),
blocks_n(blocking.blocks_n()),
copyA(blocking.copyA()), copyB(blocking.copyB()),
transposeA(blocking.transposeA()), transposeB(blocking.transposeB()),
num_threads(blocking.num_threads()),
buffer(buffer),
leftData(self->m_leftImpl.data()), rightData(self->m_rightImpl.data()),
workers_done(blocking.num_threads()),
packingA_jobs(0), packingB_jobs(0), compute_jobs(0),
packingA_done(blocking.blocks_m()), packingB_done(blocking.blocks_n()) {}
void worker() {
// Pack
if (copyA) {
while (true) {
uint32_t mk = packingA_jobs++;
Index mi = mk / blocks_k;
Index ki = mk % blocks_k;
if (mi >= blocks_m) break;
LhsScalar * blockA = blocksA + (bk*bm) * (mi*blocks_k+ki);
if (transposeA) {
const LhsScalar * current_a = leftData + (bm*mi)*stride_a + (bk*ki);
libxsmm_otrans(blockA, current_a, sizeof(LhsScalar), actual_bk(ki),
actual_bm(mi), stride_a, bm);
} else {
const LhsScalar * current_a = leftData + (bk*ki)*stride_a + (bm*mi);
internal::pack_simple<LhsScalar, Index>(blockA, current_a,
actual_bk(ki), actual_bm(mi), bm, stride_a);
}
packingA_done.at(mi)++;
}
}
if (copyB) {
while (true) {
uint32_t nk = packingB_jobs++;
Index ni = nk / blocks_k;
Index ki = nk % blocks_k;
if (ni >= blocks_n) break;
RhsScalar * blockB = blocksB + (bk*bn) * (ni*blocks_k+ki);
if (transposeB) {
const RhsScalar * current_b = rightData + (ki*bk)*stride_b +
(ni*bn);
libxsmm_otrans(blockB, current_b, sizeof(RhsScalar), actual_bn(ni),
actual_bk(ki), stride_b, bk);
} else {
const RhsScalar * current_b = rightData + (ni*bn)*stride_b +
(ki*bk);
internal::pack_simple<RhsScalar, Index>(blockB, current_b,
actual_bn(ni), actual_bk(ki), bk, stride_b);
}
packingB_done.at(ni)++;
}
}
// Compute
while (true) {
uint32_t mn = compute_jobs++;
Index mi = mn / blocks_n;
Index ni = mn % blocks_n;
if (mi >= blocks_m) break;
// Wait for mi, ni packings to be done. This is more fine-grained than
// waiting for all workers to finish packing.
while ((copyA && (packingA_done.at(mi) < blocks_k)) ||
(copyB && (packingB_done.at(ni) < blocks_k)))
{}
for (Index ki=0; ki < blocks_k; ++ki) {
const LhsScalar * current_a = copyA ?
blocksA + (bk*bm) * (mi*blocks_k+ki) :
leftData + (bk*ki)*stride_a + (bm*mi);
const RhsScalar * current_b = copyB ?
blocksB + (bk*bn) * (ni*blocks_k+ki) :
rightData + (ni*bn)*stride_b + (bk*ki);
Index current_stride_a = copyA ? bm : stride_a;
Index current_stride_b = copyB ? bk : stride_b;
// Memory may not be zeroed, overwrite instead of adding in first
// iteration.
float beta = ki == 0 ? 0 : 1;
Scalar * current_c = buffer + (mi*bm) + (ni*bn)*stride_c;
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(
0, actual_bm(mi), actual_bn(ni), actual_bk(ki),
current_stride_a, current_stride_b, stride_c, 1, beta, 0)
(current_a, current_b, current_c);
}
}
workers_done.Notify();
}
void run() {
// Parallelization strategy.
//
// First pack A into blocks (sharding by m, k) and B (sharding by n,k),
// then shard by m, n.
//
// Do not use advanced ThreadPool queuing, just run a single long-standing
// function in each thread.
if (copyA) {
blocksA = static_cast<LhsScalar*>(device.allocate(
(blocks_m*bm)*(blocks_k*bk)*sizeof(LhsScalar)));
}
if (copyB) {
blocksB = static_cast<RhsScalar*>(device.allocate(
(blocks_n*bn)*(blocks_k*bk)*sizeof(RhsScalar)));
}
for (Index i = 0; i < num_threads; ++i) {
device.enqueueNoNotification([=]() { worker(); });
}
workers_done.Wait();
if (copyA) {
device.deallocate(blocksA);
}
if (copyB) {
device.deallocate(blocksB);
}
}
private:
// real block size for block index in [0, ..., blocks - 1].
Index actual_bm(Index mi) const {
return mi != blocks_m - 1 ? bm : m + bm - bm * blocks_m;
}
Index actual_bk(Index ki) const {
return ki != blocks_k - 1 ? bk : k + bk - bk * blocks_k;
}
Index actual_bn(Index ni) const {
return ni != blocks_n - 1 ? bn : n + bn - bn * blocks_n;
}
const Device& device;
Index m, k, n;
Index stride_a, stride_b, stride_c;
Index bm, bk, bn; // Block sizes.
Index blocks_m, blocks_k, blocks_n; // Number of blocks in each dimension.
bool copyA, copyB, transposeA, transposeB;
Index num_threads;
Scalar *buffer;
const LhsScalar *leftData;
const RhsScalar *rightData;
LhsScalar *blocksA;
RhsScalar *blocksB;
// barrier for joining all threads after all done.
Barrier workers_done;
// "queues" of (mi,ki), (ki,ni), (mi,ni) jobs packed [0,p)x[0,q) -> [0, p*q)
std::atomic<uint32_t> packingA_jobs;
std::atomic<uint32_t> packingB_jobs;
std::atomic<uint32_t> compute_jobs;
// already packed blocks for each mi-panel in A and ni-panel in B.
std::vector<std::atomic<uint8_t>> packingA_done;
std::vector<std::atomic<uint8_t>> packingB_done;
};
#endif
};
} // end namespace Eigen