Detect "effectively inner/outer" chipping in TensorChipping

This commit is contained in:
Eugene Zhulenev 2024-08-29 17:49:59 +00:00 committed by Rasmus Munk Larsen
parent 648bce6cae
commit c59332d74a

View File

@ -173,6 +173,26 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
} }
m_inputStride *= input_dims[m_dim.actualDim()]; m_inputStride *= input_dims[m_dim.actualDim()];
m_inputOffset = m_stride * op.offset(); m_inputOffset = m_stride * op.offset();
// Check if chipping is effectively inner or outer: products of dimensions
// before or after the chipped dimension is `1`.
Index after_chipped_dim_product = 1;
for (int i = m_dim.actualDim() + 1; i < NumInputDims; ++i) {
after_chipped_dim_product *= input_dims[i];
}
Index before_chipped_dim_product = 1;
for (int i = 0; i < m_dim.actualDim(); ++i) {
before_chipped_dim_product *= input_dims[i];
}
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
m_isEffectivelyInnerChipping = before_chipped_dim_product == 1;
m_isEffectivelyOuterChipping = after_chipped_dim_product == 1;
} else {
m_isEffectivelyInnerChipping = after_chipped_dim_product == 1;
m_isEffectivelyOuterChipping = before_chipped_dim_product == 1;
}
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
@ -336,13 +356,11 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isInnerChipping() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isInnerChipping() const {
return IsInnerChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == 0) || return IsInnerChipping || m_isEffectivelyInnerChipping;
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == NumInputDims - 1);
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isOuterChipping() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isOuterChipping() const {
return IsOuterChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == NumInputDims - 1) || return IsOuterChipping || m_isEffectivelyOuterChipping;
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == 0);
} }
Dimensions m_dimensions; Dimensions m_dimensions;
@ -352,6 +370,11 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
TensorEvaluator<ArgType, Device> m_impl; TensorEvaluator<ArgType, Device> m_impl;
const internal::DimensionId<DimId> m_dim; const internal::DimensionId<DimId> m_dim;
const Device EIGEN_DEVICE_REF m_device; const Device EIGEN_DEVICE_REF m_device;
// If product of all dimensions after or before the chipped dimension is `1`,
// it is effectively the same as chipping innermost or outermost dimension.
bool m_isEffectivelyInnerChipping;
bool m_isEffectivelyOuterChipping;
}; };
// Eval as lvalue // Eval as lvalue