mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-14 02:13:13 +08:00
Fix evalShardedByInnerDim for AVX512 builds
evalShardedByInnerDim ensures that the values it passes for start_k and end_k to evalGemmPartialWithoutOutputKernel are multiples of 8 as the kernel does not work correctly when the values of k are not multiples of the packet_size. While this precaution works for AVX builds, it is insufficient for AVX512 builds where the maximum packet size is 16. The result is slightly incorrect float32 contractions on AVX512 builds. This commit fixes the problem by ensuring that k is always a multiple of the packet_size if the packet_size is > 8.
This commit is contained in:
parent
6f5b126e6d
commit
36f8f6d0be
@ -788,9 +788,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
const Index m = this->m_i_size;
|
const Index m = this->m_i_size;
|
||||||
const Index n = this->m_j_size;
|
const Index n = this->m_j_size;
|
||||||
const Index k = this->m_k_size;
|
const Index k = this->m_k_size;
|
||||||
// The underlying GEMM kernel assumes that k is a multiple of 8 and
|
const Index packet_size = internal::packet_traits<RhsScalar>::size;
|
||||||
// subtle breakage occurs if this is violated.
|
const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
|
||||||
Index block_size = 8 * divup<Index>(k, 8 * num_threads);
|
// The underlying GEMM kernel assumes that k is a multiple of
|
||||||
|
// the packet size and subtle breakage occurs if this is violated.
|
||||||
|
Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads);
|
||||||
Index num_blocks = divup<Index>(k, block_size);
|
Index num_blocks = divup<Index>(k, block_size);
|
||||||
// we use 'result' for the first block's partial result.
|
// we use 'result' for the first block's partial result.
|
||||||
MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
|
MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
|
||||||
@ -805,9 +807,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
Index start = 0;
|
Index start = 0;
|
||||||
for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
|
for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
|
||||||
// The underlying GEMM kernel assumes that k is a multiple of packet size
|
// The underlying GEMM kernel assumes that k is a multiple of packet size
|
||||||
// (currently largest packet size is 8) and subtle breakage occurs if
|
// (currently largest packet size is 16) and subtle breakage occurs if
|
||||||
// this is violated.
|
// this is violated.
|
||||||
block_size = 8 * divup<Index>(k - start, 8 * blocks_left);
|
block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left);
|
||||||
Scalar* buf;
|
Scalar* buf;
|
||||||
if (start == 0) {
|
if (start == 0) {
|
||||||
buf = result;
|
buf = result;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user