mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-02 00:34:14 +08:00
Expanded the functionality of index lists
This commit is contained in:
parent
1ac8600126
commit
0feff6e987
@ -95,6 +95,20 @@ struct tuple_coeff {
|
|||||||
return ((i == Idx) & is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value) ||
|
return ((i == Idx) & is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value) ||
|
||||||
tuple_coeff<Idx-1>::value_known_statically(i, t);
|
tuple_coeff<Idx-1>::value_known_statically(i, t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr bool values_up_to_known_statically(const std::tuple<T...>& t) {
|
||||||
|
return is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value &&
|
||||||
|
tuple_coeff<Idx-1>::values_up_to_known_statically(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr bool values_up_to_statically_known_to_increase(const std::tuple<T...>& t) {
|
||||||
|
return is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value &&
|
||||||
|
is_compile_time_constant<typename std::tuple_element<Idx-1, std::tuple<T...> >::type>::value &&
|
||||||
|
std::get<Idx>(t) > std::get<Idx-1>(t) &&
|
||||||
|
tuple_coeff<Idx-1>::values_up_to_statically_known_to_increase(t);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
@ -110,10 +124,20 @@ struct tuple_coeff<0> {
|
|||||||
update_value(std::get<0>(t), value);
|
update_value(std::get<0>(t), value);
|
||||||
}
|
}
|
||||||
template <typename... T>
|
template <typename... T>
|
||||||
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>&) {
|
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
|
||||||
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
|
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
|
||||||
return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value & (i == 0);
|
return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value & (i == 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr bool values_up_to_known_statically(const std::tuple<T...>& t) {
|
||||||
|
return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr bool values_up_to_statically_known_to_increase(const std::tuple<T...>& t) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
@ -133,6 +157,13 @@ struct IndexList : std::tuple<FirstType, OtherTypes...> {
|
|||||||
constexpr bool value_known_statically(const DenseIndex i) const {
|
constexpr bool value_known_statically(const DenseIndex i) const {
|
||||||
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this);
|
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this);
|
||||||
}
|
}
|
||||||
|
constexpr bool all_values_known_statically() const {
|
||||||
|
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::values_up_to_known_statically(*this);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool values_statically_known_to_increase() const {
|
||||||
|
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::values_up_to_statically_known_to_increase(*this);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -144,6 +175,14 @@ constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, Ot
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename FirstType, typename... OtherTypes> size_t array_prod(const IndexList<FirstType, OtherTypes...>& sizes) {
|
||||||
|
size_t result = 1;
|
||||||
|
for (int i = 0; i < array_size<IndexList<FirstType, OtherTypes...> >::value; ++i) {
|
||||||
|
result *= sizes[i];
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
|
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
|
||||||
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
||||||
};
|
};
|
||||||
@ -179,6 +218,48 @@ struct index_known_statically<const IndexList<FirstType, OtherTypes...> > {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct all_indices_known_statically {
|
||||||
|
constexpr bool operator() () const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct all_indices_known_statically<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() () const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().all_values_known_statically();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct all_indices_known_statically<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() () const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().all_values_known_statically();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct indices_statically_known_to_increase {
|
||||||
|
constexpr bool operator() () const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct indices_statically_known_to_increase<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() () const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().values_statically_known_to_increase();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct indices_statically_known_to_increase<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() () const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().values_statically_known_to_increase();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename Tx>
|
template <typename Tx>
|
||||||
struct index_statically_eq {
|
struct index_statically_eq {
|
||||||
constexpr bool operator() (DenseIndex, DenseIndex) const {
|
constexpr bool operator() (DenseIndex, DenseIndex) const {
|
||||||
@ -190,7 +271,7 @@ template <typename FirstType, typename... OtherTypes>
|
|||||||
struct index_statically_eq<IndexList<FirstType, OtherTypes...> > {
|
struct index_statically_eq<IndexList<FirstType, OtherTypes...> > {
|
||||||
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
(IndexList<FirstType, OtherTypes...>()[i] == value);
|
IndexList<FirstType, OtherTypes...>()[i] == value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -198,7 +279,7 @@ template <typename FirstType, typename... OtherTypes>
|
|||||||
struct index_statically_eq<const IndexList<FirstType, OtherTypes...> > {
|
struct index_statically_eq<const IndexList<FirstType, OtherTypes...> > {
|
||||||
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
(IndexList<FirstType, OtherTypes...>()[i] == value);
|
IndexList<FirstType, OtherTypes...>()[i] == value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -213,7 +294,7 @@ template <typename FirstType, typename... OtherTypes>
|
|||||||
struct index_statically_ne<IndexList<FirstType, OtherTypes...> > {
|
struct index_statically_ne<IndexList<FirstType, OtherTypes...> > {
|
||||||
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
(IndexList<FirstType, OtherTypes...>()[i] != value);
|
IndexList<FirstType, OtherTypes...>()[i] != value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -221,7 +302,7 @@ template <typename FirstType, typename... OtherTypes>
|
|||||||
struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
|
struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
|
||||||
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
(IndexList<FirstType, OtherTypes...>()[i] != value);
|
IndexList<FirstType, OtherTypes...>()[i] != value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -242,6 +323,20 @@ struct index_known_statically {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct all_indices_known_statically {
|
||||||
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() () const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct indices_statically_known_to_increase {
|
||||||
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() () const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct index_statically_eq {
|
struct index_statically_eq {
|
||||||
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
|
||||||
|
@ -44,6 +44,120 @@ static void test_static_index_list()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_type2index_list()
|
||||||
|
{
|
||||||
|
Tensor<float, 5> tensor(2,3,5,7,11);
|
||||||
|
tensor.setRandom();
|
||||||
|
tensor += tensor.constant(10.0f);
|
||||||
|
|
||||||
|
typedef Eigen::IndexList<Eigen::type2index<0>> Dims0;
|
||||||
|
typedef Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<1>> Dims1;
|
||||||
|
typedef Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<1>, Eigen::type2index<2>> Dims2;
|
||||||
|
typedef Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<1>, Eigen::type2index<2>, Eigen::type2index<3>> Dims3;
|
||||||
|
typedef Eigen::IndexList<Eigen::type2index<0>, Eigen::type2index<1>, Eigen::type2index<2>, Eigen::type2index<3>, Eigen::type2index<4>> Dims4;
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
EIGEN_STATIC_ASSERT((internal::indices_statically_known_to_increase<Dims0>()() == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::indices_statically_known_to_increase<Dims1>()() == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::indices_statically_known_to_increase<Dims2>()() == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::indices_statically_known_to_increase<Dims3>()() == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::indices_statically_known_to_increase<Dims4>()() == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims0, 1, ColMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims1, 2, ColMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims2, 3, ColMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims3, 4, ColMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims4, 5, ColMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims0, 1, RowMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims1, 2, RowMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims2, 3, RowMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims3, 4, RowMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::are_inner_most_dims<Dims4, 5, RowMajor>::value == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
const Dims0 reduction_axis0;
|
||||||
|
Tensor<float, 4> result0 = tensor.sum(reduction_axis0);
|
||||||
|
for (int m = 0; m < 11; ++m) {
|
||||||
|
for (int l = 0; l < 7; ++l) {
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
expected += tensor(i,j,k,l,m);
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result0(j,k,l,m), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const Dims1 reduction_axis1;
|
||||||
|
Tensor<float, 3> result1 = tensor.sum(reduction_axis1);
|
||||||
|
for (int m = 0; m < 11; ++m) {
|
||||||
|
for (int l = 0; l < 7; ++l) {
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
expected += tensor(i,j,k,l,m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result1(k,l,m), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const Dims2 reduction_axis2;
|
||||||
|
Tensor<float, 2> result2 = tensor.sum(reduction_axis2);
|
||||||
|
for (int m = 0; m < 11; ++m) {
|
||||||
|
for (int l = 0; l < 7; ++l) {
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
expected += tensor(i,j,k,l,m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result2(l,m), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const Dims3 reduction_axis3;
|
||||||
|
Tensor<float, 1> result3 = tensor.sum(reduction_axis3);
|
||||||
|
for (int m = 0; m < 11; ++m) {
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int l = 0; l < 7; ++l) {
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
expected += tensor(i,j,k,l,m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result3(m), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
const Dims4 reduction_axis4;
|
||||||
|
Tensor<float, 1> result4 = tensor.sum(reduction_axis4);
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int m = 0; m < 11; ++m) {
|
||||||
|
for (int l = 0; l < 7; ++l) {
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
expected += tensor(i,j,k,l,m);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result4(0), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_dynamic_index_list()
|
static void test_dynamic_index_list()
|
||||||
{
|
{
|
||||||
Tensor<float, 4> tensor(2,3,5,7);
|
Tensor<float, 4> tensor(2,3,5,7);
|
||||||
@ -105,10 +219,25 @@ static void test_mixed_index_list()
|
|||||||
EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
#if 0
|
||||||
|
EIGEN_STATIC_ASSERT((internal::all_indices_known_statically<ReductionIndices>()() == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::indices_statically_known_to_increase<ReductionIndices>()() == false), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef IndexList<type2index<0>, type2index<1>, type2index<2>, type2index<3>> ReductionList;
|
||||||
|
ReductionList reduction_list;
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionList>()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionList>()(1, 1) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionList>()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionList>()(3, 3) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
#if 0
|
||||||
|
EIGEN_STATIC_ASSERT((internal::all_indices_known_statically<ReductionList>()() == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::indices_statically_known_to_increase<ReductionList>()() == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
#endif
|
||||||
|
|
||||||
Tensor<float, 1> result1 = tensor.sum(reduction_axis);
|
Tensor<float, 1> result1 = tensor.sum(reduction_axis);
|
||||||
Tensor<float, 1> result2 = tensor.sum(reduction_indices);
|
Tensor<float, 1> result2 = tensor.sum(reduction_indices);
|
||||||
|
Tensor<float, 1> result3 = tensor.sum(reduction_list);
|
||||||
|
|
||||||
float expected = 0.0f;
|
float expected = 0.0f;
|
||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
@ -122,12 +251,14 @@ static void test_mixed_index_list()
|
|||||||
}
|
}
|
||||||
VERIFY_IS_APPROX(result1(0), expected);
|
VERIFY_IS_APPROX(result1(0), expected);
|
||||||
VERIFY_IS_APPROX(result2(0), expected);
|
VERIFY_IS_APPROX(result2(0), expected);
|
||||||
|
VERIFY_IS_APPROX(result3(0), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void test_cxx11_tensor_index_list()
|
void test_cxx11_tensor_index_list()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_static_index_list());
|
CALL_SUBTEST(test_static_index_list());
|
||||||
|
CALL_SUBTEST(test_type2index_list());
|
||||||
CALL_SUBTEST(test_dynamic_index_list());
|
CALL_SUBTEST(test_dynamic_index_list());
|
||||||
CALL_SUBTEST(test_mixed_index_list());
|
CALL_SUBTEST(test_mixed_index_list());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user