Merged in ibab/eigen (pull request PR-197)

Implement exclusive scan option for Tensor library
This commit is contained in:
Benoit Steiner 2016-06-14 17:54:59 -07:00
commit 7d495d890a
3 changed files with 46 additions and 12 deletions

View File

@ -486,22 +486,22 @@ class TensorBase<Derived, ReadOnlyAccessors>
typedef TensorScanOp<internal::SumReducer<CoeffReturnType>, const Derived> TensorScanSumOp;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorScanSumOp
cumsum(const Index& axis) const {
return TensorScanSumOp(derived(), axis);
cumsum(const Index& axis, bool exclusive = false) const {
return TensorScanSumOp(derived(), axis, exclusive);
}
typedef TensorScanOp<internal::ProdReducer<CoeffReturnType>, const Derived> TensorScanProdOp;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorScanProdOp
cumprod(const Index& axis) const {
return TensorScanProdOp(derived(), axis);
cumprod(const Index& axis, bool exclusive = false) const {
return TensorScanProdOp(derived(), axis, exclusive);
}
template <typename Reducer>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorScanOp<Reducer, const Derived>
scan(const Index& axis, const Reducer& reducer) const {
return TensorScanOp<Reducer, const Derived>(derived(), axis, reducer);
scan(const Index& axis, const Reducer& reducer, bool exclusive = false) const {
return TensorScanOp<Reducer, const Derived>(derived(), axis, exclusive, reducer);
}
// Reductions.

View File

@ -57,8 +57,8 @@ public:
typedef typename Eigen::internal::traits<TensorScanOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorScanOp(
const XprType& expr, const Index& axis, const Op& op = Op())
: m_expr(expr), m_axis(axis), m_accumulator(op) {}
const XprType& expr, const Index& axis, bool exclusive = false, const Op& op = Op())
: m_expr(expr), m_axis(axis), m_accumulator(op), m_exclusive(exclusive) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Index axis() const { return m_axis; }
@ -66,11 +66,14 @@ public:
const XprType& expression() const { return m_expr; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Op accumulator() const { return m_accumulator; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool exclusive() const { return m_exclusive; }
protected:
typename XprType::Nested m_expr;
const Index m_axis;
const Op m_accumulator;
const bool m_exclusive;
};
// Eval as rvalue
@ -99,6 +102,7 @@ struct TensorEvaluator<const TensorScanOp<Op, ArgType>, Device> {
: m_impl(op.expression(), device),
m_device(device),
m_axis(op.axis()),
m_exclusive(op.exclusive()),
m_accumulator(op.accumulator()),
m_dimensions(m_impl.dimensions()),
m_size(m_dimensions[m_axis]),
@ -168,6 +172,7 @@ protected:
TensorEvaluator<ArgType, Device> m_impl;
const Device& m_device;
const Index m_axis;
const bool m_exclusive;
Op m_accumulator;
const Dimensions& m_dimensions;
const Index& m_size;
@ -176,7 +181,7 @@ protected:
// TODO(ibab) Parallelize this single-threaded implementation if desired
EIGEN_DEVICE_FUNC void accumulateTo(Scalar* data) {
// We fix the index along the scan axis to 0 and perform an
// We fix the index along the scan axis to 0 and perform a
// scan per remaining entry. The iteration is split into two nested
// loops to avoid an integer division by keeping track of each idx1 and idx2.
for (Index idx1 = 0; idx1 < dimensions().TotalSize() / m_size; idx1 += m_stride) {
@ -184,12 +189,17 @@ protected:
// Calculate the starting offset for the scan
Index offset = idx1 * m_size + idx2;
// Compute the prefix sum along the axis, starting at the calculated offset
// Compute the scan along the axis, starting at the calculated offset
CoeffReturnType accum = m_accumulator.initialize();
for (Index idx3 = 0; idx3 < m_size; idx3++) {
Index curr = offset + idx3 * m_stride;
m_accumulator.reduce(m_impl.coeff(curr), &accum);
data[curr] = m_accumulator.finalize(accum);
if (m_exclusive) {
data[curr] = m_accumulator.finalize(accum);
m_accumulator.reduce(m_impl.coeff(curr), &accum);
} else {
m_accumulator.reduce(m_impl.coeff(curr), &accum);
data[curr] = m_accumulator.finalize(accum);
}
}
}
}

View File

@ -38,6 +38,30 @@ static void test_1d_scan()
}
}
template <int DataLayout, typename Type=float>
static void test_1d_inclusive_scan()
{
int size = 50;
Tensor<Type, 1, DataLayout> tensor(size);
tensor.setRandom();
Tensor<Type, 1, DataLayout> result = tensor.cumsum(0, true);
VERIFY_IS_EQUAL(tensor.dimension(0), result.dimension(0));
float accum = 0;
for (int i = 0; i < size; i++) {
VERIFY_IS_EQUAL(result(i), accum);
accum += tensor(i);
}
accum = 1;
result = tensor.cumprod(0, true);
for (int i = 0; i < size; i++) {
VERIFY_IS_EQUAL(result(i), accum);
accum *= tensor(i);
}
}
template <int DataLayout, typename Type=float>
static void test_4d_scan()
{