diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h index 000b1fb58..f5172cd8d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h @@ -173,6 +173,26 @@ struct TensorEvaluator, Device> { } m_inputStride *= input_dims[m_dim.actualDim()]; m_inputOffset = m_stride * op.offset(); + + // Check if chipping is effectively inner or outer: products of dimensions + // before or after the chipped dimension is `1`. + Index after_chipped_dim_product = 1; + for (int i = m_dim.actualDim() + 1; i < NumInputDims; ++i) { + after_chipped_dim_product *= input_dims[i]; + } + + Index before_chipped_dim_product = 1; + for (int i = 0; i < m_dim.actualDim(); ++i) { + before_chipped_dim_product *= input_dims[i]; + } + + if (static_cast(Layout) == static_cast(ColMajor)) { + m_isEffectivelyInnerChipping = before_chipped_dim_product == 1; + m_isEffectivelyOuterChipping = after_chipped_dim_product == 1; + } else { + m_isEffectivelyInnerChipping = after_chipped_dim_product == 1; + m_isEffectivelyOuterChipping = before_chipped_dim_product == 1; + } } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } @@ -336,13 +356,11 @@ struct TensorEvaluator, Device> { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isInnerChipping() const { - return IsInnerChipping || (static_cast(Layout) == ColMajor && m_dim.actualDim() == 0) || - (static_cast(Layout) == RowMajor && m_dim.actualDim() == NumInputDims - 1); + return IsInnerChipping || m_isEffectivelyInnerChipping; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isOuterChipping() const { - return IsOuterChipping || (static_cast(Layout) == ColMajor && m_dim.actualDim() == NumInputDims - 1) || - (static_cast(Layout) == RowMajor && m_dim.actualDim() == 0); + return IsOuterChipping || m_isEffectivelyOuterChipping; } Dimensions m_dimensions; @@ -352,6 +370,11 @@ struct TensorEvaluator, Device> { TensorEvaluator m_impl; const internal::DimensionId m_dim; const Device EIGEN_DEVICE_REF m_device; + + // If product of all dimensions after or before the chipped dimension is `1`, + // it is effectively the same as chipping innermost or outermost dimension. + bool m_isEffectivelyInnerChipping; + bool m_isEffectivelyOuterChipping; }; // Eval as lvalue