Added support for static list of indices

This commit is contained in:
Benoit Steiner 2014-11-12 22:25:38 -08:00
parent cb37f818ca
commit c2d1074932
4 changed files with 399 additions and 0 deletions

View File

@ -43,6 +43,7 @@
#include "unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDeviceType.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h"

View File

@ -0,0 +1,264 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
#define EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H
#if __cplusplus > 199711L
namespace Eigen {
/** \internal
*
* \class TensorIndexList
* \ingroup CXX11_Tensor_Module
*
* \brief Set of classes used to encode a set of Tensor dimensions/indices.
*
* The indices in the list can be known at compile time or at runtime. A mix
* of static and dynamic indices can also be provided if needed. The tensor
* code will attempt to take advantage of the indices that are known at
* compile time to optimize the code it generates.
*
* This functionality requires a c++11 compliant compiler. If your compiler
* is older you need to use arrays of indices instead.
*
* Several examples are provided in the cxx11_tensor_index_list.cpp file.
*
* \sa Tensor
*/
template <DenseIndex n>
struct type2index {
static const DenseIndex value = n;
constexpr operator DenseIndex() const { return n; }
void set(DenseIndex val) {
eigen_assert(val == n);
}
};
namespace internal {
template <typename T>
void update_value(T& val, DenseIndex new_val) {
val = new_val;
}
template <DenseIndex n>
void update_value(type2index<n>& val, DenseIndex new_val) {
val.set(new_val);
}
template <typename T>
struct is_compile_time_constant {
static constexpr bool value = false;
};
template <DenseIndex idx>
struct is_compile_time_constant<type2index<idx> > {
static constexpr bool value = true;
};
template <DenseIndex idx>
struct is_compile_time_constant<const type2index<idx> > {
static constexpr bool value = true;
};
template <DenseIndex idx>
struct is_compile_time_constant<type2index<idx>& > {
static constexpr bool value = true;
};
template <DenseIndex idx>
struct is_compile_time_constant<const type2index<idx>& > {
static constexpr bool value = true;
};
template <DenseIndex Idx>
struct tuple_coeff {
template <typename... T>
static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) {
return std::get<Idx>(t) * (i == Idx) + tuple_coeff<Idx-1>::get(i, t) * (i != Idx);
}
template <typename... T>
static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) {
if (i == Idx) {
update_value(std::get<Idx>(t), value);
} else {
tuple_coeff<Idx-1>::set(i, t, value);
}
}
template <typename... T>
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
return ((i == Idx) & is_compile_time_constant<typename std::tuple_element<Idx, std::tuple<T...> >::type>::value) ||
tuple_coeff<Idx-1>::value_known_statically(i, t);
}
};
template <>
struct tuple_coeff<0> {
template <typename... T>
static constexpr DenseIndex get(const DenseIndex i, const std::tuple<T...>& t) {
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
return std::get<0>(t) * (i == 0);
}
template <typename... T>
static void set(const DenseIndex i, std::tuple<T...>& t, const DenseIndex value) {
eigen_assert (i == 0);
update_value(std::get<0>(t), value);
}
template <typename... T>
static constexpr bool value_known_statically(const DenseIndex i, const std::tuple<T...>& t) {
// eigen_assert (i == 0); // gcc fails to compile assertions in constexpr
return is_compile_time_constant<typename std::tuple_element<0, std::tuple<T...> >::type>::value & (i == 0);
}
};
} // namespace internal
template<typename FirstType, typename... OtherTypes>
struct IndexList : std::tuple<FirstType, OtherTypes...> {
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr DenseIndex operator[] (const DenseIndex i) const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::get(i, *this);
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC void set(const DenseIndex i, const DenseIndex value) {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::set(i, *this, value);
}
constexpr IndexList(const std::tuple<FirstType, OtherTypes...>& other) : std::tuple<FirstType, OtherTypes...>(other) { }
constexpr IndexList() : std::tuple<FirstType, OtherTypes...>() { }
constexpr bool value_known_statically(const DenseIndex i) const {
return internal::tuple_coeff<std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value-1>::value_known_statically(i, *this);
}
};
template<typename FirstType, typename... OtherTypes>
constexpr IndexList<FirstType, OtherTypes...> make_index_list(FirstType val1, OtherTypes... other_vals) {
return std::make_tuple(val1, other_vals...);
}
namespace internal {
template<typename FirstType, typename... OtherTypes> struct array_size<IndexList<FirstType, OtherTypes...> > {
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
};
template<typename FirstType, typename... OtherTypes> struct array_size<const IndexList<FirstType, OtherTypes...> > {
static const size_t value = std::tuple_size<std::tuple<FirstType, OtherTypes...> >::value;
};
template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(IndexList<FirstType, OtherTypes...>& a) {
return std::get<n>(a);
}
template<DenseIndex n, typename FirstType, typename... OtherTypes> constexpr DenseIndex array_get(const IndexList<FirstType, OtherTypes...>& a) {
return std::get<n>(a);
}
template <typename T>
struct index_known_statically {
constexpr bool operator() (DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_known_statically<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
}
};
template <typename FirstType, typename... OtherTypes>
struct index_known_statically<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i);
}
};
template <typename Tx>
struct index_statically_eq {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_eq<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
IndexList<FirstType, OtherTypes...>()[i] == value;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_eq<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
IndexList<FirstType, OtherTypes...>()[i] == value;
}
};
template <typename T>
struct index_statically_ne {
constexpr bool operator() (DenseIndex, DenseIndex) const {
return false;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_ne<IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
IndexList<FirstType, OtherTypes...>()[i] != value;
}
};
template <typename FirstType, typename... OtherTypes>
struct index_statically_ne<const IndexList<FirstType, OtherTypes...> > {
constexpr bool operator() (const DenseIndex i, const DenseIndex value) const {
return IndexList<FirstType, OtherTypes...>().value_known_statically(i) &
IndexList<FirstType, OtherTypes...>()[i] != value;
}
};
} // end namespace internal
} // end namespace Eigen
#else
namespace Eigen {
namespace internal {
// No C++11 support
template <typename T>
struct index_known_statically {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex) const{
return false;
}
};
template <typename T>
struct index_statically_eq {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
return false;
}
};
template <typename T>
struct index_statically_ne {
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool operator() (DenseIndex, DenseIndex) const{
return false;
}
};
} // end namespace internal
} // end namespace Eigen
#endif
#endif // EIGEN_CXX11_TENSOR_TENSOR_INDEX_LIST_H

