diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 8f3580ba7..87fa672f4 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -486,22 +486,22 @@ class TensorBase typedef TensorScanOp, const Derived> TensorScanSumOp; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorScanSumOp - cumsum(const Index& axis) const { - return TensorScanSumOp(derived(), axis); + cumsum(const Index& axis, bool exclusive = false) const { + return TensorScanSumOp(derived(), axis, exclusive); } typedef TensorScanOp, const Derived> TensorScanProdOp; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorScanProdOp - cumprod(const Index& axis) const { - return TensorScanProdOp(derived(), axis); + cumprod(const Index& axis, bool exclusive = false) const { + return TensorScanProdOp(derived(), axis, exclusive); } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorScanOp - scan(const Index& axis, const Reducer& reducer) const { - return TensorScanOp(derived(), axis, reducer); + scan(const Index& axis, const Reducer& reducer, bool exclusive = false) const { + return TensorScanOp(derived(), axis, exclusive, reducer); } // Reductions. diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h index 5207f6a8d..1aa196b84 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorScan.h @@ -57,8 +57,8 @@ public: typedef typename Eigen::internal::traits::Index Index; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorScanOp( - const XprType& expr, const Index& axis, const Op& op = Op()) - : m_expr(expr), m_axis(axis), m_accumulator(op) {} + const XprType& expr, const Index& axis, bool exclusive = false, const Op& op = Op()) + : m_expr(expr), m_axis(axis), m_accumulator(op), m_exclusive(exclusive) {} EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Index axis() const { return m_axis; } @@ -66,11 +66,14 @@ public: const XprType& expression() const { return m_expr; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Op accumulator() const { return m_accumulator; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + bool exclusive() const { return m_exclusive; } protected: typename XprType::Nested m_expr; const Index m_axis; const Op m_accumulator; + const bool m_exclusive; }; // Eval as rvalue @@ -99,6 +102,7 @@ struct TensorEvaluator, Device> { : m_impl(op.expression(), device), m_device(device), m_axis(op.axis()), + m_exclusive(op.exclusive()), m_accumulator(op.accumulator()), m_dimensions(m_impl.dimensions()), m_size(m_dimensions[m_axis]), @@ -168,6 +172,7 @@ protected: TensorEvaluator m_impl; const Device& m_device; const Index m_axis; + const bool m_exclusive; Op m_accumulator; const Dimensions& m_dimensions; const Index& m_size; @@ -176,7 +181,7 @@ protected: // TODO(ibab) Parallelize this single-threaded implementation if desired EIGEN_DEVICE_FUNC void accumulateTo(Scalar* data) { - // We fix the index along the scan axis to 0 and perform an + // We fix the index along the scan axis to 0 and perform a // scan per remaining entry. The iteration is split into two nested // loops to avoid an integer division by keeping track of each idx1 and idx2. for (Index idx1 = 0; idx1 < dimensions().TotalSize() / m_size; idx1 += m_stride) { @@ -184,12 +189,17 @@ protected: // Calculate the starting offset for the scan Index offset = idx1 * m_size + idx2; - // Compute the prefix sum along the axis, starting at the calculated offset + // Compute the scan along the axis, starting at the calculated offset CoeffReturnType accum = m_accumulator.initialize(); for (Index idx3 = 0; idx3 < m_size; idx3++) { Index curr = offset + idx3 * m_stride; - m_accumulator.reduce(m_impl.coeff(curr), &accum); - data[curr] = m_accumulator.finalize(accum); + if (m_exclusive) { + data[curr] = m_accumulator.finalize(accum); + m_accumulator.reduce(m_impl.coeff(curr), &accum); + } else { + m_accumulator.reduce(m_impl.coeff(curr), &accum); + data[curr] = m_accumulator.finalize(accum); + } } } } diff --git a/unsupported/test/cxx11_tensor_scan.cpp b/unsupported/test/cxx11_tensor_scan.cpp index dbd3023d7..bafa6c96e 100644 --- a/unsupported/test/cxx11_tensor_scan.cpp +++ b/unsupported/test/cxx11_tensor_scan.cpp @@ -38,6 +38,30 @@ static void test_1d_scan() } } +template +static void test_1d_inclusive_scan() +{ + int size = 50; + Tensor tensor(size); + tensor.setRandom(); + Tensor result = tensor.cumsum(0, true); + + VERIFY_IS_EQUAL(tensor.dimension(0), result.dimension(0)); + + float accum = 0; + for (int i = 0; i < size; i++) { + VERIFY_IS_EQUAL(result(i), accum); + accum += tensor(i); + } + + accum = 1; + result = tensor.cumprod(0, true); + for (int i = 0; i < size; i++) { + VERIFY_IS_EQUAL(result(i), accum); + accum *= tensor(i); + } +} + template static void test_4d_scan() {