diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 6f113b903..9d0d432ee 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -37,7 +37,7 @@ struct traits > typedef typename remove_reference::type _RhsNested; // From NumDims below. - static const int NumDimensions = max_n_1::NumDimensions + traits::NumDimensions - 2 * array_size::value>::size; + static const int NumDimensions = traits::NumDimensions + traits::NumDimensions - 2 * array_size::value; static const int Layout = traits::Layout; enum { @@ -65,7 +65,7 @@ struct traits::NumDimensions + traits::NumDimensions - 2 * array_size::value>::size; + static const int NumDimensions = traits::NumDimensions + traits::NumDimensions - 2 * array_size::value; }; } // end namespace internal @@ -140,7 +140,7 @@ struct TensorContractionEvaluatorBase static const int RDims = internal::array_size::Dimensions>::value; static const int ContractDims = internal::array_size::value; - static const int NumDims = max_n_1::size; + static const int NumDims = LDims + RDims - 2 * ContractDims; typedef array contract_t; typedef array::size> left_nocontract_t; @@ -218,11 +218,9 @@ struct TensorContractionEvaluatorBase rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i]; } - m_i_strides[0] = 1; - m_j_strides[0] = 1; - if(ContractDims) { - m_k_strides[0] = 1; - } + if (m_i_strides.size() > 0) m_i_strides[0] = 1; + if (m_j_strides.size() > 0) m_j_strides[0] = 1; + if (m_k_strides.size() > 0) m_k_strides[0] = 1; m_i_size = 1; m_j_size = 1; @@ -318,11 +316,6 @@ struct TensorContractionEvaluatorBase } } - // Scalar case. We represent the result as a 1d tensor of size 1. - if (LDims + RDims == 2 * ContractDims) { - m_dimensions[0] = 1; - } - // If the layout is RowMajor, we need to reverse the m_dimensions if (static_cast(Layout) == static_cast(RowMajor)) { for (int i = 0, j = NumDims - 1; i < j; i++, j--) { @@ -607,15 +600,14 @@ struct TensorEvaluator::value; typedef array contract_t; - typedef array::size> left_nocontract_t; - typedef array::size> right_nocontract_t; + typedef array left_nocontract_t; + typedef array right_nocontract_t; - static const int NumDims = max_n_1::size; + static const int NumDims = LDims + RDims - 2 * ContractDims; // Could we use NumDimensions here? typedef DSizes Dimensions; - EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : Base(op, device) { } diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index 0e16308a2..73623b2ed 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -87,19 +87,14 @@ static void test_scalar() vec1.setRandom(); vec2.setRandom(); - Tensor scalar(1); - scalar.setZero(); Eigen::array dims = {{DimPair(0, 0)}}; - typedef TensorEvaluator Evaluator; - Evaluator eval(vec1.contract(vec2, dims), DefaultDevice()); - eval.evalTo(scalar.data()); - EIGEN_STATIC_ASSERT(Evaluator::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE); + Tensor scalar = vec1.contract(vec2, dims); float expected = 0.0f; for (int i = 0; i < 6; ++i) { expected += vec1(i) * vec2(i); } - VERIFY_IS_APPROX(scalar(0), expected); + VERIFY_IS_APPROX(scalar(), expected); } template