View File

@ -102,6 +102,7 @@ if(EIGEN_TEST_CXX11)
ei_add_test(cxx11_tensor_symmetry "-std=c++0x")
ei_add_test(cxx11_tensor_assign "-std=c++0x")
ei_add_test(cxx11_tensor_dimension "-std=c++0x")
ei_add_test(cxx11_tensor_index_list "-std=c++0x")
ei_add_test(cxx11_tensor_comparisons "-std=c++0x")
ei_add_test(cxx11_tensor_contraction "-std=c++0x")
ei_add_test(cxx11_tensor_convolution "-std=c++0x")

View File

@ -0,0 +1,133 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#include "main.h"
#include <Eigen/CXX11/Tensor>
static void test_static_index_list()
{
Tensor<float, 4> tensor(2,3,5,7);
tensor.setRandom();
constexpr auto reduction_axis = make_index_list(0, 1, 2);
VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0);
VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2);
EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_axis) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((internal::array_get<1>(reduction_axis) == 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_axis) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE);
Tensor<float, 1> result = tensor.sum(reduction_axis);
for (int i = 0; i < result.size(); ++i) {
float expected = 0.0f;
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 3; ++k) {
for (int l = 0; l < 5; ++l) {
expected += tensor(j,k,l,i);
}
}
}
VERIFY_IS_APPROX(result(i), expected);
}
}
static void test_dynamic_index_list()
{
Tensor<float, 4> tensor(2,3,5,7);
tensor.setRandom();
int dim1 = 2;
int dim2 = 1;
int dim3 = 0;
auto reduction_axis = make_index_list(dim1, dim2, dim3);
VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 2);
VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 0);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 2);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 0);
Tensor<float, 1> result = tensor.sum(reduction_axis);
for (int i = 0; i < result.size(); ++i) {
float expected = 0.0f;
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 3; ++k) {
for (int l = 0; l < 5; ++l) {
expected += tensor(j,k,l,i);
}
}
}
VERIFY_IS_APPROX(result(i), expected);
}
}
static void test_mixed_index_list()
{
Tensor<float, 4> tensor(2,3,5,7);
tensor.setRandom();
int dim2 = 1;
int dim4 = 3;
auto reduction_axis = make_index_list(0, dim2, 2, dim4);
VERIFY_IS_EQUAL(internal::array_get<0>(reduction_axis), 0);
VERIFY_IS_EQUAL(internal::array_get<1>(reduction_axis), 1);
VERIFY_IS_EQUAL(internal::array_get<2>(reduction_axis), 2);
VERIFY_IS_EQUAL(internal::array_get<3>(reduction_axis), 3);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[0]), 0);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[1]), 1);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[2]), 2);
VERIFY_IS_EQUAL(static_cast<DenseIndex>(reduction_axis[3]), 3);
typedef IndexList<type2index<0>, int, type2index<2>, int> ReductionIndices;
ReductionIndices reduction_indices;
reduction_indices.set(1, 1);
reduction_indices.set(3, 3);
EIGEN_STATIC_ASSERT((internal::array_get<0>(reduction_indices) == 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((internal::array_get<2>(reduction_indices) == 2), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((internal::index_known_statically<ReductionIndices>()(2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(0, 0) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_STATIC_ASSERT((internal::index_statically_eq<ReductionIndices>()(2, 2) == true), YOU_MADE_A_PROGRAMMING_MISTAKE);
Tensor<float, 1> result1 = tensor.sum(reduction_axis);
Tensor<float, 1> result2 = tensor.sum(reduction_indices);
float expected = 0.0f;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 5; ++k) {
for (int l = 0; l < 7; ++l) {
expected += tensor(i,j,k,l);
}
}
}
}
VERIFY_IS_APPROX(result1(0), expected);
VERIFY_IS_APPROX(result2(0), expected);
}
void test_cxx11_tensor_index_list()
{
CALL_SUBTEST(test_static_index_list());
CALL_SUBTEST(test_dynamic_index_list());
CALL_SUBTEST(test_mixed_index_list());
}