Fix evalShardedByInnerDim for AVX512 builds

evalShardedByInnerDim ensures that the values it passes for start_k and
end_k to evalGemmPartialWithoutOutputKernel are multiples of 8 as the kernel
does not work correctly when the values of k are not multiples of the
packet_size.  While this precaution works for AVX builds, it is insufficient
for AVX512 builds where the maximum packet size is 16.  The result is slightly
incorrect float32 contractions on AVX512 builds.

This commit fixes the problem by ensuring that k is always a multiple of
the packet_size if the packet_size is > 8.
This commit is contained in:
Mark D Ryan 2018-12-05 12:29:03 +01:00
parent 6f5b126e6d
commit 36f8f6d0be

View File

@ -788,9 +788,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index m = this->m_i_size; const Index m = this->m_i_size;
const Index n = this->m_j_size; const Index n = this->m_j_size;
const Index k = this->m_k_size; const Index k = this->m_k_size;
// The underlying GEMM kernel assumes that k is a multiple of 8 and const Index packet_size = internal::packet_traits<RhsScalar>::size;
// subtle breakage occurs if this is violated. const Index kmultiple = packet_size <= 8 ? 8 : packet_size;
Index block_size = 8 * divup<Index>(k, 8 * num_threads); // The underlying GEMM kernel assumes that k is a multiple of
// the packet size and subtle breakage occurs if this is violated.
Index block_size = kmultiple * divup<Index>(k, kmultiple * num_threads);
Index num_blocks = divup<Index>(k, block_size); Index num_blocks = divup<Index>(k, block_size);
// we use 'result' for the first block's partial result. // we use 'result' for the first block's partial result.
MaxSizeVector<Scalar*> block_buffers(num_blocks - 1); MaxSizeVector<Scalar*> block_buffers(num_blocks - 1);
@ -805,9 +807,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Index start = 0; Index start = 0;
for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) { for (Index blocks_left = num_blocks; blocks_left > 0; --blocks_left) {
// The underlying GEMM kernel assumes that k is a multiple of packet size // The underlying GEMM kernel assumes that k is a multiple of packet size
// (currently largest packet size is 8) and subtle breakage occurs if // (currently largest packet size is 16) and subtle breakage occurs if
// this is violated. // this is violated.
block_size = 8 * divup<Index>(k - start, 8 * blocks_left); block_size = kmultiple * divup<Index>(k - start, kmultiple * blocks_left);
Scalar* buf; Scalar* buf;
if (start == 0) { if (start == 0) {
buf = result; buf = result;