Don't store the scan axis in the evaluator of the tensor scan operation since it's only used in the constructor.

Also avoid taking references to values that may becomes stale after a copy construction.
This commit is contained in:
Benoit Steiner 2016-06-27 10:32:38 -07:00
parent d476cadbb8
commit 75c333f94c

View File

@ -101,32 +101,31 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
const Device& device)
: m_impl(op.expression(), device),
m_device(device),
m_axis(op.axis()),
m_exclusive(op.exclusive()),
m_accumulator(op.accumulator()),
m_dimensions(m_impl.dimensions()),
m_size(m_dimensions[m_axis]),
m_size(m_impl.dimensions()[op.axis()]),
m_stride(1),
m_output(NULL) {
// Accumulating a scalar isn't supported.
EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
eigen_assert(m_axis >= 0 && m_axis < NumDims);
eigen_assert(op.axis() >= 0 && op.axis() < NumDims);
// Compute stride of scan axis
const Dimensions& dims = m_impl.dimensions();
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
for (int i = 0; i < m_axis; ++i) {
m_stride = m_stride * m_dimensions[i];
for (int i = 0; i < op.axis(); ++i) {
m_stride = m_stride * dims[i];
}
} else {
for (int i = NumDims - 1; i > m_axis; --i) {
m_stride = m_stride * m_dimensions[i];
for (int i = NumDims - 1; i > op.axis(); --i) {
m_stride = m_stride * dims[i];
}
}
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
return m_dimensions;
return m_impl.dimensions();
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
@ -135,7 +134,8 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
accumulateTo(data);
return false;
} else {
m_output = static_cast<CoeffReturnType*>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar)));
const Index total_size = internal::array_prod(dimensions());
m_output = static_cast<CoeffReturnType*>(m_device.allocate(total_size * sizeof(Scalar)));
accumulateTo(m_output);
return true;
}
@ -171,11 +171,9 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
protected:
TensorEvaluator<ArgType, Device> m_impl;
const Device& m_device;
const Index m_axis;
const bool m_exclusive;
Op m_accumulator;
const Dimensions& m_dimensions;
const Index& m_size;
const Index m_size;
Index m_stride;
CoeffReturnType* m_output;