Don't store the scan axis in the evaluator of the tensor scan operation since it's only used in the constructor.

Also avoid taking references to values that may becomes stale after a copy construction.
This commit is contained in:
Benoit Steiner 2016-06-27 10:32:38 -07:00
parent d476cadbb8
commit 75c333f94c

View File

@ -101,32 +101,31 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
const Device& device) const Device& device)
: m_impl(op.expression(), device), : m_impl(op.expression(), device),
m_device(device), m_device(device),
m_axis(op.axis()),
m_exclusive(op.exclusive()), m_exclusive(op.exclusive()),
m_accumulator(op.accumulator()), m_accumulator(op.accumulator()),
m_dimensions(m_impl.dimensions()), m_size(m_impl.dimensions()[op.axis()]),
m_size(m_dimensions[m_axis]),
m_stride(1), m_stride(1),
m_output(NULL) { m_output(NULL) {
// Accumulating a scalar isn't supported. // Accumulating a scalar isn't supported.
EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT((NumDims > 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
eigen_assert(m_axis >= 0 && m_axis < NumDims); eigen_assert(op.axis() >= 0 && op.axis() < NumDims);
// Compute stride of scan axis // Compute stride of scan axis
const Dimensions& dims = m_impl.dimensions();
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
for (int i = 0; i < m_axis; ++i) { for (int i = 0; i < op.axis(); ++i) {
m_stride = m_stride * m_dimensions[i]; m_stride = m_stride * dims[i];
} }
} else { } else {
for (int i = NumDims - 1; i > m_axis; --i) { for (int i = NumDims - 1; i > op.axis(); --i) {
m_stride = m_stride * m_dimensions[i]; m_stride = m_stride * dims[i];
} }
} }
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
return m_dimensions; return m_impl.dimensions();
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
@ -135,7 +134,8 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
accumulateTo(data); accumulateTo(data);
return false; return false;
} else { } else {
m_output = static_cast<CoeffReturnType*>(m_device.allocate(dimensions().TotalSize() * sizeof(Scalar))); const Index total_size = internal::array_prod(dimensions());
m_output = static_cast<CoeffReturnType*>(m_device.allocate(total_size * sizeof(Scalar)));
accumulateTo(m_output); accumulateTo(m_output);
return true; return true;
} }
@ -171,11 +171,9 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
protected: protected:
TensorEvaluator<ArgType, Device> m_impl; TensorEvaluator<ArgType, Device> m_impl;
const Device& m_device; const Device& m_device;
const Index m_axis;
const bool m_exclusive; const bool m_exclusive;
Op m_accumulator; Op m_accumulator;
const Dimensions& m_dimensions; const Index m_size;
const Index& m_size;
Index m_stride; Index m_stride;
CoeffReturnType* m_output; CoeffReturnType* m_output;