Fixed a bug in TensorArgMax.h

This commit is contained in:
Benoit Steiner 2015-11-23 15:58:47 -08:00
parent 547a8608e5
commit 44848ac39b

View File

@ -215,10 +215,17 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_orig_impl(op.expression(), device), : m_orig_impl(op.expression(), device),
m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device), m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
m_return_dim(op.return_dim()), m_return_dim(op.return_dim()) {
m_stride_mod(gen_stride_mod(m_orig_impl.dimensions())),
m_stride_div(gen_stride_div()) {
gen_strides(m_orig_impl.dimensions(), m_strides); gen_strides(m_orig_impl.dimensions(), m_strides);
if (Layout == static_cast<int>(ColMajor)) {
const Index total_size = internal::array_prod(m_orig_impl.dimensions());
m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
} else {
const Index total_size = internal::array_prod(m_orig_impl.dimensions());
m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
}
m_stride_div = m_strides[m_return_dim];
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
@ -263,25 +270,13 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
} }
} }
EIGEN_DEVICE_FUNC Index gen_stride_mod(const InputDimensions& dims) {
if (Layout == static_cast<int>(ColMajor)) {
return (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : dims.TotalSize();
} else {
return (m_return_dim > 0) ? m_strides[m_return_dim - 1] : dims.TotalSize();
}
}
EIGEN_DEVICE_FUNC Index gen_stride_div() {
return m_strides[m_return_dim];
}
protected: protected:
TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl; TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl; TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
const int m_return_dim; const int m_return_dim;
StrideDims m_strides; StrideDims m_strides;
const Index m_stride_mod; Index m_stride_mod;
const Index m_stride_div; Index m_stride_div;
}; };
} // end namespace Eigen } // end namespace Eigen