From 4beb447e27baaa19081e835bd6aba76e9b02cc67 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Fri, 22 Jan 2016 14:37:26 -0800 Subject: [PATCH] Created a mechanism to enable contraction mappers to determine the best blocking strategy. --- unsupported/Eigen/CXX11/Tensor | 1 + .../src/Tensor/TensorContractionBlocking.h | 58 +++++++++++++++++++ 2 files changed, 59 insertions(+) create mode 100644 unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index 1c5734383..b4f860c41 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -89,6 +89,7 @@ typedef unsigned __int64 uint64_t; #include "src/Tensor/TensorArgMax.h" #include "src/Tensor/TensorConcatenation.h" #include "src/Tensor/TensorContractionMapper.h" +#include "src/Tensor/TensorContractionBlocking.h" #include "src/Tensor/TensorContraction.h" #include "src/Tensor/TensorContractionThreadPool.h" #include "src/Tensor/TensorContractionCuda.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h new file mode 100644 index 000000000..78ed5038f --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h @@ -0,0 +1,58 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H +#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H + + +namespace Eigen { +namespace internal { + +enum { + ShardByRow = 0, + ShardByCol = 1 +}; + + +// Default Blocking Strategy +template +class TensorContractionBlocking { + public: + + typedef typename LhsMapper::Scalar LhsScalar; + typedef typename RhsMapper::Scalar RhsScalar; + + TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : + kc_(k), mc_(m), nc_(n) + { + if (ShardingType == ShardByCol) { + computeProductBlockingSizes(kc_, mc_, nc_, num_threads); + } + else { + if (kc_ && mc_ && nc_) { + mc_ = (((m / num_threads) + 15) / 16) * 16; + } + } + } + + EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } + EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } + EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } + + private: + Index kc_; + Index mc_; + Index nc_; +}; + + +} // end namespace internal +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_BLOCKING_H