Rename Tuple -> Pair.

This is to make way for a new `Tuple` class that mimics `std::tuple`,
but can be reliably used on device and with aligned Eigen types.

The existing Tuple has very few references, and is actually an
analogue of `std::pair`.
This commit is contained in:
Antonio Sanchez 2021-08-26 12:25:31 -07:00 committed by Rasmus Munk Larsen
parent 3d4ba855e0
commit 74da2e6821
6 changed files with 117 additions and 115 deletions

View File

@ -14,20 +14,20 @@
namespace Eigen {
namespace internal {
/** \class TensorIndexTuple
/** \class TensorIndexPair
* \ingroup CXX11_Tensor_Module
*
* \brief Tensor + Index Tuple class.
* \brief Tensor + Index Pair class.
*
*
*/
template<typename XprType>
struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
struct traits<TensorIndexPairOp<XprType> > : public traits<XprType>
{
typedef traits<XprType> XprTraits;
typedef typename XprTraits::StorageKind StorageKind;
typedef typename XprTraits::Index Index;
typedef Tuple<Index, typename XprTraits::Scalar> Scalar;
typedef Pair<Index, typename XprTraits::Scalar> Scalar;
typedef typename XprType::Nested Nested;
typedef typename remove_reference<Nested>::type _Nested;
static const int NumDimensions = XprTraits::NumDimensions;
@ -35,32 +35,32 @@ struct traits<TensorIndexTupleOp<XprType> > : public traits<XprType>
};
template<typename XprType>
struct eval<TensorIndexTupleOp<XprType>, Eigen::Dense>
struct eval<TensorIndexPairOp<XprType>, Eigen::Dense>
{
typedef const TensorIndexTupleOp<XprType>EIGEN_DEVICE_REF type;
typedef const TensorIndexPairOp<XprType>EIGEN_DEVICE_REF type;
};
template<typename XprType>
struct nested<TensorIndexTupleOp<XprType>, 1,
typename eval<TensorIndexTupleOp<XprType> >::type>
struct nested<TensorIndexPairOp<XprType>, 1,
typename eval<TensorIndexPairOp<XprType> >::type>
{
typedef TensorIndexTupleOp<XprType> type;
typedef TensorIndexPairOp<XprType> type;
};
} // end namespace internal
template<typename XprType>
class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOnlyAccessors>
class TensorIndexPairOp : public TensorBase<TensorIndexPairOp<XprType>, ReadOnlyAccessors>
{
public:
typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorIndexPairOp>::Scalar Scalar;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename Eigen::internal::nested<TensorIndexTupleOp>::type Nested;
typedef typename Eigen::internal::traits<TensorIndexTupleOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorIndexTupleOp>::Index Index;
typedef Tuple<Index, typename XprType::CoeffReturnType> CoeffReturnType;
typedef typename Eigen::internal::nested<TensorIndexPairOp>::type Nested;
typedef typename Eigen::internal::traits<TensorIndexPairOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorIndexPairOp>::Index Index;
typedef Pair<Index, typename XprType::CoeffReturnType> CoeffReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexTupleOp(const XprType& expr)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorIndexPairOp(const XprType& expr)
: m_xpr(expr) {}
EIGEN_DEVICE_FUNC
@ -73,9 +73,9 @@ class TensorIndexTupleOp : public TensorBase<TensorIndexTupleOp<XprType>, ReadOn
// Eval as rvalue
template<typename ArgType, typename Device>
struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
struct TensorEvaluator<const TensorIndexPairOp<ArgType>, Device>
{
typedef TensorIndexTupleOp<ArgType> XprType;
typedef TensorIndexPairOp<ArgType> XprType;
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
@ -138,14 +138,14 @@ struct TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device>
namespace internal {
/** \class TensorTupleIndex
/** \class TensorPairIndex
* \ingroup CXX11_Tensor_Module
*
* \brief Converts to Tensor<Tuple<Index, Scalar> > and reduces to Tensor<Index>.
* \brief Converts to Tensor<Pair<Index, Scalar> > and reduces to Tensor<Index>.
*
*/
template<typename ReduceOp, typename Dims, typename XprType>
struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType>
struct traits<TensorPairReducerOp<ReduceOp, Dims, XprType> > : public traits<XprType>
{
typedef traits<XprType> XprTraits;
typedef typename XprTraits::StorageKind StorageKind;
@ -158,32 +158,32 @@ struct traits<TensorTupleReducerOp<ReduceOp, Dims, XprType> > : public traits<Xp
};
template<typename ReduceOp, typename Dims, typename XprType>
struct eval<TensorTupleReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
struct eval<TensorPairReducerOp<ReduceOp, Dims, XprType>, Eigen::Dense>
{
typedef const TensorTupleReducerOp<ReduceOp, Dims, XprType>EIGEN_DEVICE_REF type;
typedef const TensorPairReducerOp<ReduceOp, Dims, XprType>EIGEN_DEVICE_REF type;
};
template<typename ReduceOp, typename Dims, typename XprType>
struct nested<TensorTupleReducerOp<ReduceOp, Dims, XprType>, 1,
typename eval<TensorTupleReducerOp<ReduceOp, Dims, XprType> >::type>
struct nested<TensorPairReducerOp<ReduceOp, Dims, XprType>, 1,
typename eval<TensorPairReducerOp<ReduceOp, Dims, XprType> >::type>
{
typedef TensorTupleReducerOp<ReduceOp, Dims, XprType> type;
typedef TensorPairReducerOp<ReduceOp, Dims, XprType> type;
};
} // end namespace internal
template<typename ReduceOp, typename Dims, typename XprType>
class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
class TensorPairReducerOp : public TensorBase<TensorPairReducerOp<ReduceOp, Dims, XprType>, ReadOnlyAccessors>
{
public:
typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorPairReducerOp>::Scalar Scalar;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename Eigen::internal::nested<TensorTupleReducerOp>::type Nested;
typedef typename Eigen::internal::traits<TensorTupleReducerOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorTupleReducerOp>::Index Index;
typedef typename Eigen::internal::nested<TensorPairReducerOp>::type Nested;
typedef typename Eigen::internal::traits<TensorPairReducerOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorPairReducerOp>::Index Index;
typedef Index CoeffReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTupleReducerOp(const XprType& expr,
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorPairReducerOp(const XprType& expr,
const ReduceOp& reduce_op,
const Index return_dim,
const Dims& reduce_dims)
@ -211,27 +211,27 @@ class TensorTupleReducerOp : public TensorBase<TensorTupleReducerOp<ReduceOp, Di
// Eval as rvalue
template<typename ReduceOp, typename Dims, typename ArgType, typename Device>
struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Device>
struct TensorEvaluator<const TensorPairReducerOp<ReduceOp, Dims, ArgType>, Device>
{
typedef TensorTupleReducerOp<ReduceOp, Dims, ArgType> XprType;
typedef TensorPairReducerOp<ReduceOp, Dims, ArgType> XprType;
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename TensorIndexTupleOp<ArgType>::CoeffReturnType TupleType;
typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Dimensions Dimensions;
typedef typename TensorEvaluator<const TensorIndexTupleOp<ArgType> , Device>::Dimensions InputDimensions;
typedef typename TensorIndexPairOp<ArgType>::CoeffReturnType PairType;
typedef typename TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType> >, Device>::Dimensions Dimensions;
typedef typename TensorEvaluator<const TensorIndexPairOp<ArgType> , Device>::Dimensions InputDimensions;
static const int NumDims = internal::array_size<InputDimensions>::value;
typedef array<Index, NumDims> StrideDims;
typedef StorageMemory<CoeffReturnType, Device> Storage;
typedef typename Storage::Type EvaluatorPointerType;
typedef StorageMemory<TupleType, Device> TupleStorageMem;
typedef StorageMemory<PairType, Device> PairStorageMem;
enum {
IsAligned = /*TensorEvaluator<ArgType, Device>::IsAligned*/ false,
PacketAccess = /*TensorEvaluator<ArgType, Device>::PacketAccess*/ false,
BlockAccess = false,
PreferBlockAccess = TensorEvaluator<ArgType, Device>::PreferBlockAccess,
Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType> >, Device>::Layout,
CoordAccess = false, // to be implemented
RawAccess = false
};
@ -242,7 +242,7 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
: m_orig_impl(op.expression(), device),
m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
m_impl(op.expression().index_pairs().reduce(op.reduce_dims(), op.reduce_op()), device),
m_return_dim(op.return_dim())
{
gen_strides(m_orig_impl.dimensions(), m_strides);
@ -272,7 +272,7 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
const TupleType v = m_impl.coeff(index);
const PairType v = m_impl.coeff(index);
return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
}
@ -316,8 +316,8 @@ struct TensorEvaluator<const TensorTupleReducerOp<ReduceOp, Dims, ArgType>, Devi
}
protected:
TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
TensorEvaluator<const TensorIndexPairOp<ArgType>, Device> m_orig_impl;
TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexPairOp<ArgType> >, Device> m_impl;
const Index m_return_dim;
StrideDims m_strides;
Index m_stride_mod;

View File

@ -741,55 +741,55 @@ class TensorBase<Derived, ReadOnlyAccessors>
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorTupleReducerOp<
internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >,
const TensorPairReducerOp<
internal::ArgMaxPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, NumDimensions>, const Derived>
argmax() const {
array<Index, NumDimensions> in_dims;
for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d;
return TensorTupleReducerOp<
internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >,
return TensorPairReducerOp<
internal::ArgMaxPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, NumDimensions>,
const Derived>(derived(), internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >(), -1, in_dims);
const Derived>(derived(), internal::ArgMaxPairReducer<Pair<Index, CoeffReturnType> >(), -1, in_dims);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorTupleReducerOp<
internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >,
const TensorPairReducerOp<
internal::ArgMinPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, NumDimensions>, const Derived>
argmin() const {
array<Index, NumDimensions> in_dims;
for (Index d = 0; d < NumDimensions; ++d) in_dims[d] = d;
return TensorTupleReducerOp<
internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >,
return TensorPairReducerOp<
internal::ArgMinPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, NumDimensions>,
const Derived>(derived(), internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >(), -1, in_dims);
const Derived>(derived(), internal::ArgMinPairReducer<Pair<Index, CoeffReturnType> >(), -1, in_dims);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorTupleReducerOp<
internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >,
const TensorPairReducerOp<
internal::ArgMaxPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, 1>, const Derived>
argmax(const Index return_dim) const {
array<Index, 1> in_dims;
in_dims[0] = return_dim;
return TensorTupleReducerOp<
internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >,
return TensorPairReducerOp<
internal::ArgMaxPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, 1>,
const Derived>(derived(), internal::ArgMaxTupleReducer<Tuple<Index, CoeffReturnType> >(), return_dim, in_dims);
const Derived>(derived(), internal::ArgMaxPairReducer<Pair<Index, CoeffReturnType> >(), return_dim, in_dims);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorTupleReducerOp<
internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >,
const TensorPairReducerOp<
internal::ArgMinPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, 1>, const Derived>
argmin(const Index return_dim) const {
array<Index, 1> in_dims;
in_dims[0] = return_dim;
return TensorTupleReducerOp<
internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >,
return TensorPairReducerOp<
internal::ArgMinPairReducer<Pair<Index, CoeffReturnType> >,
const array<Index, 1>,
const Derived>(derived(), internal::ArgMinTupleReducer<Tuple<Index, CoeffReturnType> >(), return_dim, in_dims);
const Derived>(derived(), internal::ArgMinPairReducer<Pair<Index, CoeffReturnType> >(), return_dim, in_dims);
}
template <typename Reducer, typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@ -935,11 +935,11 @@ class TensorBase<Derived, ReadOnlyAccessors>
return TensorInflationOp<const Strides, const Derived>(derived(), strides);
}
// Returns a tensor containing index/value tuples
// Returns a tensor containing index/value pairs
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorIndexTupleOp<const Derived>
index_tuples() const {
return TensorIndexTupleOp<const Derived>(derived());
const TensorIndexPairOp<const Derived>
index_pairs() const {
return TensorIndexPairOp<const Derived>(derived());
}
// Support for custom unary and binary operations

View File

@ -61,8 +61,8 @@ template<typename BinaryOp, typename LeftXprType, typename RightXprType> class T
template<typename TernaryOp, typename Arg1XprType, typename Arg2XprType, typename Arg3XprType> class TensorCwiseTernaryOp;
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template<typename Op, typename Dims, typename XprType, template <class> class MakePointer_ = MakePointer > class TensorReductionOp;
template<typename XprType> class TensorIndexTupleOp;
template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp;
template<typename XprType> class TensorIndexPairOp;
template<typename ReduceOp, typename Dims, typename XprType> class TensorPairReducerOp;
template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp;
template<typename Dimensions, typename LeftXprType, typename RightXprType, typename OutputKernelType> class TensorContractionOp;
template<typename TargetType, typename XprType> class TensorConversionOp;

View File

@ -367,7 +367,7 @@ struct reducer_traits<OrReducer, Device> {
// Argmin/Argmax reducers. Returns the first occurrence if multiple locations
// contain the same min/max value.
template <typename T> struct ArgMaxTupleReducer
template <typename T> struct ArgMaxPairReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
if (t.second < accum->second) {
@ -385,7 +385,7 @@ template <typename T> struct ArgMaxTupleReducer
};
template <typename T, typename Device>
struct reducer_traits<ArgMaxTupleReducer<T>, Device> {
struct reducer_traits<ArgMaxPairReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
PacketAccess = false,
@ -395,7 +395,7 @@ struct reducer_traits<ArgMaxTupleReducer<T>, Device> {
};
template <typename T> struct ArgMinTupleReducer
template <typename T> struct ArgMinPairReducer
{
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T& t, T* accum) const {
if (t.second > accum->second) {
@ -413,7 +413,7 @@ template <typename T> struct ArgMinTupleReducer
};
template <typename T, typename Device>
struct reducer_traits<ArgMinTupleReducer<T>, Device> {
struct reducer_traits<ArgMinPairReducer<T>, Device> {
enum {
Cost = NumTraits<T>::AddCost,
PacketAccess = false,

View File

@ -207,9 +207,11 @@ template<> struct PacketType<const half, const SyclDevice>: PacketType<half, Syc
#endif
#endif
// Tuple mimics std::pair but works on e.g. nvcc.
template <typename U, typename V> struct Tuple {
// Pair mimics std::pair but works on e.g. nvcc.
template <typename U, typename V> struct Pair {
public:
EIGEN_MAKE_ALIGNED_OPERATOR_NEW
U first;
V second;
@ -217,13 +219,13 @@ template <typename U, typename V> struct Tuple {
typedef V second_type;
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Tuple() : first(), second() {}
Pair() : first(), second() {}
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Tuple(const U& f, const V& s) : first(f), second(s) {}
Pair(const U& f, const V& s) : first(f), second(s) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void swap(Tuple& rhs) {
void swap(Pair& rhs) {
using numext::swap;
swap(first, rhs.first);
swap(second, rhs.second);
@ -232,13 +234,13 @@ template <typename U, typename V> struct Tuple {
template <typename U, typename V>
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool operator==(const Tuple<U, V>& x, const Tuple<U, V>& y) {
bool operator==(const Pair<U, V>& x, const Pair<U, V>& y) {
return (x.first == y.first && x.second == y.second);
}
template <typename U, typename V>
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool operator!=(const Tuple<U, V>& x, const Tuple<U, V>& y) {
bool operator!=(const Pair<U, V>& x, const Pair<U, V>& y) {
return !(x == y);
}

View File

@ -14,57 +14,57 @@
using Eigen::Tensor;
using Eigen::array;
using Eigen::Tuple;
using Eigen::Pair;
template <int DataLayout>
static void test_simple_index_tuples()
static void test_simple_index_pairs()
{
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
tensor.setRandom();
tensor = (tensor + tensor.constant(0.5)).log();
Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
index_tuples = tensor.index_tuples();
Tensor<Pair<DenseIndex, float>, 4, DataLayout> index_pairs(2,3,5,7);
index_pairs = tensor.index_pairs();
for (DenseIndex n = 0; n < 2*3*5*7; ++n) {
const Tuple<DenseIndex, float>& v = index_tuples.coeff(n);
const Pair<DenseIndex, float>& v = index_pairs.coeff(n);
VERIFY_IS_EQUAL(v.first, n);
VERIFY_IS_EQUAL(v.second, tensor.coeff(n));
}
}
template <int DataLayout>
static void test_index_tuples_dim()
static void test_index_pairs_dim()
{
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
tensor.setRandom();
tensor = (tensor + tensor.constant(0.5)).log();
Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
Tensor<Pair<DenseIndex, float>, 4, DataLayout> index_pairs(2,3,5,7);
index_tuples = tensor.index_tuples();
index_pairs = tensor.index_pairs();
for (Eigen::DenseIndex n = 0; n < tensor.size(); ++n) {
const Tuple<DenseIndex, float>& v = index_tuples(n); //(i, j, k, l);
const Pair<DenseIndex, float>& v = index_pairs(n); //(i, j, k, l);
VERIFY_IS_EQUAL(v.first, n);
VERIFY_IS_EQUAL(v.second, tensor(n));
}
}
template <int DataLayout>
static void test_argmax_tuple_reducer()
static void test_argmax_pair_reducer()
{
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
tensor.setRandom();
tensor = (tensor + tensor.constant(0.5)).log();
Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
index_tuples = tensor.index_tuples();
Tensor<Pair<DenseIndex, float>, 4, DataLayout> index_pairs(2,3,5,7);
index_pairs = tensor.index_pairs();
Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
Tensor<Pair<DenseIndex, float>, 0, DataLayout> reduced;
DimensionList<DenseIndex, 4> dims;
reduced = index_tuples.reduce(
dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
reduced = index_pairs.reduce(
dims, internal::ArgMaxPairReducer<Pair<DenseIndex, float> >());
Tensor<float, 0, DataLayout> maxi = tensor.maximum();
@ -72,9 +72,9 @@ static void test_argmax_tuple_reducer()
array<DenseIndex, 3> reduce_dims;
for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
reduced_by_dims = index_tuples.reduce(
reduce_dims, internal::ArgMaxTupleReducer<Tuple<DenseIndex, float> >());
Tensor<Pair<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
reduced_by_dims = index_pairs.reduce(
reduce_dims, internal::ArgMaxPairReducer<Pair<DenseIndex, float> >());
Tensor<float, 1, DataLayout> max_by_dims = tensor.maximum(reduce_dims);
@ -84,19 +84,19 @@ static void test_argmax_tuple_reducer()
}
template <int DataLayout>
static void test_argmin_tuple_reducer()
static void test_argmin_pair_reducer()
{
Tensor<float, 4, DataLayout> tensor(2,3,5,7);
tensor.setRandom();
tensor = (tensor + tensor.constant(0.5)).log();
Tensor<Tuple<DenseIndex, float>, 4, DataLayout> index_tuples(2,3,5,7);
index_tuples = tensor.index_tuples();
Tensor<Pair<DenseIndex, float>, 4, DataLayout> index_pairs(2,3,5,7);
index_pairs = tensor.index_pairs();
Tensor<Tuple<DenseIndex, float>, 0, DataLayout> reduced;
Tensor<Pair<DenseIndex, float>, 0, DataLayout> reduced;
DimensionList<DenseIndex, 4> dims;
reduced = index_tuples.reduce(
dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
reduced = index_pairs.reduce(
dims, internal::ArgMinPairReducer<Pair<DenseIndex, float> >());
Tensor<float, 0, DataLayout> mini = tensor.minimum();
@ -104,9 +104,9 @@ static void test_argmin_tuple_reducer()
array<DenseIndex, 3> reduce_dims;
for (int d = 0; d < 3; ++d) reduce_dims[d] = d;
Tensor<Tuple<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
reduced_by_dims = index_tuples.reduce(
reduce_dims, internal::ArgMinTupleReducer<Tuple<DenseIndex, float> >());
Tensor<Pair<DenseIndex, float>, 1, DataLayout> reduced_by_dims(7);
reduced_by_dims = index_pairs.reduce(
reduce_dims, internal::ArgMinPairReducer<Pair<DenseIndex, float> >());
Tensor<float, 1, DataLayout> min_by_dims = tensor.minimum(reduce_dims);
@ -275,14 +275,14 @@ static void test_argmin_dim()
EIGEN_DECLARE_TEST(cxx11_tensor_argmax)
{
CALL_SUBTEST(test_simple_index_tuples<RowMajor>());
CALL_SUBTEST(test_simple_index_tuples<ColMajor>());
CALL_SUBTEST(test_index_tuples_dim<RowMajor>());
CALL_SUBTEST(test_index_tuples_dim<ColMajor>());
CALL_SUBTEST(test_argmax_tuple_reducer<RowMajor>());
CALL_SUBTEST(test_argmax_tuple_reducer<ColMajor>());
CALL_SUBTEST(test_argmin_tuple_reducer<RowMajor>());
CALL_SUBTEST(test_argmin_tuple_reducer<ColMajor>());
CALL_SUBTEST(test_simple_index_pairs<RowMajor>());
CALL_SUBTEST(test_simple_index_pairs<ColMajor>());
CALL_SUBTEST(test_index_pairs_dim<RowMajor>());
CALL_SUBTEST(test_index_pairs_dim<ColMajor>());
CALL_SUBTEST(test_argmax_pair_reducer<RowMajor>());
CALL_SUBTEST(test_argmax_pair_reducer<ColMajor>());
CALL_SUBTEST(test_argmin_pair_reducer<RowMajor>());
CALL_SUBTEST(test_argmin_pair_reducer<ColMajor>());
CALL_SUBTEST(test_simple_argmax<RowMajor>());
CALL_SUBTEST(test_simple_argmax<ColMajor>());
CALL_SUBTEST(test_simple_argmin<RowMajor>());