Fix tensor contraction for AVX512 machines

This patch modifies the TensorContraction class to ensure that the kc_ field is
always a multiple of the packet_size, if the packet_size is > 8.  Without this
change spatial convolutions in Tensorflow do not work properly as the code that
re-arranges the input matrices can assert if kc_ is not a multiple of the
packet_size.  This leads to a unit test failure,
//tensorflow/python/kernel_tests:conv_ops_test, on AVX512 builds of tensorflow.
This commit is contained in:
Mark D Ryan 2018-07-31 09:33:37 +01:00
parent 77b447c24e
commit 6f5b126e6d

View File

@ -51,6 +51,10 @@ class TensorContractionBlocking {
else { else {
computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads); computeProductBlockingSizes<LhsScalar, RhsScalar, 1>(kc_, nc_, mc_, num_threads);
} }
const int rhs_packet_size = internal::packet_traits<RhsScalar>::size;
kc_ = (rhs_packet_size <= 8 || kc_ <= rhs_packet_size) ?
kc_ : (kc_ / rhs_packet_size) * rhs_packet_size;
} }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE StorageIndex kc() const { return kc_; }