mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-29 07:14:12 +08:00
Added support for libxsmm kernel in multithreaded contractions
This commit is contained in:
parent
0657228569
commit
519d63d350
@ -116,6 +116,28 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
|
||||
bool rhs_inner_dim_reordered, int Alignment>
|
||||
void evalProduct(Scalar* buffer) const {
|
||||
const Index m = this->m_i_size;
|
||||
const Index n = this->m_j_size;
|
||||
const Index k = this->m_k_size;
|
||||
if (m == 0 || n == 0 || k == 0) return;
|
||||
|
||||
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
|
||||
if (this->m_can_use_xsmm) {
|
||||
bool transposeA = !this->m_lhs_inner_dim_contiguous;
|
||||
bool transposeB = !this->m_rhs_inner_dim_contiguous;
|
||||
internal::TensorXsmmContractionBlocking<LhsScalar, RhsScalar, Index>
|
||||
blocking(k, m, n, this->m_device.numThreads(), transposeA,
|
||||
transposeB);
|
||||
|
||||
if (blocking.num_threads() == 1) {
|
||||
this->evalGemmXSMM(buffer);
|
||||
} else {
|
||||
ContextXsmm<Alignment>(this, buffer, m, n, k, blocking).run();
|
||||
}
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
typedef
|
||||
typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
|
||||
LhsScalar;
|
||||
@ -147,10 +169,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
Traits::mr, Traits::nr, false, false>
|
||||
GebpKernel;
|
||||
|
||||
const Index m = this->m_i_size;
|
||||
const Index n = this->m_j_size;
|
||||
const Index k = this->m_k_size;
|
||||
if (m == 0 || n == 0 || k == 0) return;
|
||||
|
||||
|
||||
// Compute a set of algorithm parameters:
|
||||
// - kernel block sizes (bm, bn, bk)
|
||||
@ -1044,6 +1063,187 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
||||
rhsCost.dropMemoryCost();
|
||||
return cost + lhsCost + rhsCost;
|
||||
}
|
||||
|
||||
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
|
||||
template<int Alignment>
|
||||
class ContextXsmm {
|
||||
public:
|
||||
ContextXsmm(const Self* self, Scalar* buffer, Index m, Index n, Index k,
|
||||
const internal::TensorXsmmContractionBlocking<LhsScalar,
|
||||
RhsScalar, Index>& blocking):
|
||||
device(self->m_device),
|
||||
m(m), k(k), n(n),
|
||||
stride_a(blocking.transposeA() ? k : m),
|
||||
stride_b(blocking.transposeB() ? n : k),
|
||||
stride_c(m),
|
||||
bm(blocking.mc()), bk(blocking.kc()), bn(blocking.nc()),
|
||||
blocks_m(blocking.blocks_m()), blocks_k(blocking.blocks_k()),
|
||||
blocks_n(blocking.blocks_n()),
|
||||
copyA(blocking.copyA()), copyB(blocking.copyB()),
|
||||
transposeA(blocking.transposeA()), transposeB(blocking.transposeB()),
|
||||
num_threads(blocking.num_threads()),
|
||||
buffer(buffer),
|
||||
leftData(self->m_leftImpl.data()), rightData(self->m_rightImpl.data()),
|
||||
workers_done(blocking.num_threads()),
|
||||
|
||||
packingA_jobs(0), packingB_jobs(0), compute_jobs(0),
|
||||
packingA_done(blocking.blocks_m()), packingB_done(blocking.blocks_n()) {}
|
||||
|
||||
void worker() {
|
||||
// Pack
|
||||
|
||||
if (copyA) {
|
||||
while (true) {
|
||||
uint32_t mk = packingA_jobs++;
|
||||
Index mi = mk / blocks_k;
|
||||
Index ki = mk % blocks_k;
|
||||
if (mi >= blocks_m) break;
|
||||
|
||||
LhsScalar * blockA = blocksA + (bk*bm) * (mi*blocks_k+ki);
|
||||
if (transposeA) {
|
||||
const LhsScalar * current_a = leftData + (bm*mi)*stride_a + (bk*ki);
|
||||
libxsmm_otrans(blockA, current_a, sizeof(LhsScalar), actual_bk(ki),
|
||||
actual_bm(mi), stride_a, bm);
|
||||
} else {
|
||||
const LhsScalar * current_a = leftData + (bk*ki)*stride_a + (bm*mi);
|
||||
internal::pack_simple<LhsScalar, Index>(blockA, current_a,
|
||||
actual_bk(ki), actual_bm(mi), bm, stride_a);
|
||||
}
|
||||
packingA_done.at(mi)++;
|
||||
}
|
||||
}
|
||||
|
||||
if (copyB) {
|
||||
while (true) {
|
||||
uint32_t nk = packingB_jobs++;
|
||||
Index ni = nk / blocks_k;
|
||||
Index ki = nk % blocks_k;
|
||||
if (ni >= blocks_n) break;
|
||||
|
||||
RhsScalar * blockB = blocksB + (bk*bn) * (ni*blocks_k+ki);
|
||||
if (transposeB) {
|
||||
const RhsScalar * current_b = rightData + (ki*bk)*stride_b +
|
||||
(ni*bn);
|
||||
libxsmm_otrans(blockB, current_b, sizeof(RhsScalar), actual_bn(ni),
|
||||
actual_bk(ki), stride_b, bk);
|
||||
} else {
|
||||
const RhsScalar * current_b = rightData + (ni*bn)*stride_b +
|
||||
(ki*bk);
|
||||
internal::pack_simple<RhsScalar, Index>(blockB, current_b,
|
||||
actual_bn(ni), actual_bk(ki), bk, stride_b);
|
||||
}
|
||||
packingB_done.at(ni)++;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute
|
||||
|
||||
while (true) {
|
||||
uint32_t mn = compute_jobs++;
|
||||
Index mi = mn / blocks_n;
|
||||
Index ni = mn % blocks_n;
|
||||
if (mi >= blocks_m) break;
|
||||
|
||||
// Wait for mi, ni packings to be done. This is more fine-grained than
|
||||
// waiting for all workers to finish packing.
|
||||
while ((copyA && (packingA_done.at(mi) < blocks_k)) ||
|
||||
(copyB && (packingB_done.at(ni) < blocks_k)))
|
||||
{}
|
||||
|
||||
for (Index ki=0; ki < blocks_k; ++ki) {
|
||||
const LhsScalar * current_a = copyA ?
|
||||
blocksA + (bk*bm) * (mi*blocks_k+ki) :
|
||||
leftData + (bk*ki)*stride_a + (bm*mi);
|
||||
const RhsScalar * current_b = copyB ?
|
||||
blocksB + (bk*bn) * (ni*blocks_k+ki) :
|
||||
rightData + (ni*bn)*stride_b + (bk*ki);
|
||||
|
||||
Index current_stride_a = copyA ? bm : stride_a;
|
||||
Index current_stride_b = copyB ? bk : stride_b;
|
||||
|
||||
// Memory may not be zeroed, overwrite instead of adding in first
|
||||
// iteration.
|
||||
float beta = ki == 0 ? 0 : 1;
|
||||
|
||||
Scalar * current_c = buffer + (mi*bm) + (ni*bn)*stride_c;
|
||||
internal::libxsmm_wrapper<LhsScalar, RhsScalar, Scalar>(
|
||||
0, actual_bm(mi), actual_bn(ni), actual_bk(ki),
|
||||
current_stride_a, current_stride_b, stride_c, 1, beta, 0)
|
||||
(current_a, current_b, current_c);
|
||||
}
|
||||
}
|
||||
|
||||
workers_done.Notify();
|
||||
}
|
||||
|
||||
void run() {
|
||||
// Parallelization strategy.
|
||||
//
|
||||
// First pack A into blocks (sharding by m, k) and B (sharding by n,k),
|
||||
// then shard by m, n.
|
||||
//
|
||||
// Do not use advanced ThreadPool queuing, just run a single long-standing
|
||||
// function in each thread.
|
||||
if (copyA) {
|
||||
blocksA = static_cast<LhsScalar*>(device.allocate(
|
||||
(blocks_m*bm)*(blocks_k*bk)*sizeof(LhsScalar)));
|
||||
}
|
||||
if (copyB) {
|
||||
blocksB = static_cast<RhsScalar*>(device.allocate(
|
||||
(blocks_n*bn)*(blocks_k*bk)*sizeof(RhsScalar)));
|
||||
}
|
||||
|
||||
for (Index i = 0; i < num_threads; ++i) {
|
||||
device.enqueueNoNotification([=]() { worker(); });
|
||||
}
|
||||
|
||||
workers_done.Wait();
|
||||
|
||||
if (copyA) {
|
||||
device.deallocate(blocksA);
|
||||
}
|
||||
if (copyB) {
|
||||
device.deallocate(blocksB);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// real block size for block index in [0, ..., blocks - 1].
|
||||
Index actual_bm(Index mi) const {
|
||||
return mi != blocks_m - 1 ? bm : m + bm - bm * blocks_m;
|
||||
}
|
||||
Index actual_bk(Index ki) const {
|
||||
return ki != blocks_k - 1 ? bk : k + bk - bk * blocks_k;
|
||||
}
|
||||
Index actual_bn(Index ni) const {
|
||||
return ni != blocks_n - 1 ? bn : n + bn - bn * blocks_n;
|
||||
}
|
||||
|
||||
const Device& device;
|
||||
Index m, k, n;
|
||||
Index stride_a, stride_b, stride_c;
|
||||
Index bm, bk, bn; // Block sizes.
|
||||
Index blocks_m, blocks_k, blocks_n; // Number of blocks in each dimension.
|
||||
bool copyA, copyB, transposeA, transposeB;
|
||||
Index num_threads;
|
||||
Scalar *buffer;
|
||||
const LhsScalar *leftData;
|
||||
const RhsScalar *rightData;
|
||||
|
||||
LhsScalar *blocksA;
|
||||
RhsScalar *blocksB;
|
||||
// barrier for joining all threads after all done.
|
||||
Barrier workers_done;
|
||||
// "queues" of (mi,ki), (ki,ni), (mi,ni) jobs packed [0,p)x[0,q) -> [0, p*q)
|
||||
std::atomic<uint32_t> packingA_jobs;
|
||||
std::atomic<uint32_t> packingB_jobs;
|
||||
std::atomic<uint32_t> compute_jobs;
|
||||
// already packed blocks for each mi-panel in A and ni-panel in B.
|
||||
std::vector<std::atomic<uint8_t>> packingA_done;
|
||||
std::vector<std::atomic<uint8_t>> packingB_done;
|
||||
};
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
} // end namespace Eigen
|
||||
|
Loading…
x
Reference in New Issue
Block a user