mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Detect "effectively inner/outer" chipping in TensorChipping
This commit is contained in:
parent
648bce6cae
commit
c59332d74a
@ -173,6 +173,26 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
|
||||
}
|
||||
m_inputStride *= input_dims[m_dim.actualDim()];
|
||||
m_inputOffset = m_stride * op.offset();
|
||||
|
||||
// Check if chipping is effectively inner or outer: products of dimensions
|
||||
// before or after the chipped dimension is `1`.
|
||||
Index after_chipped_dim_product = 1;
|
||||
for (int i = m_dim.actualDim() + 1; i < NumInputDims; ++i) {
|
||||
after_chipped_dim_product *= input_dims[i];
|
||||
}
|
||||
|
||||
Index before_chipped_dim_product = 1;
|
||||
for (int i = 0; i < m_dim.actualDim(); ++i) {
|
||||
before_chipped_dim_product *= input_dims[i];
|
||||
}
|
||||
|
||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||
m_isEffectivelyInnerChipping = before_chipped_dim_product == 1;
|
||||
m_isEffectivelyOuterChipping = after_chipped_dim_product == 1;
|
||||
} else {
|
||||
m_isEffectivelyInnerChipping = after_chipped_dim_product == 1;
|
||||
m_isEffectivelyOuterChipping = before_chipped_dim_product == 1;
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||
@ -336,13 +356,11 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isInnerChipping() const {
|
||||
return IsInnerChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == 0) ||
|
||||
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == NumInputDims - 1);
|
||||
return IsInnerChipping || m_isEffectivelyInnerChipping;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isOuterChipping() const {
|
||||
return IsOuterChipping || (static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == NumInputDims - 1) ||
|
||||
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == 0);
|
||||
return IsOuterChipping || m_isEffectivelyOuterChipping;
|
||||
}
|
||||
|
||||
Dimensions m_dimensions;
|
||||
@ -352,6 +370,11 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
|
||||
TensorEvaluator<ArgType, Device> m_impl;
|
||||
const internal::DimensionId<DimId> m_dim;
|
||||
const Device EIGEN_DEVICE_REF m_device;
|
||||
|
||||
// If product of all dimensions after or before the chipped dimension is `1`,
|
||||
// it is effectively the same as chipping innermost or outermost dimension.
|
||||
bool m_isEffectivelyInnerChipping;
|
||||
bool m_isEffectivelyOuterChipping;
|
||||
};
|
||||
|
||||
// Eval as lvalue
|
||||
|
Loading…
x
Reference in New Issue
Block a user