mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 11:19:02 +08:00
Added support for static list of indices
This commit is contained in:
parent
cb37f818ca
commit
c2d1074932
@ -43,6 +43,7 @@
|
|||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
|
||||||
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"
|
||||||
|
264
unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
Normal file
264
unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h
Normal file
@ -0,0 +1,264 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
|
||||||
|
#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
|
||||||
|
|
||||||
|
#if __cplusplus > 199711L
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
|
||||||
|
/** \internal
|
||||||
|
*
|
||||||
|
* \class TensorIndexList
|
||||||
|
* \ingroup CXX11_Tensor_Module
|
||||||
|
*
|
||||||
|
* \brief Set of classes used to encode a set of Tensor dimensions/indices.
|
||||||
|
*
|
||||||
|
* The indices in the list can be known at compile time or at runtime. A mix
|
||||||
|
* of static and dynamic indices can also be provided if needed. The tensor
|
||||||
|
* code will attempt to take advantage of the indices that are known at
|
||||||
|
* compile time to optimize the code it generates.
|
||||||
|
*
|
||||||
|
* This functionality requires a c++11 compliant compiler. If your compiler
|
||||||
|
* is older you need to use arrays of indices instead.
|
||||||
|
*
|
||||||
|
* Several examples are provided in the cxx11_tensor_index_list.cpp file.
|
||||||
|
*
|
||||||
|
* \sa Tensor
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <DenseIndex n>
|
||||||
|
struct type2index {
|
||||||
|
static const DenseIndex value = n;
|
||||||
|
constexpr operator DenseIndex() const { return n; }
|
||||||
|
void set(DenseIndex val) {
|
||||||
|
eigen_assert(val == n);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
template <typename T>
|
||||||
|
void update_value(T& val, DenseIndex new_val) {
|
||||||
|
val = new_val;
|
||||||
|
}
|
||||||
|
template <DenseIndex n>
|
||||||
|
void update_value(type2index<n>& val, DenseIndex new_val) {
|
||||||
|
val.set(new_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct is_compile_time_constant {
|
||||||
|
static constexpr bool value = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <DenseIndex idx>
|
||||||
|
struct is_compile_time_constant<type2index<idx> > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
template <DenseIndex idx>
|
||||||
|
struct is_compile_time_constant<const type2index<idx> > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
template <DenseIndex idx>
|
||||||
|
struct is_compile_time_constant<type2index<idx>& > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
template <DenseIndex idx>
|
||||||
|
struct is_compile_time_constant<const type2index<idx>& > {
|
||||||
|
static constexpr bool value = true;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <DenseIndex Idx>
|
||||||
|
struct tuple_coeff {
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) {
|
||||||
|
return std::get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx);
|
||||||
|
}
|
||||||
|
template <typename... T>
|
||||||
|
static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) {
|
||||||
|
if (i == Idx) {
|
||||||
|
update_value(std::get<Idx>(t), value);
|
||||||
|
} else {
|
||||||
|
tuple_coeff<Idx-1>::set(i, t, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
|
||||||
|
return ((i == Idx) & is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value) ||
|
||||||
|
tuple_coeff<Idx-1>::value_known_statically(i, t);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct tuple_coeff<0> {
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) {
|
||||||
|
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
|
||||||
|
return std::get<0>(t) * (i == 0);
|
||||||
|
}
|
||||||
|
template <typename... T>
|
||||||
|
static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) {
|
||||||
|
eigen_assert (i == 0);
|
||||||
|
update_value(std::get<0>(t), value);
|
||||||
|
}
|
||||||
|
template <typename... T>
|
||||||
|
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
|
||||||
|
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
|
||||||
|
return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value & (i == 0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
|
||||||
|
template<typename FirstType, typename... OtherTypes>
|
||||||
|
struct IndexList : std::tuple<FirstType, OtherTypes...> {
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const {
|
||||||
|
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::get(i, *this);
|
||||||
|
}
|
||||||
|
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) {
|
||||||
|
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::set(i, *this, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr IndexList(const std::tuple<FirstType, OtherTypes...>& other) : std::tuple<FirstType, OtherTypes...>(other) { }
|
||||||
|
constexpr IndexList() : std::tuple<FirstType, OtherTypes...>() { }
|
||||||
|
|
||||||
|
constexpr bool value_known_statically(const DenseIndex i) const {
|
||||||
|
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template<typename FirstType, typename... OtherTypes>
|
||||||
|
constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, OtherTypes... other_vals) {
|
||||||
|
return std::make_tuple(val1, other_vals...);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
||||||
|
};
|
||||||
|
template<typename FirstType, typename... OtherTypes> struct array_size<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) {
|
||||||
|
return std::get<n>(a);
|
||||||
|
}
|
||||||
|
template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(const IndexList<FirstType, OtherTypes...>& a) {
|
||||||
|
return std::get<n>(a);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_known_statically {
|
||||||
|
constexpr bool operator() (DenseIndex) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_known_statically<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_known_statically<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Tx>
|
||||||
|
struct index_statically_eq {
|
||||||
|
constexpr bool operator() (DenseIndex, DenseIndex) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_eq<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] == value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_eq<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] == value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_statically_ne {
|
||||||
|
constexpr bool operator() (DenseIndex, DenseIndex) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_ne<IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] != value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename FirstType, typename... OtherTypes>
|
||||||
|
struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
|
||||||
|
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
|
||||||
|
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
|
||||||
|
IndexList<FirstType, OtherTypes...>()[i] != value;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// No C++11 support
|
||||||
|
template <typename T>
|
||||||
|
struct index_known_statically {
|
||||||
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex) const{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_statically_eq {
|
||||||
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct index_statically_ne {
|
||||||
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
|
@ -102,6 +102,7 @@ if(EIGEN_TEST_CXX11)
|
|||||||
ei_add_test(cxx11_tensor_symmetry "-std=c++0x")
|
ei_add_test(cxx11_tensor_symmetry "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_assign "-std=c++0x")
|
ei_add_test(cxx11_tensor_assign "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_dimension "-std=c++0x")
|
ei_add_test(cxx11_tensor_dimension "-std=c++0x")
|
||||||
|
ei_add_test(cxx11_tensor_index_list "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_comparisons "-std=c++0x")
|
ei_add_test(cxx11_tensor_comparisons "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_contraction "-std=c++0x")
|
ei_add_test(cxx11_tensor_contraction "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_convolution "-std=c++0x")
|
ei_add_test(cxx11_tensor_convolution "-std=c++0x")
|
||||||
|
133
unsupported/test/cxx11_tensor_index_list.cpp
Normal file
133
unsupported/test/cxx11_tensor_index_list.cpp
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#include "main.h"
|
||||||
|
|
||||||
|
#include <Eigen/CXX11/Tensor>
|
||||||
|
|
||||||
|
|
||||||
|
static void test_static_index_list()
|
||||||
|
{
|
||||||
|
Tensor<float, 4> tensor(2,3,5,7);
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
constexpr auto reduction_axis = make_index_list(0, 1, 2);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2);
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_axis) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::array_get<1>(reduction_axis) == 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_axis) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
Tensor<float, 1> result = tensor.sum(reduction_axis);
|
||||||
|
for (int i = 0; i < result.size(); ++i) {
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int j = 0; j < 2; ++j) {
|
||||||
|
for (int k = 0; k < 3; ++k) {
|
||||||
|
for (int l = 0; l < 5; ++l) {
|
||||||
|
expected += tensor(j,k,l,i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result(i), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_dynamic_index_list()
|
||||||
|
{
|
||||||
|
Tensor<float, 4> tensor(2,3,5,7);
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
int dim1 = 2;
|
||||||
|
int dim2 = 1;
|
||||||
|
int dim3 = 0;
|
||||||
|
|
||||||
|
auto reduction_axis = make_index_list(dim1, dim2, dim3);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 2);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 0);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 2);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 0);
|
||||||
|
|
||||||
|
Tensor<float, 1> result = tensor.sum(reduction_axis);
|
||||||
|
for (int i = 0; i < result.size(); ++i) {
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int j = 0; j < 2; ++j) {
|
||||||
|
for (int k = 0; k < 3; ++k) {
|
||||||
|
for (int l = 0; l < 5; ++l) {
|
||||||
|
expected += tensor(j,k,l,i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result(i), expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_mixed_index_list()
|
||||||
|
{
|
||||||
|
Tensor<float, 4> tensor(2,3,5,7);
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
int dim2 = 1;
|
||||||
|
int dim4 = 3;
|
||||||
|
|
||||||
|
auto reduction_axis = make_index_list(0, dim2, 2, dim4);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2);
|
||||||
|
VERIFY_IS_EQUAL(internal::array_get<3>(reduction_axis), 3);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2);
|
||||||
|
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[3]), 3);
|
||||||
|
|
||||||
|
typedef IndexList<type2index<0>, int, type2index<2>, int> ReductionIndices;
|
||||||
|
ReductionIndices reduction_indices;
|
||||||
|
reduction_indices.set(1, 1);
|
||||||
|
reduction_indices.set(3, 3);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_indices) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_indices) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
|
||||||
|
Tensor<float, 1> result1 = tensor.sum(reduction_axis);
|
||||||
|
Tensor<float, 1> result2 = tensor.sum(reduction_indices);
|
||||||
|
|
||||||
|
float expected = 0.0f;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
for (int l = 0; l < 7; ++l) {
|
||||||
|
expected += tensor(i,j,k,l);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VERIFY_IS_APPROX(result1(0), expected);
|
||||||
|
VERIFY_IS_APPROX(result2(0), expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void test_cxx11_tensor_index_list()
|
||||||
|
{
|
||||||
|
CALL_SUBTEST(test_static_index_list());
|
||||||
|
CALL_SUBTEST(test_dynamic_index_list());
|
||||||
|
CALL_SUBTEST(test_mixed_index_list());
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user