mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-13 18:03:13 +08:00
Reduce block evaluation overhead for small tensor expressions
This commit is contained in:
parent
7252163335
commit
788bef6ab5
@ -282,19 +282,8 @@ class TensorBlockMapper {
|
|||||||
TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions,
|
TensorBlockMapper(const DSizes<IndexType, NumDims>& dimensions,
|
||||||
const TensorBlockResourceRequirements& requirements)
|
const TensorBlockResourceRequirements& requirements)
|
||||||
: m_tensor_dimensions(dimensions), m_requirements(requirements) {
|
: m_tensor_dimensions(dimensions), m_requirements(requirements) {
|
||||||
// Initialize `m_block_dimensions`.
|
// Compute block dimensions and the total number of blocks.
|
||||||
InitializeBlockDimensions();
|
InitializeBlockDimensions();
|
||||||
|
|
||||||
// Calculate block counts by dimension and total block count.
|
|
||||||
DSizes<IndexType, NumDims> block_count;
|
|
||||||
for (int i = 0; i < NumDims; ++i) {
|
|
||||||
block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
|
|
||||||
}
|
|
||||||
m_total_block_count = array_prod(block_count);
|
|
||||||
|
|
||||||
// Calculate block strides (used for enumerating blocks).
|
|
||||||
m_tensor_strides = strides<Layout>(m_tensor_dimensions);
|
|
||||||
m_block_strides = strides<Layout>(block_count);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE IndexType blockCount() const {
|
||||||
@ -339,23 +328,33 @@ class TensorBlockMapper {
|
|||||||
void InitializeBlockDimensions() {
|
void InitializeBlockDimensions() {
|
||||||
// Requested block shape and size.
|
// Requested block shape and size.
|
||||||
const TensorBlockShapeType shape_type = m_requirements.shape_type;
|
const TensorBlockShapeType shape_type = m_requirements.shape_type;
|
||||||
const IndexType target_block_size =
|
IndexType target_block_size =
|
||||||
numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size));
|
numext::maxi<IndexType>(1, static_cast<IndexType>(m_requirements.size));
|
||||||
|
|
||||||
|
IndexType tensor_size = m_tensor_dimensions.TotalSize();
|
||||||
|
|
||||||
// Corner case: one of the dimensions is zero. Logic below is too complex
|
// Corner case: one of the dimensions is zero. Logic below is too complex
|
||||||
// to handle this case on a general basis, just use unit block size.
|
// to handle this case on a general basis, just use unit block size.
|
||||||
// Note: we must not yield blocks with zero dimensions (recipe for
|
// Note: we must not yield blocks with zero dimensions (recipe for
|
||||||
// overflows/underflows, divisions by zero and NaNs later).
|
// overflows/underflows, divisions by zero and NaNs later).
|
||||||
if (m_tensor_dimensions.TotalSize() == 0) {
|
if (tensor_size == 0) {
|
||||||
for (int i = 0; i < NumDims; ++i) {
|
for (int i = 0; i < NumDims; ++i) {
|
||||||
m_block_dimensions[i] = 1;
|
m_block_dimensions[i] = 1;
|
||||||
}
|
}
|
||||||
|
m_total_block_count = 0;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If tensor fits into a target block size, evaluate it as a single block.
|
// If tensor fits into a target block size, evaluate it as a single block.
|
||||||
if (m_tensor_dimensions.TotalSize() <= target_block_size) {
|
if (tensor_size <= target_block_size) {
|
||||||
m_block_dimensions = m_tensor_dimensions;
|
m_block_dimensions = m_tensor_dimensions;
|
||||||
|
m_total_block_count = 1;
|
||||||
|
// The only valid block index is `0`, and in this case we do not need
|
||||||
|
// to compute real strides for tensor or blocks (see blockDescriptor).
|
||||||
|
for (int i = 0; i < NumDims; ++i) {
|
||||||
|
m_tensor_strides[i] = 0;
|
||||||
|
m_block_strides[i] = 1;
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -418,6 +417,17 @@ class TensorBlockMapper {
|
|||||||
eigen_assert(m_block_dimensions.TotalSize() >=
|
eigen_assert(m_block_dimensions.TotalSize() >=
|
||||||
numext::mini<IndexType>(target_block_size,
|
numext::mini<IndexType>(target_block_size,
|
||||||
m_tensor_dimensions.TotalSize()));
|
m_tensor_dimensions.TotalSize()));
|
||||||
|
|
||||||
|
// Calculate block counts by dimension and total block count.
|
||||||
|
DSizes<IndexType, NumDims> block_count;
|
||||||
|
for (int i = 0; i < NumDims; ++i) {
|
||||||
|
block_count[i] = divup(m_tensor_dimensions[i], m_block_dimensions[i]);
|
||||||
|
}
|
||||||
|
m_total_block_count = array_prod(block_count);
|
||||||
|
|
||||||
|
// Calculate block strides (used for enumerating blocks).
|
||||||
|
m_tensor_strides = strides<Layout>(m_tensor_dimensions);
|
||||||
|
m_block_strides = strides<Layout>(block_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
DSizes<IndexType, NumDims> m_tensor_dimensions;
|
DSizes<IndexType, NumDims> m_tensor_dimensions;
|
||||||
|
@ -374,15 +374,23 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable,
|
|||||||
IndexType lastBlockIdx) {
|
IndexType lastBlockIdx) {
|
||||||
TensorBlockScratch scratch(device);
|
TensorBlockScratch scratch(device);
|
||||||
|
|
||||||
for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx; ++block_idx) {
|
for (IndexType block_idx = firstBlockIdx; block_idx < lastBlockIdx;
|
||||||
|
++block_idx) {
|
||||||
TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx);
|
TensorBlockDesc desc = tiling.block_mapper.blockDescriptor(block_idx);
|
||||||
evaluator.evalBlock(desc, scratch);
|
evaluator.evalBlock(desc, scratch);
|
||||||
scratch.reset();
|
scratch.reset();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
|
// Evaluate small expressions directly as a single block.
|
||||||
eval_block);
|
if (tiling.block_mapper.blockCount() == 1) {
|
||||||
|
TensorBlockScratch scratch(device);
|
||||||
|
TensorBlockDesc desc(0, tiling.block_mapper.blockDimensions());
|
||||||
|
evaluator.evalBlock(desc, scratch);
|
||||||
|
} else {
|
||||||
|
device.parallelFor(tiling.block_mapper.blockCount(), tiling.cost,
|
||||||
|
eval_block);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
evaluator.cleanup();
|
evaluator.cleanup();
|
||||||
}
|
}
|
||||||
@ -486,8 +494,18 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
|
|||||||
scratch.reset();
|
scratch.reset();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
|
|
||||||
ctx->tiling.cost, eval_block, [ctx]() { delete ctx; });
|
// Evaluate small expressions directly as a single block.
|
||||||
|
if (ctx->tiling.block_mapper.blockCount() == 1) {
|
||||||
|
TensorBlockScratch scratch(ctx->device);
|
||||||
|
TensorBlockDesc desc(0, ctx->tiling.block_mapper.blockDimensions());
|
||||||
|
ctx->evaluator.evalBlock(desc, scratch);
|
||||||
|
delete ctx;
|
||||||
|
} else {
|
||||||
|
ctx->device.parallelForAsync(ctx->tiling.block_mapper.blockCount(),
|
||||||
|
ctx->tiling.cost, eval_block,
|
||||||
|
[ctx]() { delete ctx; });
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);
|
ctx->evaluator.evalSubExprsIfNeededAsync(nullptr, on_eval_subexprs);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user