diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index cbe416602..3331ccb55 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -57,6 +57,7 @@ #endif +#include "src/Tensor/TensorMacros.h" #include "src/Tensor/TensorForwardDeclarations.h" #include "src/Tensor/TensorMeta.h" #include "src/Tensor/TensorDeviceType.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h index 6c16e0faa..3ac465d24 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h @@ -88,6 +88,15 @@ class Tensor : public TensorBase m_storage; +#ifdef EIGEN_HAS_SFINAE + template + struct isOfNormalIndex{ + static const bool is_array = internal::is_base_of, CustomIndices>::value; + static const bool is_int = NumTraits::IsInteger; + static const bool value = is_array | is_int; + }; +#endif + public: // Metadata EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return NumIndices; } @@ -113,12 +122,24 @@ class Tensor : public TensorBase& indices) const { eigen_internal_assert(checkIndexRange(indices)); return m_storage.data()[linearizedIndex(indices)]; } + // custom indices +#ifdef EIGEN_HAS_SFINAE + template::value) ) + > + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(CustomIndices& indices) const + { + return coeff(internal::customIndices2Array(indices)); + } +#endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const { eigen_internal_assert(index >= 0 && index < size()); @@ -135,12 +156,24 @@ class Tensor : public TensorBase& indices) { eigen_internal_assert(checkIndexRange(indices)); return m_storage.data()[linearizedIndex(indices)]; } + // custom indices +#ifdef EIGEN_HAS_SFINAE + template::value) ) + > + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(CustomIndices& indices) + { + return coeffRef(internal::customIndices2Array(indices)); + } +#endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) { eigen_internal_assert(index >= 0 && index < size()); @@ -178,9 +211,20 @@ class Tensor : public TensorBase::value) ) + > + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(CustomIndices& indices) const + { + return coeff(internal::customIndices2Array(indices)); + } +#endif + + // normal indices EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const array& indices) const { - eigen_assert(checkIndexRange(indices)); return coeff(indices); } @@ -228,12 +272,23 @@ class Tensor : public TensorBase& indices) { - eigen_assert(checkIndexRange(indices)); return coeffRef(indices); } + // custom indices +#ifdef EIGEN_HAS_SFINAE + template::value) ) + > + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(CustomIndices& indices) + { + return coeffRef(internal::customIndices2Array(indices)); + } +#endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index index) { eigen_assert(index >= 0 && index < size()); @@ -295,6 +350,7 @@ class Tensor : public TensorBase& dimensions) : m_storage(internal::array_prod(dimensions), dimensions) { @@ -341,7 +397,7 @@ class Tensor : public TensorBase EIGEN_DEVICE_FUNC + template EIGEN_DEVICE_FUNC void resize(Index firstDimension, IndexTypes... otherDimensions) { // The number of dimensions used to resize a tensor must be equal to the rank of the tensor. @@ -350,6 +406,7 @@ class Tensor : public TensorBase& dimensions) { std::size_t i; @@ -367,6 +424,7 @@ class Tensor : public TensorBase& dimensions) { array dims; for (std::size_t i = 0; i < NumIndices; ++i) { @@ -375,6 +433,17 @@ class Tensor : public TensorBase::value) ) + > + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(CustomDimension& dimensions) + { + resize(internal::customIndices2Array(dimensions)); + } +#endif + #ifndef EIGEN_EMULATE_CXX11_META_H template EIGEN_DEVICE_FUNC diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h new file mode 100644 index 000000000..6d9cc4f38 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h @@ -0,0 +1,44 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_META_MACROS_H +#define EIGEN_CXX11_TENSOR_TENSOR_META_MACROS_H + + +/** use this macro in sfinae selection in templated functions + * + * template::value , int >::type = 0 + * > + * void foo(){} + * + * becomes => + * + * template::value ) + * > + * void foo(){} + */ + +#ifdef EIGEN_HAS_VARIADIC_TEMPLATES +#define EIGEN_HAS_SFINAE +#endif + +#define EIGEN_SFINAE_ENABLE_IF( __condition__ ) \ + typename internal::enable_if< ( __condition__ ) , int >::type = 0 + + +#if defined(EIGEN_HAS_CONSTEXPR) +#define EIGEN_CONSTEXPR constexpr +#else +#define EIGEN_CONSTEXPR +#endif + + +#endif diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h index 7dfa04760..07735fa5f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h @@ -32,14 +32,6 @@ template <> struct max_n_1<0> { }; - - -#if defined(EIGEN_HAS_CONSTEXPR) -#define EIGEN_CONSTEXPR constexpr -#else -#define EIGEN_CONSTEXPR -#endif - // Tuple mimics std::pair but works on e.g. nvcc. template struct Tuple { public: @@ -83,7 +75,50 @@ bool operator!=(const Tuple& x, const Tuple& y) { return !(x == y); } -#undef EIGEN_CONSTEXPR + + +#ifdef EIGEN_HAS_SFINAE +namespace internal{ + + template + EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + array customIndices2Array(IndexType& idx, numeric_list) { + return { idx[Is]... }; + } + + /** Make an array (for index/dimensions) out of a custom index */ + template + EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + array customIndices2Array(IndexType& idx) { + return customIndices2Array(idx, typename gen_numeric_list::type{}); + } + + + template + struct is_base_of + { + + typedef char (&yes)[1]; + typedef char (&no)[2]; + + template + struct Host + { + operator BB*() const; + operator DD*(); + }; + + template + static yes check(D*, T); + static no check(B*, int); + + static const bool value = sizeof(check(Host(), int())) == sizeof(yes); + }; + +} +#endif + + } // namespace Eigen diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index 7a1737edd..8865892e6 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -142,6 +142,7 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_io "-std=c++0x") ei_add_test(cxx11_tensor_generator "-std=c++0x") ei_add_test(cxx11_tensor_custom_op "-std=c++0x") + ei_add_test(cxx11_tensor_custom_index "-std=c++0x") # These tests needs nvcc # ei_add_test(cxx11_tensor_device "-std=c++0x") diff --git a/unsupported/test/cxx11_tensor_custom_index.cpp b/unsupported/test/cxx11_tensor_custom_index.cpp new file mode 100644 index 000000000..4528cc176 --- /dev/null +++ b/unsupported/test/cxx11_tensor_custom_index.cpp @@ -0,0 +1,100 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Benoit Steiner +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" +#include +#include + +#include +#include + +using Eigen::Tensor; + + +template +static void test_map_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor tensor(2, 3, 5, 7); + tensor.setRandom(); + + using NormalIndex = DSizes; + using CustomIndex = std::map; + CustomIndex coeffC; + coeffC[0] = 1; + coeffC[1] = 2; + coeffC[2] = 4; + coeffC[3] = 1; + NormalIndex coeff(1,2,4,1); + + VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif +} + + +template +static void test_matrix_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor tensor(2, 3, 5, 7); + tensor.setRandom(); + + using NormalIndex = DSizes; + using CustomIndex = Matrix; + CustomIndex coeffC(1,2,4,1); + NormalIndex coeff(1,2,4,1); + + VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif +} + + +template +static void test_varlist_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor tensor(2, 3, 5, 7); + tensor.setRandom(); + + DSizes coeff(1,2,4,1); + + VERIFY_IS_EQUAL(tensor.coeff({1,2,4,1}), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef({1,2,4,1}), tensor.coeffRef(coeff)); +#endif +} + + +template +static void test_sizes_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor tensor(2, 3, 5, 7); + tensor.setRandom(); + + DSizes coeff(1,2,4,1); + Sizes<1,2,4,1> coeffC; + + VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif +} + + +void test_cxx11_tensor_custom_index() { + test_map_as_index(); + test_map_as_index(); + test_matrix_as_index(); + test_matrix_as_index(); + test_varlist_as_index(); + test_varlist_as_index(); + test_sizes_as_index(); + test_sizes_as_index(); +} diff --git a/unsupported/test/cxx11_tensor_simple.cpp b/unsupported/test/cxx11_tensor_simple.cpp index 8cd2ab7fd..0ce92eed9 100644 --- a/unsupported/test/cxx11_tensor_simple.cpp +++ b/unsupported/test/cxx11_tensor_simple.cpp @@ -293,7 +293,3 @@ void test_cxx11_tensor_simple() CALL_SUBTEST(test_simple_assign()); CALL_SUBTEST(test_resize()); } - -/* - * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle; - */