mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Add recursive work splitting to EvalShardedByInnerDimContext
This commit is contained in:
parent
25230d1862
commit
bb7ccac3af
@ -1159,16 +1159,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
template <int Alignment>
|
template <int Alignment>
|
||||||
void run() {
|
void run() {
|
||||||
Barrier barrier(internal::convert_index<int>(num_blocks));
|
Barrier barrier(internal::convert_index<int>(num_blocks));
|
||||||
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
eval<Alignment>(barrier, 0, num_blocks);
|
||||||
evaluator->m_device.enqueueNoNotification(
|
|
||||||
[this, block_idx, &barrier]() {
|
|
||||||
Index block_start = block_idx * block_size;
|
|
||||||
Index block_end = block_start + actualBlockSize(block_idx);
|
|
||||||
|
|
||||||
processBlock<Alignment>(block_idx, block_start, block_end);
|
|
||||||
barrier.Notify();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
barrier.Wait();
|
barrier.Wait();
|
||||||
|
|
||||||
// Aggregate partial sums from l0 ranges.
|
// Aggregate partial sums from l0 ranges.
|
||||||
@ -1180,38 +1171,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
|
|
||||||
template <int Alignment>
|
template <int Alignment>
|
||||||
void runAsync() {
|
void runAsync() {
|
||||||
for (Index block_idx = 0; block_idx < num_blocks; ++block_idx) {
|
evalAsync<Alignment>(0, num_blocks);
|
||||||
evaluator->m_device.enqueueNoNotification([this, block_idx]() {
|
|
||||||
Index block_start = block_idx * block_size;
|
|
||||||
Index block_end = block_start + actualBlockSize(block_idx);
|
|
||||||
|
|
||||||
processBlock<Alignment>(block_idx, block_start, block_end);
|
|
||||||
|
|
||||||
int v = num_pending_blocks.fetch_sub(1);
|
|
||||||
eigen_assert(v >= 1);
|
|
||||||
|
|
||||||
if (v == 1) {
|
|
||||||
// Aggregate partial sums from l0 ranges.
|
|
||||||
aggregateL0Blocks<Alignment>();
|
|
||||||
|
|
||||||
// Apply output kernel.
|
|
||||||
applyOutputKernel();
|
|
||||||
|
|
||||||
// NOTE: If we call `done` callback before deleting this (context),
|
|
||||||
// it might deallocate Self* pointer captured by context, and we'll
|
|
||||||
// fail in destructor trying to deallocate temporary buffers.
|
|
||||||
|
|
||||||
// Move done call back from context before it will be destructed.
|
|
||||||
DoneCallback done_copy = std::move(done);
|
|
||||||
|
|
||||||
// We are confident that we are the last one who touches context.
|
|
||||||
delete this;
|
|
||||||
|
|
||||||
// Now safely call the done callback.
|
|
||||||
done_copy();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -1405,6 +1365,68 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int Alignment>
|
||||||
|
void eval(Barrier& barrier, Index start_block_idx, Index end_block_idx) {
|
||||||
|
while (end_block_idx - start_block_idx > 1) {
|
||||||
|
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
|
||||||
|
evaluator->m_device.enqueueNoNotification(
|
||||||
|
[this, &barrier, mid_block_idx, end_block_idx]() {
|
||||||
|
eval<Alignment>(barrier, mid_block_idx, end_block_idx);
|
||||||
|
});
|
||||||
|
end_block_idx = mid_block_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
Index block_idx = start_block_idx;
|
||||||
|
Index block_start = block_idx * block_size;
|
||||||
|
Index block_end = block_start + actualBlockSize(block_idx);
|
||||||
|
|
||||||
|
processBlock<Alignment>(block_idx, block_start, block_end);
|
||||||
|
barrier.Notify();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int Alignment>
|
||||||
|
void evalAsync(Index start_block_idx, Index end_block_idx) {
|
||||||
|
while (end_block_idx - start_block_idx > 1) {
|
||||||
|
Index mid_block_idx = (start_block_idx + end_block_idx) / 2;
|
||||||
|
evaluator->m_device.enqueueNoNotification(
|
||||||
|
[this, mid_block_idx, end_block_idx]() {
|
||||||
|
evalAsync<Alignment>(mid_block_idx, end_block_idx);
|
||||||
|
});
|
||||||
|
end_block_idx = mid_block_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
Index block_idx = start_block_idx;
|
||||||
|
|
||||||
|
Index block_start = block_idx * block_size;
|
||||||
|
Index block_end = block_start + actualBlockSize(block_idx);
|
||||||
|
|
||||||
|
processBlock<Alignment>(block_idx, block_start, block_end);
|
||||||
|
|
||||||
|
int v = num_pending_blocks.fetch_sub(1);
|
||||||
|
eigen_assert(v >= 1);
|
||||||
|
|
||||||
|
if (v == 1) {
|
||||||
|
// Aggregate partial sums from l0 ranges.
|
||||||
|
aggregateL0Blocks<Alignment>();
|
||||||
|
|
||||||
|
// Apply output kernel.
|
||||||
|
applyOutputKernel();
|
||||||
|
|
||||||
|
// NOTE: If we call `done` callback before deleting this (context),
|
||||||
|
// it might deallocate Self* pointer captured by context, and we'll
|
||||||
|
// fail in destructor trying to deallocate temporary buffers.
|
||||||
|
|
||||||
|
// Move done call back from context before it will be destructed.
|
||||||
|
DoneCallback done_copy = std::move(done);
|
||||||
|
|
||||||
|
// We are confident that we are the last one who touches context.
|
||||||
|
delete this;
|
||||||
|
|
||||||
|
// Now safely call the done callback.
|
||||||
|
done_copy();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Cost model doesn't capture well the cost associated with constructing
|
// Cost model doesn't capture well the cost associated with constructing
|
||||||
// tensor contraction mappers and computing loop bounds in gemm_pack_lhs
|
// tensor contraction mappers and computing loop bounds in gemm_pack_lhs
|
||||||
// and gemm_pack_rhs, so we specify minimum desired block size.
|
// and gemm_pack_rhs, so we specify minimum desired block size.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user