From c2d1074932ae92a001eadb27e9f85eaf2de187b9 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 12 Nov 2014 22:25:38 -0800 Subject: [PATCH] Added support for static list of indices --- unsupported/Eigen/CXX11/Tensor | 1 + .../Eigen/CXX11/src/Tensor/TensorIndexList.h | 264 ++++++++++++++++++ unsupported/test/CMakeLists.txt | 1 + unsupported/test/cxx11_tensor_index_list.cpp | 133 +++++++++ 4 files changed, 399 insertions(+) create mode 100644 unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h create mode 100644 unsupported/test/cxx11_tensor_index_list.cpp diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index c36db96ec..44d5a4d82 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -43,6 +43,7 @@ #include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h new file mode 100644 index 000000000..010221e74 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h @@ -0,0 +1,264 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H +#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H + +#if __cplusplus > 199711L + +namespace Eigen { + +/** \internal + * + * \class TensorIndexList + * \ingroup CXX11_Tensor_Module + * + * \brief Set of classes used to encode a set of Tensor dimensions/indices. + * + * The indices in the list can be known at compile time or at runtime. A mix + * of static and dynamic indices can also be provided if needed. The tensor + * code will attempt to take advantage of the indices that are known at + * compile time to optimize the code it generates. + * + * This functionality requires a c++11 compliant compiler. If your compiler + * is older you need to use arrays of indices instead. + * + * Several examples are provided in the cxx11_tensor_index_list.cpp file. + * + * \sa Tensor + */ + +template +struct type2index { + static const DenseIndex value = n; + constexpr operator DenseIndex() const { return n; } + void set(DenseIndex val) { + eigen_assert(val == n); + } +}; + +namespace internal { +template +void update_value(T& val, DenseIndex new_val) { + val = new_val; +} +template +void update_value(type2index& val, DenseIndex new_val) { + val.set(new_val); +} + +template +struct is_compile_time_constant { + static constexpr bool value = false; +}; + +template +struct is_compile_time_constant > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant& > { + static constexpr bool value = true; +}; +template +struct is_compile_time_constant& > { + static constexpr bool value = true; +}; + +template +struct tuple_coeff { + template + static constexpr DenseIndex get(const DenseIndex i, const std::tuple& t) { + return std::get(t) * (i == Idx) + tuple_coeff::get(i, t) * (i != Idx); + } + template + static void set(const DenseIndex i, std::tuple& t, const DenseIndex value) { + if (i == Idx) { + update_value(std::get(t), value); + } else { + tuple_coeff::set(i, t, value); + } + } + + template + static constexpr bool value_known_statically(const DenseIndex i, const std::tuple& t) { + return ((i == Idx) & is_compile_time_constant >::type>::value) || + tuple_coeff::value_known_statically(i, t); + } +}; + +template <> +struct tuple_coeff<0> { + template + static constexpr DenseIndex get(const DenseIndex i, const std::tuple& t) { + // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr + return std::get<0>(t) * (i == 0); + } + template + static void set(const DenseIndex i, std::tuple& t, const DenseIndex value) { + eigen_assert (i == 0); + update_value(std::get<0>(t), value); + } + template + static constexpr bool value_known_statically(const DenseIndex i, const std::tuple& t) { + // eigen_assert (i == 0); // gcc fails to compile assertions in constexpr + return is_compile_time_constant >::type>::value & (i == 0); + } +}; +} // namespace internal + + +template +struct IndexList : std::tuple { + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const { + return internal::tuple_coeff >::value-1>::get(i, *this); + } + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) { + return internal::tuple_coeff >::value-1>::set(i, *this, value); + } + + constexpr IndexList(const std::tuple& other) : std::tuple(other) { } + constexpr IndexList() : std::tuple() { } + + constexpr bool value_known_statically(const DenseIndex i) const { + return internal::tuple_coeff >::value-1>::value_known_statically(i, *this); + } +}; + + +template +constexpr IndexList make_index_list(FirstType val1, OtherTypes... other_vals) { + return std::make_tuple(val1, other_vals...); +} + + +namespace internal { + +template struct array_size > { + static const size_t value = std::tuple_size >::value; +}; +template struct array_size > { + static const size_t value = std::tuple_size >::value; +}; + +template constexpr DenseIndex array_get(IndexList& a) { + return std::get(a); +} +template constexpr DenseIndex array_get(const IndexList& a) { + return std::get(a); +} + +template +struct index_known_statically { + constexpr bool operator() (DenseIndex) const { + return false; + } +}; + +template +struct index_known_statically > { + constexpr bool operator() (const DenseIndex i) const { + return IndexList().value_known_statically(i); + } +}; + +template +struct index_known_statically > { + constexpr bool operator() (const DenseIndex i) const { + return IndexList().value_known_statically(i); + } +}; + +template +struct index_statically_eq { + constexpr bool operator() (DenseIndex, DenseIndex) const { + return false; + } +}; + +template +struct index_statically_eq > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] == value; + } +}; + +template +struct index_statically_eq > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] == value; + } +}; + +template +struct index_statically_ne { + constexpr bool operator() (DenseIndex, DenseIndex) const { + return false; + } +}; + +template +struct index_statically_ne > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] != value; + } +}; + +template +struct index_statically_ne > { + constexpr bool operator() (const DenseIndex i, const DenseIndex value) const { + return IndexList().value_known_statically(i) & + IndexList()[i] != value; + } +}; + + +} // end namespace internal +} // end namespace Eigen + +#else + +namespace Eigen { +namespace internal { + +// No C++11 support +template +struct index_known_statically { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex) const{ + return false; + } +}; + +template +struct index_statically_eq { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{ + return false; + } +}; + +template +struct index_statically_ne { + EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{ + return false; + } +}; + +} // end namespace internal +} // end namespace Eigen + +#endif + +#endif // EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index 6b8ed2826..181f06fc7 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -102,6 +102,7 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_symmetry "-std=c++0x") ei_add_test(cxx11_tensor_assign "-std=c++0x") ei_add_test(cxx11_tensor_dimension "-std=c++0x") + ei_add_test(cxx11_tensor_index_list "-std=c++0x") ei_add_test(cxx11_tensor_comparisons "-std=c++0x") ei_add_test(cxx11_tensor_contraction "-std=c++0x") ei_add_test(cxx11_tensor_convolution "-std=c++0x") diff --git a/unsupported/test/cxx11_tensor_index_list.cpp b/unsupported/test/cxx11_tensor_index_list.cpp new file mode 100644 index 000000000..6a103cab1 --- /dev/null +++ b/unsupported/test/cxx11_tensor_index_list.cpp @@ -0,0 +1,133 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2014 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include + + +static void test_static_index_list() +{ + Tensor tensor(2,3,5,7); + tensor.setRandom(); + + constexpr auto reduction_axis = make_index_list(0, 1, 2); + VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0); + VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1); + VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2); + VERIFY_IS_EQUAL(static_cast(reduction_axis[0]), 0); + VERIFY_IS_EQUAL(static_cast(reduction_axis[1]), 1); + VERIFY_IS_EQUAL(static_cast(reduction_axis[2]), 2); + + EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_axis) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::array_get<1>(reduction_axis) == 1), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_axis) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE); + + Tensor result = tensor.sum(reduction_axis); + for (int i = 0; i < result.size(); ++i) { + float expected = 0.0f; + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + for (int l = 0; l < 5; ++l) { + expected += tensor(j,k,l,i); + } + } + } + VERIFY_IS_APPROX(result(i), expected); + } +} + + +static void test_dynamic_index_list() +{ + Tensor tensor(2,3,5,7); + tensor.setRandom(); + + int dim1 = 2; + int dim2 = 1; + int dim3 = 0; + + auto reduction_axis = make_index_list(dim1, dim2, dim3); + + VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 2); + VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1); + VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 0); + VERIFY_IS_EQUAL(static_cast(reduction_axis[0]), 2); + VERIFY_IS_EQUAL(static_cast(reduction_axis[1]), 1); + VERIFY_IS_EQUAL(static_cast(reduction_axis[2]), 0); + + Tensor result = tensor.sum(reduction_axis); + for (int i = 0; i < result.size(); ++i) { + float expected = 0.0f; + for (int j = 0; j < 2; ++j) { + for (int k = 0; k < 3; ++k) { + for (int l = 0; l < 5; ++l) { + expected += tensor(j,k,l,i); + } + } + } + VERIFY_IS_APPROX(result(i), expected); + } +} + +static void test_mixed_index_list() +{ + Tensor tensor(2,3,5,7); + tensor.setRandom(); + + int dim2 = 1; + int dim4 = 3; + + auto reduction_axis = make_index_list(0, dim2, 2, dim4); + + VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0); + VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1); + VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2); + VERIFY_IS_EQUAL(internal::array_get<3>(reduction_axis), 3); + VERIFY_IS_EQUAL(static_cast(reduction_axis[0]), 0); + VERIFY_IS_EQUAL(static_cast(reduction_axis[1]), 1); + VERIFY_IS_EQUAL(static_cast(reduction_axis[2]), 2); + VERIFY_IS_EQUAL(static_cast(reduction_axis[3]), 3); + + typedef IndexList, int, type2index<2>, int> ReductionIndices; + ReductionIndices reduction_indices; + reduction_indices.set(1, 1); + reduction_indices.set(3, 3); + EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_indices) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_indices) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_known_statically()(0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_known_statically()(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_statically_eq()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + EIGEN_STATIC_ASSERT((internal::index_statically_eq()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE); + + + Tensor result1 = tensor.sum(reduction_axis); + Tensor result2 = tensor.sum(reduction_indices); + + float expected = 0.0f; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + expected += tensor(i,j,k,l); + } + } + } + } + VERIFY_IS_APPROX(result1(0), expected); + VERIFY_IS_APPROX(result2(0), expected); +} + + +void test_cxx11_tensor_index_list() +{ + CALL_SUBTEST(test_static_index_list()); + CALL_SUBTEST(test_dynamic_index_list()); + CALL_SUBTEST(test_mixed_index_list()); +}