Fix TensorReverse on GPU with m_stride[i]==0

This commit is contained in:
Eugene Zhulenev 2019-06-28 15:50:39 -07:00
parent 8053eeb51e
commit 81a03bec75

View File

@ -122,6 +122,8 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
RawAccess = false
};
typedef internal::TensorIntDivisor<Index> IndexDivisor;
typedef typename internal::remove_const<Scalar>::type ScalarNoConst;
typedef internal::TensorBlock<ScalarNoConst, Index, NumDims, Layout>
OutputTensorBlock;
@ -141,17 +143,15 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
m_strides[0] = 1;
for (int i = 1; i < NumDims; ++i) {
m_strides[i] = m_strides[i-1] * m_dimensions[i-1];
if (m_strides[i] > 0) m_fastStrides[i] = IndexDivisor(m_strides[i]);
}
} else {
m_strides[NumDims-1] = 1;
for (int i = NumDims - 2; i >= 0; --i) {
m_strides[i] = m_strides[i+1] * m_dimensions[i+1];
if (m_strides[i] > 0) m_fastStrides[i] = IndexDivisor(m_strides[i]);
}
}
// Remember the strides for fast division.
for (int i = 0; i < NumDims; ++i) {
m_fastStrides[i] = internal::TensorIntDivisor<Index>(m_strides[i]);
}
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@ -377,7 +377,7 @@ struct TensorEvaluator<const TensorReverseOp<ReverseDimensions, ArgType>, Device
protected:
Dimensions m_dimensions;
array<Index, NumDims> m_strides;
array<internal::TensorIntDivisor<Index>, NumDims> m_fastStrides;
array<IndexDivisor, NumDims> m_fastStrides;
TensorEvaluator<ArgType, Device> m_impl;
ReverseDimensions m_reverse;
const Device& m_device;