diff --git a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h index 57d44baf9..3ac465d24 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/Tensor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/Tensor.h @@ -91,7 +91,7 @@ class Tensor : public TensorBase struct isOfNormalIndex{ - static const bool is_array = internal::is_base_of, CustomIndices >::value; + static const bool is_array = internal::is_base_of, CustomIndices>::value; static const bool is_int = NumTraits::IsInteger; static const bool value = is_array | is_int; }; @@ -120,11 +120,8 @@ class Tensor : public TensorBase{{firstIndex, secondIndex, otherIndices...}}); } - - #endif - // normal indices EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array& indices) const { @@ -137,7 +134,7 @@ class Tensor : public TensorBase::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const CustomIndices & indices) const + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(CustomIndices& indices) const { return coeff(internal::customIndices2Array(indices)); } @@ -171,7 +168,7 @@ class Tensor : public TensorBase::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const CustomIndices & indices) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(CustomIndices& indices) { return coeffRef(internal::customIndices2Array(indices)); } @@ -219,7 +216,7 @@ class Tensor : public TensorBase::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const CustomIndices & indices) const + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(CustomIndices& indices) const { return coeff(internal::customIndices2Array(indices)); } @@ -286,7 +283,7 @@ class Tensor : public TensorBase::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const CustomIndices & indices) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(CustomIndices& indices) { return coeffRef(internal::customIndices2Array(indices)); } @@ -441,9 +438,9 @@ class Tensor : public TensorBase::value) ) > - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(const CustomDimension & dimensions) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(CustomDimension& dimensions) { - return coeffRef(internal::customIndices2Array(dimensions)); + resize(internal::customIndices2Array(dimensions)); } #endif diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h index d1efc1a87..07735fa5f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h @@ -82,15 +82,15 @@ namespace internal{ template EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - array customIndices2Array(const IndexType & idx, numeric_list) { - return { idx(Is)... }; + array customIndices2Array(IndexType& idx, numeric_list) { + return { idx[Is]... }; } /** Make an array (for index/dimensions) out of a custom index */ template EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE - array customIndices2Array(const IndexType & idx) { - return customIndices2Array(idx, typename gen_numeric_list::type{}); + array customIndices2Array(IndexType& idx) { + return customIndices2Array(idx, typename gen_numeric_list::type{}); } diff --git a/unsupported/test/cxx11_tensor_custom_index.cpp b/unsupported/test/cxx11_tensor_custom_index.cpp index ff9545a7a..4528cc176 100644 --- a/unsupported/test/cxx11_tensor_custom_index.cpp +++ b/unsupported/test/cxx11_tensor_custom_index.cpp @@ -9,6 +9,7 @@ #include "main.h" #include +#include #include #include @@ -17,22 +18,83 @@ using Eigen::Tensor; template -static void test_custom_index() { - +static void test_map_as_index() +{ +#ifdef EIGEN_HAS_SFINAE Tensor tensor(2, 3, 5, 7); tensor.setRandom(); using NormalIndex = DSizes; - using CustomIndex = Matrix; + using CustomIndex = std::map; + CustomIndex coeffC; + coeffC[0] = 1; + coeffC[1] = 2; + coeffC[2] = 4; + coeffC[3] = 1; + NormalIndex coeff(1,2,4,1); + + VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif +} + + +template +static void test_matrix_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor tensor(2, 3, 5, 7); + tensor.setRandom(); + + using NormalIndex = DSizes; + using CustomIndex = Matrix; CustomIndex coeffC(1,2,4,1); NormalIndex coeff(1,2,4,1); VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif +} + + +template +static void test_varlist_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor tensor(2, 3, 5, 7); + tensor.setRandom(); + + DSizes coeff(1,2,4,1); + + VERIFY_IS_EQUAL(tensor.coeff({1,2,4,1}), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef({1,2,4,1}), tensor.coeffRef(coeff)); +#endif +} + + +template +static void test_sizes_as_index() +{ +#ifdef EIGEN_HAS_SFINAE + Tensor tensor(2, 3, 5, 7); + tensor.setRandom(); + + DSizes coeff(1,2,4,1); + Sizes<1,2,4,1> coeffC; + + VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); + VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); +#endif } void test_cxx11_tensor_custom_index() { - test_custom_index(); - test_custom_index(); + test_map_as_index(); + test_map_as_index(); + test_matrix_as_index(); + test_matrix_as_index(); + test_varlist_as_index(); + test_varlist_as_index(); + test_sizes_as_index(); + test_sizes_as_index(); } diff --git a/unsupported/test/cxx11_tensor_simple.cpp b/unsupported/test/cxx11_tensor_simple.cpp index 8cd2ab7fd..0ce92eed9 100644 --- a/unsupported/test/cxx11_tensor_simple.cpp +++ b/unsupported/test/cxx11_tensor_simple.cpp @@ -293,7 +293,3 @@ void test_cxx11_tensor_simple() CALL_SUBTEST(test_simple_assign()); CALL_SUBTEST(test_resize()); } - -/* - * kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle; - */