From b1789c112b5cf8d478a03786c6c1243320aefd47 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 3 Nov 2014 08:51:33 -0800 Subject: [PATCH] Improved handling of 1d tensors --- .../CXX11/src/Tensor/TensorContraction.h | 98 +++++++++++++++++-- .../src/Tensor/TensorContractionThreadPool.h | 12 ++- 2 files changed, 99 insertions(+), 11 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index c530b27a7..8e898619d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -48,7 +48,7 @@ class BaseTensorContractionMapper { m_k_strides(k_strides) { } EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE void prefetch(int /*i*/) { } + EIGEN_STRONG_INLINE void prefetch(Index /*i*/) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(Index row) const { @@ -142,6 +142,13 @@ class BaseTensorContractionMapper { return IndexPair(linidx[0], linidx[1]); } + Index firstAligned(Index size) const { + return size; + } + Index stride() const { + return 1; + } + protected: const Tensor m_tensor; const nocontract_t m_nocontract_strides; @@ -202,6 +209,18 @@ class TensorContractionSubMapper { return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset); } + template + EIGEN_ALWAYS_INLINE PacketT load(Index i) const { + EIGEN_STATIC_ASSERT((internal::is_same::value), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((AlignmentType == Aligned || Alignment == Unaligned), YOU_MADE_A_PROGRAMMING_MISTAKE); + return loadPacket(i); + } + + template + bool aligned(Index /*i*/) const { + return false; + } + private: const ParentMapper& m_base_mapper; const Index m_vert_offset; @@ -220,6 +239,7 @@ class TensorContractionInputMapper public: typedef BaseTensorContractionMapper Base; typedef TensorContractionSubMapper SubMapper; + typedef SubMapper VectorMapper; TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides, @@ -233,6 +253,10 @@ class TensorContractionInputMapper return SubMapper(*this, i, j); } + EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const { + return VectorMapper(*this, i, j); + } + typedef typename packet_traits::type Packet; typedef typename packet_traits::half HalfPacket; @@ -306,6 +330,7 @@ class TensorContractionInputMapper Base; typedef TensorContractionSubMapper SubMapper; + typedef SubMapper VectorMapper; TensorContractionInputMapper(const Tensor& tensor, const nocontract_t& nocontract_strides, @@ -319,6 +344,10 @@ class TensorContractionInputMapper::type Packet; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j) const { @@ -592,41 +621,80 @@ struct TensorContractionEvaluatorBase if (this->m_lhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } else { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } else { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } } } else { if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_reordered) { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } else { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } } else { if (this->m_rhs_inner_dim_reordered) { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } else { - static_cast(this)->template evalTyped(buffer); + static_cast(this)->template evalProduct(buffer); } } } } + template + void evalGemv(Scalar* buffer) const { + const Index rows = m_i_size; + const Index cols = m_k_size; + + typedef typename internal::remove_const::type LhsScalar; + typedef typename internal::remove_const::type RhsScalar; + typedef TensorEvaluator LeftEvaluator; + typedef TensorEvaluator RightEvaluator; + const int lhs_packet_size = internal::packet_traits::size; + const int rhs_packet_size = internal::packet_traits::size; + typedef internal::TensorContractionInputMapper LhsMapper; + + typedef internal::TensorContractionInputMapper RhsMapper; + + LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides, + m_left_contracting_strides, m_k_strides); + RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides, + m_right_contracting_strides, m_k_strides); + + const Scalar alpha(1); + const Index resIncr(1); + + // zero out the result buffer (which must be of size at least rows * sizeof(Scalar) + m_device.memset(buffer, 0, rows * sizeof(Scalar)); + + internal::general_matrix_vector_product::run( + rows, cols, lhs, rhs, + buffer, resIncr, alpha); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); m_rightImpl.cleanup(); @@ -707,7 +775,17 @@ struct TensorEvaluator - EIGEN_DEVICE_FUNC void evalTyped(Scalar* buffer) const { + void evalProduct(Scalar* buffer) const { + if (this->m_j_size == 1) { + this->template evalGemv(buffer); + return; + } + + evalGemm(buffer); + } + + template + EIGEN_DEVICE_FUNC void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index cf1352a31..f0e9bb616 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -93,7 +93,17 @@ struct TensorEvaluator - void evalTyped(Scalar* buffer) const { + void evalProduct(Scalar* buffer) const { + if (this->m_j_size == 1) { + this->template evalGemv(buffer); + return; + } + + evalGemm(buffer); + } + + template + void evalGemm(Scalar* buffer) const { // columns in left side, rows in right side const Index k = this->m_k_size;