From 36f8f6d0be1543e12c87c6f33df46fe7bcecab87 Mon Sep 17 00:00:00 2001 From: Mark D Ryan Date: Wed, 5 Dec 2018 12:29:03 +0100 Subject: [PATCH] Fix evalShardedByInnerDim for AVX512 builds evalShardedByInnerDim ensures that the values it passes for start_k and end_k to evalGemmPartialWithoutOutputKernel are multiples of 8 as the kernel does not work correctly when the values of k are not multiples of the packet_size. While this precaution works for AVX builds, it is insufficient for AVX512 builds where the maximum packet size is 16. The result is slightly incorrect float32 contractions on AVX512 builds. This commit fixes the problem by ensuring that k is always a multiple of the packet_size if the packet_size is > 8. --- .../CXX11/src/Tensor/TensorContractionThreadPool.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 24ba3e431..3946e2fc4 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -788,9 +788,11 @@ struct TensorEvaluatorm_i_size; const Index n = this->m_j_size; const Index k = this->m_k_size; - // The underlying GEMM kernel assumes that k is a multiple of 8 and - // subtle breakage occurs if this is violated. - Index block_size = 8 * divup(k, 8 * num_threads); + const Index packet_size = internal::packet_traits::size; + const Index kmultiple = packet_size <= 8 ? 8 : packet_size; + // The underlying GEMM kernel assumes that k is a multiple of + // the packet size and subtle breakage occurs if this is violated. + Index block_size = kmultiple * divup(k, kmultiple * num_threads); Index num_blocks = divup(k, block_size); // we use 'result' for the first block's partial result. MaxSizeVector block_buffers(num_blocks - 1); @@ -805,9 +807,9 @@ struct TensorEvaluator 0; --blocks_left) { // The underlying GEMM kernel assumes that k is a multiple of packet size - // (currently largest packet size is 8) and subtle breakage occurs if + // (currently largest packet size is 16) and subtle breakage occurs if // this is violated. - block_size = 8 * divup(k - start, 8 * blocks_left); + block_size = kmultiple * divup(k - start, kmultiple * blocks_left); Scalar* buf; if (start == 0) { buf = result;