Add outer/inner chipping optimization for chipping dimension specified at runtime

This commit is contained in:
Eugene Zhulenev 2019-07-03 11:35:25 -07:00
parent 7eb2e0a95b
commit 6083014594

View File

@ -239,7 +239,7 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device>
EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
eigen_assert(index+PacketSize-1 < dimensions().TotalSize());
if (IsInnerChipping) {
if (isInnerChipping()) {
// m_stride is equal to 1, so let's avoid the integer division.
eigen_assert(m_stride == 1);
Index inputIndex = index * m_inputStride + m_inputOffset;
@ -251,7 +251,7 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device>
}
PacketReturnType rslt = internal::pload<PacketReturnType>(values);
return rslt;
} else if (IsOuterChipping) {
} else if (isOuterChipping()) {
// m_stride is always greater than index, so let's avoid the integer division.
eigen_assert(m_stride > index);
return m_impl.template packet<LoadMode>(index + m_inputOffset);
@ -354,7 +354,7 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Storage::Type data() const {
typename Storage::Type result = constCast(m_impl.data());
if (IsOuterChipping && result) {
if (isOuterChipping() && result) {
return result + m_inputOffset;
} else {
return NULL;
@ -371,11 +371,11 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index srcCoeff(Index index) const
{
Index inputIndex;
if (IsInnerChipping) {
if (isInnerChipping()) {
// m_stride is equal to 1, so let's avoid the integer division.
eigen_assert(m_stride == 1);
inputIndex = index * m_inputStride + m_inputOffset;
} else if (IsOuterChipping) {
} else if (isOuterChipping()) {
// m_stride is always greater than index, so let's avoid the integer
// division.
eigen_assert(m_stride > index);
@ -389,6 +389,18 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device>
return inputIndex;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isInnerChipping() const {
return IsInnerChipping ||
(static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == 0) ||
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == NumInputDims - 1);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool isOuterChipping() const {
return IsOuterChipping ||
(static_cast<int>(Layout) == ColMajor && m_dim.actualDim() == NumInputDims-1) ||
(static_cast<int>(Layout) == RowMajor && m_dim.actualDim() == 0);
}
Dimensions m_dimensions;
Index m_stride;
Index m_inputOffset;
@ -421,16 +433,7 @@ struct TensorEvaluator<TensorChippingOp<DimId, ArgType>, Device>
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = TensorEvaluator<ArgType, Device>::BlockAccess,
Layout = TensorEvaluator<ArgType, Device>::Layout,
RawAccess = false,
// Chipping of outer-most dimension is a trivial operation, because we can
// read and write directly from the underlying tensor using single offset.
IsOuterChipping =
(static_cast<int>(Layout) == ColMajor && DimId == NumInputDims - 1) ||
(static_cast<int>(Layout) == RowMajor && DimId == 0),
// Chipping inner-most dimension.
IsInnerChipping =
(static_cast<int>(Layout) == ColMajor && DimId == 0) ||
(static_cast<int>(Layout) == RowMajor && DimId == NumInputDims - 1)
RawAccess = false
};
typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
@ -454,7 +457,7 @@ struct TensorEvaluator<TensorChippingOp<DimId, ArgType>, Device>
{
EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE)
if (IsInnerChipping) {
if (this->isInnerChipping()) {
// m_stride is equal to 1, so let's avoid the integer division.
eigen_assert(this->m_stride == 1);
EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
@ -465,7 +468,7 @@ struct TensorEvaluator<TensorChippingOp<DimId, ArgType>, Device>
this->m_impl.coeffRef(inputIndex) = values[i];
inputIndex += this->m_inputStride;
}
} else if (IsOuterChipping) {
} else if (this->isOuterChipping()) {
// m_stride is always greater than index, so let's avoid the integer division.
eigen_assert(this->m_stride > index);
this->m_impl.template writePacket<StoreMode>(index + this->m_inputOffset, x);