Detect "effectively inner/outer" chipping in TensorChipping

This commit is contained in:
Eugene Zhulenev 2024-08-29 17:49:59 +00:00 committed by Rasmus Munk Larsen
parent 648bce6cae
commit c59332d74a

View File

@ -173,6 +173,26 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
}
m_inputStride *= input_dims[m_dim.actualDim()];
m_inputOffset = m_stride * op.offset();
// Check if chipping is effectively inner or outer: products of dimensions
// before or after the chipped dimension is `1`.
Index after_chipped_dim_product = 1;
for (int i = m_dim.actualDim() + 1; i < NumInputDims; ++i) {
after_chipped_dim_product *= input_dims[i];
}
Index before_chipped_dim_product = 1;
for (int i = 0; i < m_dim.actualDim(); ++i) {
before_chipped_dim_product *= input_dims[i];
}
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
m_isEffectivelyInnerChipping = before_chipped_dim_product == 1;
m_isEffectivelyOuterChipping = after_chipped_dim_product == 1;
} else {
m_isEffectivelyInnerChipping = after_chipped_dim_product == 1;
m_isEffectivelyOuterChipping = before_chipped_dim_product == 1;
}
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
@ -336,13 +356,11 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isInnerChipping() const {
return IsInnerChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == 0) ||
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == NumInputDims - 1);
return IsInnerChipping || m_isEffectivelyInnerChipping;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isOuterChipping() const {
return IsOuterChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == NumInputDims - 1) ||
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == 0);
return IsOuterChipping || m_isEffectivelyOuterChipping;
}
Dimensions m_dimensions;
@ -352,6 +370,11 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
TensorEvaluator<ArgType, Device> m_impl;
const internal::DimensionId<DimId> m_dim;
const Device EIGEN_DEVICE_REF m_device;
// If product of all dimensions after or before the chipped dimension is `1`,
// it is effectively the same as chipping innermost or outermost dimension.
bool m_isEffectivelyInnerChipping;
bool m_isEffectivelyOuterChipping;
};
// Eval as lvalue