mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Added support for boolean reductions (ie 'and' & 'or' reductions)
This commit is contained in:
parent
f5c1587e4e
commit
eaf4b98180
@ -1149,6 +1149,19 @@ are the smallest of the reduced values.
|
||||
Reduce a tensor using the prod() operator. The resulting values
|
||||
are the product of the reduced values.
|
||||
|
||||
### <Operation> all(const Dimensions& new_dims)
|
||||
### <Operation> all()
|
||||
Reduce a tensor using the all() operator. Casts tensor to bool and then checks
|
||||
whether all elements are true. Runs through all elements rather than
|
||||
short-circuiting, so may be significantly inefficient.
|
||||
|
||||
### <Operation> any(const Dimensions& new_dims)
|
||||
### <Operation> any()
|
||||
Reduce a tensor using the any() operator. Casts tensor to bool and then checks
|
||||
whether any element is true. Runs through all elements rather than
|
||||
short-circuiting, so may be significantly inefficient.
|
||||
|
||||
|
||||
### <Operation> reduce(const Dimensions& new_dims, const Reducer& reducer)
|
||||
|
||||
Reduce a tensor using a user-defined reduction operator. See ```SumReducer```
|
||||
|
@ -363,6 +363,32 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
||||
return TensorReductionOp<internal::MinReducer<CoeffReturnType>, const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims, internal::MinReducer<CoeffReturnType>());
|
||||
}
|
||||
|
||||
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorReductionOp<internal::AndReducer, const Dims, const TensorConversionOp<bool, const Derived> >
|
||||
all(const Dims& dims) const {
|
||||
return cast<bool>().reduce(dims, internal::AndReducer());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorReductionOp<internal::AndReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> >
|
||||
all() const {
|
||||
DimensionList<Index, NumDimensions> in_dims;
|
||||
return cast<bool>().reduce(in_dims, internal::AndReducer());
|
||||
}
|
||||
|
||||
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorReductionOp<internal::OrReducer, const Dims, const TensorConversionOp<bool, const Derived> >
|
||||
any(const Dims& dims) const {
|
||||
return cast<bool>().reduce(dims, internal::OrReducer());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorReductionOp<internal::OrReducer, const DimensionList<Index, NumDimensions>, const TensorConversionOp<bool, const Derived> >
|
||||
any() const {
|
||||
DimensionList<Index, NumDimensions> in_dims;
|
||||
return cast<bool>().reduce(in_dims, internal::OrReducer());
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
const TensorTupleReducerOp<
|
||||
internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >,
|
||||
|
@ -219,6 +219,33 @@ template <typename T> struct ProdReducer
|
||||
};
|
||||
|
||||
|
||||
struct AndReducer
|
||||
{
|
||||
static const bool PacketAccess = false;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const {
|
||||
*accum = *accum && t;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const {
|
||||
return true;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const {
|
||||
return accum;
|
||||
}
|
||||
};
|
||||
|
||||
struct OrReducer {
|
||||
static const bool PacketAccess = false;
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(bool t, bool* accum) const {
|
||||
*accum = *accum || t;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool initialize() const {
|
||||
return false;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool finalize(bool accum) const {
|
||||
return accum;
|
||||
}
|
||||
};
|
||||
|
||||
// Argmin/Argmax reducers
|
||||
template <typename T> struct ArgMaxTupleReducer
|
||||
{
|
||||
|
@ -180,6 +180,23 @@ static void test_simple_reductions() {
|
||||
|
||||
VERIFY_IS_APPROX(mean1(0), mean2(0));
|
||||
}
|
||||
|
||||
{
|
||||
Tensor<int, 1> ints(10);
|
||||
std::iota(ints.data(), ints.data() + ints.dimension(0), 0);
|
||||
|
||||
TensorFixedSize<bool, Sizes<1> > all;
|
||||
all = ints.all();
|
||||
VERIFY(!all(0));
|
||||
all = (ints >= ints.constant(0)).all();
|
||||
VERIFY(all(0));
|
||||
|
||||
TensorFixedSize<bool, Sizes<1> > any;
|
||||
any = (ints > ints.constant(10)).any();
|
||||
VERIFY(!any(0));
|
||||
any = (ints < ints.constant(1)).any();
|
||||
VERIFY(any(0));
|
||||
}
|
||||
}
|
||||
|
||||
template <int DataLayout>
|
||||
|
Loading…
x
Reference in New Issue
Block a user