mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Added support for statically known lists of pairs of indices
This commit is contained in:
parent
ed783872ab
commit
58026905ae
@ -29,14 +29,6 @@ namespace Eigen {
|
|||||||
* \sa Tensor
|
* \sa Tensor
|
||||||
*/
|
*/
|
||||||
|
|
||||||
// Can't use std::pair on cuda devices
|
|
||||||
template <typename Index> struct IndexPair {
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) { }
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Index f, Index s) : first(f), second(s) { }
|
|
||||||
Index first;
|
|
||||||
Index second;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Boilerplate code
|
// Boilerplate code
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
@ -10,6 +10,22 @@
|
|||||||
#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
|
#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
|
||||||
#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
|
#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
|
||||||
|
|
||||||
|
/*namespace Eigen {
|
||||||
|
|
||||||
|
template <typename Index> struct IndexPair {
|
||||||
|
constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) {}
|
||||||
|
constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Index f, Index s) : first(f), second(s) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC void set(IndexPair<Index> val) {
|
||||||
|
first = val.first;
|
||||||
|
second = val.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
Index first;
|
||||||
|
Index second;
|
||||||
|
};
|
||||||
|
}*/
|
||||||
|
|
||||||
#if EIGEN_HAS_CONSTEXPR && EIGEN_HAS_VARIADIC_TEMPLATES
|
#if EIGEN_HAS_CONSTEXPR && EIGEN_HAS_VARIADIC_TEMPLATES
|
||||||
|
|
||||||
#define EIGEN_HAS_INDEX_LIST
|
#define EIGEN_HAS_INDEX_LIST
|
||||||
@ -45,6 +61,24 @@ struct type2index {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// This can be used with IndexPairList to get compile-time constant pairs,
|
||||||
|
// such as IndexPairList<type2indexpair<1,2>, type2indexpair<3,4>>().
|
||||||
|
template <DenseIndex f, DenseIndex s>
|
||||||
|
struct type2indexpair {
|
||||||
|
static const DenseIndex first = f;
|
||||||
|
static const DenseIndex second = s;
|
||||||
|
|
||||||
|
constexpr EIGEN_DEVICE_FUNC operator IndexPair<DenseIndex>() const {
|
||||||
|
return IndexPair<DenseIndex>(f, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC void set(const IndexPair<DenseIndex>& val) {
|
||||||
|
eigen_assert(val.first == f);
|
||||||
|
eigen_assert(val.second == s);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template<DenseIndex n> struct NumTraits<type2index<n> >
|
template<DenseIndex n> struct NumTraits<type2index<n> >
|
||||||
{
|
{
|
||||||
typedef DenseIndex Real;
|
typedef DenseIndex Real;
|
||||||
@ -72,6 +106,16 @@ EIGEN_DEVICE_FUNC void update_value(type2index<n>& val, DenseIndex new_val) {
|
|||||||
val.set(new_val);
|
val.set(new_val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
EIGEN_DEVICE_FUNC void update_value(T& val, IndexPair<DenseIndex> new_val) {
|
||||||
|
val = new_val;
|
||||||
|
}
|
||||||
|
template <DenseIndex f, DenseIndex s>
|
||||||
|
EIGEN_DEVICE_FUNC void update_value(type2indexpair<f, s>& val, IndexPair<DenseIndex> new_val) {
|
||||||
|
val.set(new_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct is_compile_time_constant {
|
struct is_compile_time_constant {
|
||||||
static constexpr bool value = false;
|
static constexpr bool value = false;
|
||||||
@ -94,7 +138,22 @@ struct is_compile_time_constant<const type2index<idx>& > {
|
|||||||
static constexpr bool value = true;
|
static constexpr bool value = true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <DenseIndex f, DenseIndex s>
|
||||||
|
struct is_compile_time_constant<type2indexpair<f, s> > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
template <DenseIndex f, DenseIndex s>
|
||||||
|
struct is_compile_time_constant<const type2indexpair<f, s> > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
template <DenseIndex f, DenseIndex s>
|
||||||
|
struct is_compile_time_constant<type2indexpair<f, s>& > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
template <DenseIndex f, DenseIndex s>
|
||||||
|
struct is_compile_time_constant<const type2indexpair<f, s>& > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template<typename... T>
|
template<typename... T>
|
||||||
@ -184,31 +243,32 @@ template <typename T, typename... O>
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
template <DenseIndex Idx>
|
template <DenseIndex Idx, typename ValueT>
|
||||||
struct tuple_coeff {
|
struct tuple_coeff {
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
EIGEN_DEVICE_FUNC static constexpr DenseIndex get(const DenseIndex i, const IndexTuple<T...>& t) {
|
EIGEN_DEVICE_FUNC static constexpr ValueT get(const DenseIndex i, const IndexTuple<T...>& t) {
|
||||||
return array_get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx);
|
// return array_get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx);
|
||||||
|
return (i == Idx ? array_get<Idx>(t) : tuple_coeff<Idx-1, ValueT>::get(i, t));
|
||||||
}
|
}
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const DenseIndex value) {
|
EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const ValueT& value) {
|
||||||
if (i == Idx) {
|
if (i == Idx) {
|
||||||
update_value(array_get<Idx>(t), value);
|
update_value(array_get<Idx>(t), value);
|
||||||
} else {
|
} else {
|
||||||
tuple_coeff<Idx-1>::set(i, t, value);
|
tuple_coeff<Idx-1, ValueT>::set(i, t, value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
EIGEN_DEVICE_FUNC static constexpr bool value_known_statically(const DenseIndex i, const IndexTuple<T...>& t) {
|
EIGEN_DEVICE_FUNC static constexpr bool value_known_statically(const DenseIndex i, const IndexTuple<T...>& t) {
|
||||||
return ((i == Idx) & is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value) ||
|
return ((i == Idx) & is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value) ||
|
||||||
tuple_coeff<Idx-1>::value_known_statically(i, t);
|
tuple_coeff<Idx-1, ValueT>::value_known_statically(i, t);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
EIGEN_DEVICE_FUNC static constexpr bool values_up_to_known_statically(const IndexTuple<T...>& t) {
|
EIGEN_DEVICE_FUNC static constexpr bool values_up_to_known_statically(const IndexTuple<T...>& t) {
|
||||||
return is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
|
return is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
|
||||||
tuple_coeff<Idx-1>::values_up_to_known_statically(t);
|
tuple_coeff<Idx-1, ValueT>::values_up_to_known_statically(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
@ -216,19 +276,19 @@ struct tuple_coeff {
|
|||||||
return is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
|
return is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
|
||||||
is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
|
is_compile_time_constant<typename IndexTupleExtractor<Idx, T...>::ValType>::value &&
|
||||||
array_get<Idx>(t) > array_get<Idx-1>(t) &&
|
array_get<Idx>(t) > array_get<Idx-1>(t) &&
|
||||||
tuple_coeff<Idx-1>::values_up_to_statically_known_to_increase(t);
|
tuple_coeff<Idx-1, ValueT>::values_up_to_statically_known_to_increase(t);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <typename ValueT>
|
||||||
struct tuple_coeff<0> {
|
struct tuple_coeff<0, ValueT> {
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
EIGEN_DEVICE_FUNC static constexpr DenseIndex get(const DenseIndex i, const IndexTuple<T...>& t) {
|
EIGEN_DEVICE_FUNC static constexpr ValueT get(const DenseIndex /*i*/, const IndexTuple<T...>& t) {
|
||||||
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
|
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
|
||||||
return array_get<0>(t) * (i == 0);
|
return array_get<0>(t)/* * (i == 0)*/;
|
||||||
}
|
}
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const DenseIndex value) {
|
EIGEN_DEVICE_FUNC static void set(const DenseIndex i, IndexTuple<T...>& t, const ValueT value) {
|
||||||
eigen_assert (i == 0);
|
eigen_assert (i == 0);
|
||||||
update_value(array_get<0>(t), value);
|
update_value(array_get<0>(t), value);
|
||||||
}
|
}
|
||||||
@ -254,13 +314,13 @@ struct tuple_coeff<0> {
|
|||||||
template<typename FirstType, typename... OtherTypes>
|
template<typename FirstType, typename... OtherTypes>
|
||||||
struct IndexList : internal::IndexTuple<FirstType, OtherTypes...> {
|
struct IndexList : internal::IndexTuple<FirstType, OtherTypes...> {
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const {
|
||||||
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::get(i, *this);
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::get(i, *this);
|
||||||
}
|
}
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex get(const DenseIndex i) const {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex get(const DenseIndex i) const {
|
||||||
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::get(i, *this);
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::get(i, *this);
|
||||||
}
|
}
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) {
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) {
|
||||||
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::set(i, *this, value);
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::set(i, *this, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC constexpr IndexList(const internal::IndexTuple<FirstType, OtherTypes...>& other) : internal::IndexTuple<FirstType, OtherTypes...>(other) { }
|
EIGEN_DEVICE_FUNC constexpr IndexList(const internal::IndexTuple<FirstType, OtherTypes...>& other) : internal::IndexTuple<FirstType, OtherTypes...>(other) { }
|
||||||
@ -268,14 +328,14 @@ struct IndexList : internal::IndexTuple<FirstType, OtherTypes...> {
|
|||||||
EIGEN_DEVICE_FUNC constexpr IndexList() : internal::IndexTuple<FirstType, OtherTypes...>() { }
|
EIGEN_DEVICE_FUNC constexpr IndexList() : internal::IndexTuple<FirstType, OtherTypes...>() { }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC constexpr bool value_known_statically(const DenseIndex i) const {
|
EIGEN_DEVICE_FUNC constexpr bool value_known_statically(const DenseIndex i) const {
|
||||||
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this);
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::value_known_statically(i, *this);
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC constexpr bool all_values_known_statically() const {
|
EIGEN_DEVICE_FUNC constexpr bool all_values_known_statically() const {
|
||||||
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::values_up_to_known_statically(*this);
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::values_up_to_known_statically(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC constexpr bool values_statically_known_to_increase() const {
|
EIGEN_DEVICE_FUNC constexpr bool values_statically_known_to_increase() const {
|
||||||
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1>::values_up_to_statically_known_to_increase(*this);
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::values_up_to_statically_known_to_increase(*this);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -286,6 +346,23 @@ constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, Ot
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename FirstType, typename... OtherTypes>
|
||||||
|
struct IndexPairList : internal::IndexTuple<FirstType, OtherTypes...> {
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr IndexPair<DenseIndex> operator[] (const DenseIndex i) const {
|
||||||
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, IndexPair<DenseIndex>>::get(i, *this);
|
||||||
|
}
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const IndexPair<DenseIndex> value) {
|
||||||
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...>>::value-1, IndexPair<DenseIndex> >::set(i, *this, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC constexpr IndexPairList(const internal::IndexTuple<FirstType, OtherTypes...>& other) : internal::IndexTuple<FirstType, OtherTypes...>(other) { }
|
||||||
|
EIGEN_DEVICE_FUNC constexpr IndexPairList() : internal::IndexTuple<FirstType, OtherTypes...>() { }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC constexpr bool value_known_statically(const DenseIndex i) const {
|
||||||
|
return internal::tuple_coeff<internal::array_size<internal::IndexTuple<FirstType, OtherTypes...> >::value-1, DenseIndex>::value_known_statically(i, *this);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename FirstType, typename... OtherTypes> size_t array_prod(const IndexList<FirstType, OtherTypes...>& sizes) {
|
template<typename FirstType, typename... OtherTypes> size_t array_prod(const IndexList<FirstType, OtherTypes...>& sizes) {
|
||||||
@ -303,6 +380,13 @@ template<typename FirstType, typename... OtherTypes> struct array_size<const Ind
|
|||||||
static const size_t value = array_size<IndexTuple<FirstType, OtherTypes...> >::value;
|
static const size_t value = array_size<IndexTuple<FirstType, OtherTypes...> >::value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename FirstType, typename... OtherTypes> struct array_size<IndexPairList<FirstType, OtherTypes...> > {
|
||||||
|
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
||||||
|
};
|
||||||
|
template<typename FirstType, typename... OtherTypes> struct array_size<const IndexPairList<FirstType, OtherTypes...> > {
|
||||||
|
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
||||||
|
};
|
||||||
|
|
||||||
template<DenseIndex N, typename FirstType, typename... OtherTypes> EIGEN_DEVICE_FUNC constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) {
|
template<DenseIndex N, typename FirstType, typename... OtherTypes> EIGEN_DEVICE_FUNC constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) {
|
||||||
return IndexTupleExtractor<N, FirstType, OtherTypes...>::get_val(a);
|
return IndexTupleExtractor<N, FirstType, OtherTypes...>::get_val(a);
|
||||||
}
|
}
|
||||||
@ -472,6 +556,57 @@ struct index_statically_lt_impl<const IndexList<FirstType, OtherTypes...> > {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename Tx>
|
||||||
|
struct index_pair_first_statically_eq_impl {
|
||||||
|
EIGEN_DEVICE_FUNC static constexpr bool run(DenseIndex, DenseIndex) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_pair_first_statically_eq_impl<IndexPairList<FirstType, OtherTypes...> > {
|
||||||
|
EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
|
||||||
|
return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
(IndexPairList<FirstType, OtherTypes...>()[i].first == value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_pair_first_statically_eq_impl<const IndexPairList<FirstType, OtherTypes...> > {
|
||||||
|
EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
|
||||||
|
return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
(IndexPairList<FirstType, OtherTypes...>()[i].first == value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
template <typename Tx>
|
||||||
|
struct index_pair_second_statically_eq_impl {
|
||||||
|
EIGEN_DEVICE_FUNC static constexpr bool run(DenseIndex, DenseIndex) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_pair_second_statically_eq_impl<IndexPairList<FirstType, OtherTypes...> > {
|
||||||
|
EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
|
||||||
|
return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
(IndexPairList<FirstType, OtherTypes...>()[i].second == value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_pair_second_statically_eq_impl<const IndexPairList<FirstType, OtherTypes...> > {
|
||||||
|
EIGEN_DEVICE_FUNC static constexpr bool run(const DenseIndex i, const DenseIndex value) {
|
||||||
|
return IndexPairList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
(IndexPairList<FirstType, OtherTypes...>()[i].second == value);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
@ -482,53 +617,69 @@ namespace internal {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct index_known_statically_impl {
|
struct index_known_statically_impl {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(const DenseIndex) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(const DenseIndex) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct all_indices_known_statically_impl {
|
struct all_indices_known_statically_impl {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run() {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct indices_statically_known_to_increase_impl {
|
struct indices_statically_known_to_increase_impl {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run() {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run() {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct index_statically_eq_impl {
|
struct index_statically_eq_impl {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct index_statically_ne_impl {
|
struct index_statically_ne_impl {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct index_statically_gt_impl {
|
struct index_statically_gt_impl {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct index_statically_lt_impl {
|
struct index_statically_lt_impl {
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE static bool run(DenseIndex, DenseIndex) {
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Tx>
|
||||||
|
struct index_pair_first_statically_eq_impl {
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Tx>
|
||||||
|
struct index_pair_second_statically_eq_impl {
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool run(DenseIndex, DenseIndex) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
@ -572,6 +723,16 @@ static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_statically_lt(DenseIndex i,
|
|||||||
return index_statically_lt_impl<T>::run(i, value);
|
return index_statically_lt_impl<T>::run(i, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_pair_first_statically_eq(DenseIndex i, DenseIndex value) {
|
||||||
|
return index_pair_first_statically_eq_impl<T>::run(i, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR bool index_pair_second_statically_eq(DenseIndex i, DenseIndex value) {
|
||||||
|
return index_pair_second_statically_eq_impl<T>::run(i, value);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
@ -112,6 +112,20 @@ bool operator!=(const Tuple<U, V>& x, const Tuple<U, V>& y) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Can't use std::pairs on cuda devices
|
||||||
|
template <typename Idx> struct IndexPair {
|
||||||
|
constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair() : first(0), second(0) {}
|
||||||
|
constexpr EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE IndexPair(Idx f, Idx s) : first(f), second(s) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC void set(IndexPair<Idx> val) {
|
||||||
|
first = val.first;
|
||||||
|
second = val.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
Idx first;
|
||||||
|
Idx second;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
#ifdef EIGEN_HAS_SFINAE
|
#ifdef EIGEN_HAS_SFINAE
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
@ -159,6 +159,111 @@ static void test_type2index_list()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_type2indexpair_list()
|
||||||
|
{
|
||||||
|
Tensor<float, 5> tensor(2,3,5,7,11);
|
||||||
|
tensor.setRandom();
|
||||||
|
tensor += tensor.constant(10.0f);
|
||||||
|
|
||||||
|
typedef Eigen::IndexPairList<Eigen::type2indexpair<0,10>> Dims0;
|
||||||
|
typedef Eigen::IndexPairList<Eigen::type2indexpair<0,10>, Eigen::type2indexpair<1,11>, Eigen::type2indexpair<2,12>> Dims2_a;
|
||||||
|
typedef Eigen::IndexPairList<Eigen::type2indexpair<0,10>, Eigen::IndexPair<DenseIndex>, Eigen::type2indexpair<2,12>> Dims2_b;
|
||||||
|
typedef Eigen::IndexPairList<Eigen::IndexPair<DenseIndex>, Eigen::type2indexpair<1,11>, Eigen::IndexPair<DenseIndex>> Dims2_c;
|
||||||
|
|
||||||
|
Dims0 d0;
|
||||||
|
Dims2_a d2_a;
|
||||||
|
|
||||||
|
Dims2_b d2_b;
|
||||||
|
d2_b.set(1, Eigen::IndexPair<DenseIndex>(1,11));
|
||||||
|
|
||||||
|
Dims2_c d2_c;
|
||||||
|
d2_c.set(0, Eigen::IndexPair<DenseIndex>(Eigen::IndexPair<DenseIndex>(0,10)));
|
||||||
|
d2_c.set(1, Eigen::IndexPair<DenseIndex>(1,11)); // setting type2indexpair to correct value.
|
||||||
|
d2_c.set(2, Eigen::IndexPair<DenseIndex>(2,12));
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(d2_a[0].first, 0);
|
||||||
|
VERIFY_IS_EQUAL(d2_a[0].second, 10);
|
||||||
|
VERIFY_IS_EQUAL(d2_a[1].first, 1);
|
||||||
|
VERIFY_IS_EQUAL(d2_a[1].second, 11);
|
||||||
|
VERIFY_IS_EQUAL(d2_a[2].first, 2);
|
||||||
|
VERIFY_IS_EQUAL(d2_a[2].second, 12);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(d2_b[0].first, 0);
|
||||||
|
VERIFY_IS_EQUAL(d2_b[0].second, 10);
|
||||||
|
VERIFY_IS_EQUAL(d2_b[1].first, 1);
|
||||||
|
VERIFY_IS_EQUAL(d2_b[1].second, 11);
|
||||||
|
VERIFY_IS_EQUAL(d2_b[2].first, 2);
|
||||||
|
VERIFY_IS_EQUAL(d2_b[2].second, 12);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(d2_c[0].first, 0);
|
||||||
|
VERIFY_IS_EQUAL(d2_c[0].second, 10);
|
||||||
|
VERIFY_IS_EQUAL(d2_c[1].first, 1);
|
||||||
|
VERIFY_IS_EQUAL(d2_c[1].second, 11);
|
||||||
|
VERIFY_IS_EQUAL(d2_c[2].first, 2);
|
||||||
|
VERIFY_IS_EQUAL(d2_c[2].second, 12);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((d2_a.value_known_statically(0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((d2_a.value_known_statically(1) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((d2_a.value_known_statically(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((d2_b.value_known_statically(0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((d2_b.value_known_statically(1) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((d2_b.value_known_statically(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((d2_c.value_known_statically(0) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((d2_c.value_known_statically(1) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((d2_c.value_known_statically(2) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims0>(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims0>(0, 1) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_a>(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_a>(0, 1) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_a>(1, 1) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_a>(1, 2) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_a>(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_a>(2, 3) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_b>(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_b>(0, 1) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_b>(1, 1) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_b>(1, 2) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_b>(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_b>(2, 3) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_c>(0, 0) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_c>(0, 1) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_c>(1, 1) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_c>(1, 2) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_c>(2, 2) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_first_statically_eq<Dims2_c>(2, 3) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims0>(0, 10) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims0>(0, 11) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_a>(0, 10) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_a>(0, 11) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_a>(1, 11) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_a>(1, 12) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_a>(2, 12) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_a>(2, 13) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_b>(0, 10) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_b>(0, 11) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_b>(1, 11) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_b>(1, 12) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_b>(2, 12) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_b>(2, 13) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_c>(0, 10) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_c>(0, 11) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_c>(1, 11) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_c>(1, 12) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_c>(2, 12) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((Eigen::internal::index_pair_second_statically_eq<Dims2_c>(2, 13) == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_dynamic_index_list()
|
static void test_dynamic_index_list()
|
||||||
{
|
{
|
||||||
Tensor<float, 4> tensor(2,3,5,7);
|
Tensor<float, 4> tensor(2,3,5,7);
|
||||||
@ -273,6 +378,7 @@ void test_cxx11_tensor_index_list()
|
|||||||
#ifdef EIGEN_HAS_INDEX_LIST
|
#ifdef EIGEN_HAS_INDEX_LIST
|
||||||
CALL_SUBTEST(test_static_index_list());
|
CALL_SUBTEST(test_static_index_list());
|
||||||
CALL_SUBTEST(test_type2index_list());
|
CALL_SUBTEST(test_type2index_list());
|
||||||
|
CALL_SUBTEST(test_type2indexpair_list());
|
||||||
CALL_SUBTEST(test_dynamic_index_list());
|
CALL_SUBTEST(test_dynamic_index_list());
|
||||||
CALL_SUBTEST(test_mixed_index_list());
|
CALL_SUBTEST(test_mixed_index_list());
|
||||||
CALL_SUBTEST(test_dim_check());
|
CALL_SUBTEST(test_dim_check());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user