From 6e40454a6e6cc57c07c7340148657c985ca6c928 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 2 Oct 2019 11:06:02 -0700 Subject: [PATCH] Add beta to TensorContractionKernel and make memset optional --- .../CXX11/src/Tensor/TensorContraction.h | 32 ++++++++++++------- .../src/Tensor/TensorContractionThreadPool.h | 29 ++++++++++------- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index d61209133..87e8db3fd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -180,6 +180,10 @@ template struct TensorContractionKernel { + // True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C` + // (otherwise beta should be always equal to 1). + enum { HasBeta = false }; + EIGEN_DEVICE_FUNC TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_, StorageIndex bm_, StorageIndex bk_, StorageIndex bn_) @@ -248,7 +252,9 @@ struct TensorContractionKernel { const OutputMapper& output_mapper, const LhsBlock& lhsBlock, const RhsBlock& rhsBlock, const StorageIndex rows, const StorageIndex depth, const StorageIndex cols, - const ResScalar alpha) { + const ResScalar alpha, const ResScalar beta) { + // Default GEBP kernel does not support beta. + eigen_assert(beta == ResScalar(1)); static const int kComputeStrideFromBlockDimensions = -1; GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, /*strideA*/ kComputeStrideFromBlockDimensions, @@ -772,15 +778,6 @@ struct TensorContractionEvaluatorBase void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; - - // rows in left side - const Index m = this->m_i_size; - - // columns in right side - const Index n = this->m_j_size; - - // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) - this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); this->template evalGemmPartialm_device, &blockA, &blockB); + // If a contraction kernel does not support beta, explicitly initialize + // output buffer with zeroes. + if (!TensorContractionKernel::HasBeta) { + this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); + } + for(Index i2=0; i2= k_end) { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index 873db5efd..26c9fac17 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -904,14 +904,16 @@ struct TensorEvaluator void processBlock(Index block_idx, Index begin, Index end) { Scalar* buf = block_buffers[block_idx]; - ::memset(buf, 0, buffer_size_bytes); TENSOR_CONTRACTION_DISPATCH( evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,