diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index c5ec42cf4..a02a273e7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -320,6 +320,8 @@ class TensorContractionInputMapper }; + + template struct max_n_1 { + static const size_t size = n; +}; +template <> struct max_n_1<0> { + static const size_t size = 1; +}; + + template struct traits > { @@ -378,6 +388,10 @@ struct traits > typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; + // From NumDims below. + static const int NumDimensions = max_n_1::NumDimensions + traits::NumDimensions - 2 * array_size::value>::size; + static const int Layout = traits::Layout; + enum { Flags = 0, }; @@ -401,19 +415,19 @@ struct traits::NumDimensions + traits::NumDimensions - 2 * array_size::value>::size; }; } // end namespace internal - - template class TensorContractionOp : public TensorBase > { public: typedef typename Eigen::internal::traits::Scalar Scalar; typedef typename Eigen::internal::traits::Packet Packet; - typedef typename Eigen::NumTraits::Real RealScalar; typedef typename internal::promote_storage_type::ret CoeffReturnType; typedef typename internal::promote_storage_type::StorageKind StorageKind; typedef typename Eigen::internal::traits::Index Index; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( + const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} - EIGEN_DEVICE_FUNC - const Indices& indices() const { return m_indices; } + EIGEN_DEVICE_FUNC + const Indices& indices() const { return m_indices; } - /** \returns the nested expressions */ - EIGEN_DEVICE_FUNC - const typename internal::remove_all::type& - lhsExpression() const { return m_lhs_xpr; } + /** \returns the nested expressions */ + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + lhsExpression() const { return m_lhs_xpr; } - EIGEN_DEVICE_FUNC - const typename internal::remove_all::type& - rhsExpression() const { return m_rhs_xpr; } + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + rhsExpression() const { return m_rhs_xpr; } protected: typename LhsXprType::Nested m_lhs_xpr; @@ -444,12 +459,17 @@ class TensorContractionOp : public TensorBase struct max_n_1 { - static const size_t size = n; -}; -template <> struct max_n_1<0> { - static const size_t size = 1; -}; +template struct Cond {}; + +template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +const T1& choose(Cond, const T1& first, const T2&) { + return first; +} + +template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE +const T2& choose(Cond, const T1&, const T2& second) { + return second; +} template @@ -467,37 +487,94 @@ struct TensorContractionEvaluatorBase typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - typedef array::Dimensions::count> left_dim_mapper_t; - typedef array::Dimensions::count> right_dim_mapper_t; - - typedef array::value> contract_t; - typedef array::Dimensions::count - internal::array_size::value>::size> left_nocontract_t; - typedef array::Dimensions::count - internal::array_size::value>::size> right_nocontract_t; - - static const int NumDims = max_n_1::Dimensions::count + TensorEvaluator::Dimensions::count - 2 * internal::array_size::value>::size; - - typedef DSizes Dimensions; - enum { IsAligned = true, PacketAccess = (internal::packet_traits::size > 1), + Layout = TensorEvaluator::Layout, + CoordAccess = false, // to be implemented }; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionEvaluatorBase(const XprType& op, const Device& device) - : m_leftImpl(op.lhsExpression(), device), m_rightImpl(op.rhsExpression(), device), m_device(device), m_result(NULL) - { + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; + + static const int LDims = + internal::array_size::Dimensions>::value; + static const int RDims = + internal::array_size::Dimensions>::value; + static const int ContractDims = internal::array_size::value; + static const int NumDims = internal::max_n_1::size; + + typedef array left_dim_mapper_t; + typedef array right_dim_mapper_t; + typedef array contract_t; + typedef array::size> left_nocontract_t; + typedef array::size> right_nocontract_t; + + typedef DSizes Dimensions; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + TensorContractionEvaluatorBase(const XprType& op, const Device& device) + : m_leftImpl(choose(Cond(), + op.lhsExpression(), op.rhsExpression()), device), + m_rightImpl(choose(Cond(), + op.rhsExpression(), op.lhsExpression()), device), + m_device(device), + m_result(NULL) { + EIGEN_STATIC_ASSERT((TensorEvaluator::Layout == + TensorEvaluator::Layout), + YOU_MADE_A_PROGRAMMING_MISTAKE); + eigen_assert((internal::array_size::value > 0) && "Must contract on some indices"); - array::Dimensions::count> lhs_strides; - lhs_strides[0] = 1; - for (int i = 0; i < TensorEvaluator::Dimensions::count-1; ++i) { - lhs_strides[i+1] = lhs_strides[i] * m_leftImpl.dimensions()[i]; + + DSizes eval_left_dims; + DSizes eval_right_dims; + array, ContractDims> eval_op_indices; + if (Layout == ColMajor) { + // For ColMajor, we keep using the existing dimensions + for (int i = 0; i < LDims; i++) { + eval_left_dims[i] = m_leftImpl.dimensions()[i]; + } + for (int i = 0; i < RDims; i++) { + eval_right_dims[i] = m_rightImpl.dimensions()[i]; + } + // We keep the pairs of contracting indices. + for (int i = 0; i < ContractDims; i++) { + eval_op_indices[i].first = op.indices()[i].first; + eval_op_indices[i].second = op.indices()[i].second; + } + } else { + // For RowMajor, we need to reverse the existing dimensions + for (int i = 0; i < LDims; i++) { + eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1]; + } + for (int i = 0; i < RDims; i++) { + eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1]; + } + // We need to flip all the pairs of contracting indices as well as + // reversing the dimensions. + for (int i = 0; i < ContractDims; i++) { + eval_op_indices[i].first = LDims - 1 - op.indices()[i].second; + eval_op_indices[i].second = RDims - 1 - op.indices()[i].first; + } } - array::Dimensions::count> rhs_strides; + array lhs_strides; + lhs_strides[0] = 1; + for (int i = 0; i < LDims-1; ++i) { + lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i]; + } + + array rhs_strides; rhs_strides[0] = 1; - for (int i = 0; i < TensorEvaluator::Dimensions::count-1; ++i) { - rhs_strides[i+1] = rhs_strides[i] * m_rightImpl.dimensions()[i]; + for (int i = 0; i < RDims-1; ++i) { + rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; } m_i_strides[0] = 1; @@ -515,27 +592,28 @@ struct TensorContractionEvaluatorBase m_lhs_inner_dim_contiguous = true; int dim_idx = 0; int nocontract_idx = 0; - const typename TensorEvaluator::Dimensions& left_dims = m_leftImpl.dimensions(); - for (int i = 0; i < TensorEvaluator::Dimensions::count; i++) { + + for (int i = 0; i < LDims; i++) { // find if we are contracting on index i of left tensor bool contracting = false; - for (int j = 0; j < internal::array_size::value; j++) { - if (op.indices()[j].first == i) { + for (int j = 0; j < ContractDims; j++) { + if (eval_op_indices[j].first == i) { contracting = true; break; } } if (!contracting) { // add dimension size to output dimensions - m_dimensions[dim_idx] = left_dims[i]; + m_dimensions[dim_idx] = eval_left_dims[i]; m_left_nocontract_strides[nocontract_idx] = lhs_strides[i]; if (dim_idx != i) { m_lhs_inner_dim_contiguous = false; } if (nocontract_idx+1 < internal::array_size::value) { - m_i_strides[nocontract_idx+1] = m_i_strides[nocontract_idx] * left_dims[i]; + m_i_strides[nocontract_idx+1] = + m_i_strides[nocontract_idx] * eval_left_dims[i]; } else { - m_i_size = m_i_strides[nocontract_idx] * left_dims[i]; + m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i]; } dim_idx++; nocontract_idx++; @@ -543,22 +621,22 @@ struct TensorContractionEvaluatorBase } nocontract_idx = 0; - const typename TensorEvaluator::Dimensions& right_dims = m_rightImpl.dimensions(); - for (int i = 0; i < TensorEvaluator::Dimensions::count; i++) { + for (int i = 0; i < RDims; i++) { bool contracting = false; // find if we are contracting on index i of right tensor - for (int j = 0; j < internal::array_size::value; j++) { - if (op.indices()[j].second == i) { + for (int j = 0; j < ContractDims; j++) { + if (eval_op_indices[j].second == i) { contracting = true; break; } } if (!contracting) { - m_dimensions[dim_idx] = right_dims[i]; + m_dimensions[dim_idx] = eval_right_dims[i]; if (nocontract_idx+1 < internal::array_size::value) { - m_j_strides[nocontract_idx+1] = m_j_strides[nocontract_idx] * right_dims[i]; + m_j_strides[nocontract_idx+1] = + m_j_strides[nocontract_idx] * eval_right_dims[i]; } else { - m_j_size = m_j_strides[nocontract_idx] * right_dims[i]; + m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i]; } m_right_nocontract_strides[nocontract_idx] = rhs_strides[i]; dim_idx++; @@ -573,12 +651,13 @@ struct TensorContractionEvaluatorBase // each tensor, we'll only look at the first tensor here. m_rhs_inner_dim_contiguous = true; m_rhs_inner_dim_reordered = false; - for (int i = 0; i < internal::array_size::value; i++) { - Index left = op.indices()[i].first; - Index right = op.indices()[i].second; + for (int i = 0; i < ContractDims; i++) { + Index left = eval_op_indices[i].first; + Index right = eval_op_indices[i].second; - Index size = left_dims[left]; - eigen_assert(size == right_dims[right] && "Contraction axes must be same size"); + Index size = eval_left_dims[left]; + eigen_assert(size == eval_right_dims[right] && + "Contraction axes must be same size"); if (i+1 < internal::array_size::value) { m_k_strides[i+1] = m_k_strides[i] * size; @@ -588,7 +667,7 @@ struct TensorContractionEvaluatorBase m_left_contracting_strides[i] = lhs_strides[left]; m_right_contracting_strides[i] = rhs_strides[right]; - if (i > 0 && right < op.indices()[i-1].second) { + if (i > 0 && right < eval_op_indices[i-1].second) { m_rhs_inner_dim_reordered = true; } if (right != i) { @@ -597,9 +676,16 @@ struct TensorContractionEvaluatorBase } // Scalar case. We represent the result as a 1d tensor of size 1. - if (TensorEvaluator::Dimensions::count + TensorEvaluator::Dimensions::count == 2 * internal::array_size::value) { + if (LDims + RDims == 2 * ContractDims) { m_dimensions[0] = 1; } + + // If the layout is RowMajor, we need to reverse the m_dimensions + if (Layout == RowMajor) { + for (int i = 0, j = NumDims - 1; i < j; i++, j--) { + std::swap(m_dimensions[i], m_dimensions[j]); + } + } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -661,10 +747,10 @@ struct TensorContractionEvaluatorBase const Index rows = m_i_size; const Index cols = m_k_size; - typedef typename internal::remove_const::type LhsScalar; - typedef typename internal::remove_const::type RhsScalar; - typedef TensorEvaluator LeftEvaluator; - typedef TensorEvaluator RightEvaluator; + typedef typename internal::remove_const::type LhsScalar; + typedef typename internal::remove_const::type RhsScalar; + typedef TensorEvaluator LeftEvaluator; + typedef TensorEvaluator RightEvaluator; const int lhs_packet_size = internal::packet_traits::size; const int rhs_packet_size = internal::packet_traits::size; typedef internal::TensorContractionInputMapper m_leftImpl; - TensorEvaluator m_rightImpl; + TensorEvaluator m_leftImpl; + TensorEvaluator m_rightImpl; const Device& m_device; Scalar* m_result; }; +// evaluator for default device template struct TensorEvaluator, Device> : - public TensorContractionEvaluatorBase, Device> > { + public TensorContractionEvaluatorBase< + TensorEvaluator, Device> > { typedef TensorEvaluator, Device> Self; typedef TensorContractionEvaluatorBase Base; @@ -759,15 +846,35 @@ struct TensorEvaluator::Dimensions::count> left_dim_mapper_t; - typedef array::Dimensions::count> right_dim_mapper_t; + enum { + Layout = TensorEvaluator::Layout, + }; - typedef array::value> contract_t; - typedef array::Dimensions::count - internal::array_size::value>::size> left_nocontract_t; - typedef array::Dimensions::count - internal::array_size::value>::size> right_nocontract_t; + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; - static const int NumDims = max_n_1::Dimensions::count + TensorEvaluator::Dimensions::count - 2 * internal::array_size::value>::size; + static const int LDims = + internal::array_size::Dimensions>::value; + static const int RDims = + internal::array_size::Dimensions>::value; + static const int ContractDims = internal::array_size::value; + typedef array left_dim_mapper_t; + typedef array right_dim_mapper_t; + + typedef array contract_t; + typedef array::size> left_nocontract_t; + typedef array::size> right_nocontract_t; + + static const int NumDims = internal::max_n_1::size; + + // Could we use NumDimensions here? typedef DSizes Dimensions; @@ -799,15 +906,15 @@ struct TensorEvaluatorm_device.memset(buffer, 0, m * n * sizeof(Scalar)); // define mr, nr, and all of my data mapper types - typedef typename internal::remove_const::type LhsScalar; - typedef typename internal::remove_const::type RhsScalar; + typedef typename internal::remove_const::type LhsScalar; + typedef typename internal::remove_const::type RhsScalar; typedef typename internal::gebp_traits Traits; const Index nr = Traits::nr; const Index mr = Traits::mr; - typedef TensorEvaluator LeftEvaluator; - typedef TensorEvaluator RightEvaluator; + typedef TensorEvaluator LeftEvaluator; + typedef TensorEvaluator RightEvaluator; const int lhs_packet_size = internal::packet_traits::size; const int rhs_packet_size = internal::packet_traits::size; @@ -826,10 +933,10 @@ struct TensorEvaluator OutputMapper; - // Declare GEBP packing and kernel structs internal::gemm_pack_lhs pack_lhs; internal::gemm_pack_rhs pack_rhs; + internal::gebp_kernel gebp; // initialize data mappers diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h index f6bd949bd..588770bb4 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionCuda.h @@ -1241,10 +1241,10 @@ struct TensorEvaluator right_dim_mapper_t; typedef array contract_t; - typedef array::size> left_nocontract_t; - typedef array::size> right_nocontract_t; + typedef array::size> left_nocontract_t; + typedef array::size> right_nocontract_t; - static const int NumDims = max_n_1::size; + static const int NumDims = internal::max_n_1::size; typedef DSizes Dimensions; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h index f0e9bb616..5851e5adc 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h @@ -70,24 +70,43 @@ struct TensorEvaluator::Dimensions::count> left_dim_mapper_t; - typedef array::Dimensions::count> right_dim_mapper_t; + enum { + Layout = TensorEvaluator::Layout, + }; - typedef array::value> contract_t; - typedef array::Dimensions::count - internal::array_size::value>::size> left_nocontract_t; - typedef array::Dimensions::count - internal::array_size::value>::size> right_nocontract_t; + // Most of the code is assuming that both input tensors are ColMajor. If the + // inputs are RowMajor, we will "cheat" by swapping the LHS and RHS: + // If we want to compute A * B = C, where A is LHS and B is RHS, the code + // will pretend B is LHS and A is RHS. + typedef typename internal::conditional< + Layout == ColMajor, LeftArgType, RightArgType>::type EvalLeftArgType; + typedef typename internal::conditional< + Layout == ColMajor, RightArgType, LeftArgType>::type EvalRightArgType; - static const int NumDims = max_n_1::Dimensions::count + TensorEvaluator::Dimensions::count - 2 * internal::array_size::value>::size; + static const int LDims = + internal::array_size::Dimensions>::value; + static const int RDims = + internal::array_size::Dimensions>::value; + static const int ContractDims = internal::array_size::value; + + typedef array left_dim_mapper_t; + typedef array right_dim_mapper_t; + + typedef array contract_t; + typedef array::size> left_nocontract_t; + typedef array::size> right_nocontract_t; + + static const int NumDims = max_n_1::size; typedef DSizes Dimensions; // typedefs needed in evalTo - typedef typename internal::remove_const::type LhsScalar; - typedef typename internal::remove_const::type RhsScalar; + typedef typename internal::remove_const::type LhsScalar; + typedef typename internal::remove_const::type RhsScalar; typedef typename internal::gebp_traits Traits; - typedef TensorEvaluator LeftEvaluator; - typedef TensorEvaluator RightEvaluator; + typedef TensorEvaluator LeftEvaluator; + typedef TensorEvaluator RightEvaluator; TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) {}