From 3aeeca32af00b1921b4424d7be2e03bbaeaa05b4 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 22 Jan 2016 16:36:30 -0800 Subject: [PATCH] Leverage the new blocking code in the tensor contraction code. --- unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h | 4 +--- .../Eigen/CXX11/src/Tensor/TensorContractionMapper.h | 5 +++-- .../Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h | 8 ++++---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 624e814e2..e6a008ba7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -582,10 +582,8 @@ struct TensorEvaluator BlockingType; - // Sizes of the blocks to load in cache. See the Goto paper for details. - BlockingType blocking(m, n, k, 1, true); + internal::TensorContractionBlocking blocking(k, m, n, 1); const Index kc = blocking.kc(); const Index mc = numext::mini(m, blocking.mc()); const Index nc = numext::mini(n, blocking.nc()); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h index 9b6d18090..63c8ae126 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h @@ -426,15 +426,16 @@ class TensorContractionSubMapper { }; -template class TensorContractionInputMapper - : public BaseTensorContractionMapper { + : public BaseTensorContractionMapper { public: + typedef Scalar_ Scalar; typedef BaseTensorContractionMapper Base; typedef TensorContractionSubMapper SubMapper; typedef SubMapper VectorMapper; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 576bea295..51a3b9490 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -176,10 +176,10 @@ struct TensorEvaluatorm_device.numThreads(); - Index mc = m; - Index nc = n; - Index kc = k; - internal::computeProductBlockingSizes(kc, mc, nc, num_threads); + internal::TensorContractionBlocking blocking(k, m, n, num_threads); + Index mc = blocking.mc(); + Index nc = blocking.nc(); + Index kc = blocking.kc(); eigen_assert(mc <= m); eigen_assert(nc <= n); eigen_assert(kc <= k);