mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Updated the custom indexing code: we can now use any container that provides the [] operator to index a tensor. Added unit tests to validate the use of std::map and a few more types as valid custom index containers
This commit is contained in:
parent
6585efc553
commit
de1e9f29f4
@ -91,7 +91,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
#ifdef EIGEN_HAS_SFINAE
|
#ifdef EIGEN_HAS_SFINAE
|
||||||
template<typename CustomIndices>
|
template<typename CustomIndices>
|
||||||
struct isOfNormalIndex{
|
struct isOfNormalIndex{
|
||||||
static const bool is_array = internal::is_base_of<array<Index, NumIndices>, CustomIndices >::value;
|
static const bool is_array = internal::is_base_of<array<Index, NumIndices>, CustomIndices>::value;
|
||||||
static const bool is_int = NumTraits<CustomIndices>::IsInteger;
|
static const bool is_int = NumTraits<CustomIndices>::IsInteger;
|
||||||
static const bool value = is_array | is_int;
|
static const bool value = is_array | is_int;
|
||||||
};
|
};
|
||||||
@ -120,11 +120,8 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
return coeff(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
|
return coeff(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
// normal indices
|
// normal indices
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array<Index, NumIndices>& indices) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array<Index, NumIndices>& indices) const
|
||||||
{
|
{
|
||||||
@ -137,7 +134,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
template<typename CustomIndices,
|
template<typename CustomIndices,
|
||||||
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
||||||
>
|
>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const CustomIndices & indices) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(CustomIndices& indices) const
|
||||||
{
|
{
|
||||||
return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
|
return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
|
||||||
}
|
}
|
||||||
@ -171,7 +168,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
template<typename CustomIndices,
|
template<typename CustomIndices,
|
||||||
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
||||||
>
|
>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const CustomIndices & indices)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(CustomIndices& indices)
|
||||||
{
|
{
|
||||||
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
|
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
|
||||||
}
|
}
|
||||||
@ -219,7 +216,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
template<typename CustomIndices,
|
template<typename CustomIndices,
|
||||||
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
||||||
>
|
>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const CustomIndices & indices) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(CustomIndices& indices) const
|
||||||
{
|
{
|
||||||
return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
|
return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
|
||||||
}
|
}
|
||||||
@ -286,7 +283,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
template<typename CustomIndices,
|
template<typename CustomIndices,
|
||||||
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
|
||||||
>
|
>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const CustomIndices & indices)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(CustomIndices& indices)
|
||||||
{
|
{
|
||||||
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
|
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
|
||||||
}
|
}
|
||||||
@ -441,9 +438,9 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
template<typename CustomDimension,
|
template<typename CustomDimension,
|
||||||
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) )
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) )
|
||||||
>
|
>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(const CustomDimension & dimensions)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(CustomDimension& dimensions)
|
||||||
{
|
{
|
||||||
return coeffRef(internal::customIndices2Array<Index,NumIndices>(dimensions));
|
resize(internal::customIndices2Array<Index,NumIndices>(dimensions));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -82,15 +82,15 @@ namespace internal{
|
|||||||
|
|
||||||
template<typename IndexType, Index... Is>
|
template<typename IndexType, Index... Is>
|
||||||
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
array<Index,sizeof...(Is)> customIndices2Array(const IndexType & idx, numeric_list<Index,Is...>) {
|
array<Index, sizeof...(Is)> customIndices2Array(IndexType& idx, numeric_list<Index, Is...>) {
|
||||||
return { idx(Is)... };
|
return { idx[Is]... };
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Make an array (for index/dimensions) out of a custom index */
|
/** Make an array (for index/dimensions) out of a custom index */
|
||||||
template<typename Index, int NumIndices, typename IndexType>
|
template<typename Index, int NumIndices, typename IndexType>
|
||||||
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
array<Index,NumIndices> customIndices2Array(const IndexType & idx) {
|
array<Index, NumIndices> customIndices2Array(IndexType& idx) {
|
||||||
return customIndices2Array(idx, typename gen_numeric_list<Index,NumIndices>::type{});
|
return customIndices2Array(idx, typename gen_numeric_list<Index, NumIndices>::type{});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
#include "main.h"
|
#include "main.h"
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
#include <Eigen/Dense>
|
#include <Eigen/Dense>
|
||||||
#include <Eigen/CXX11/Tensor>
|
#include <Eigen/CXX11/Tensor>
|
||||||
@ -17,22 +18,83 @@ using Eigen::Tensor;
|
|||||||
|
|
||||||
|
|
||||||
template <int DataLayout>
|
template <int DataLayout>
|
||||||
static void test_custom_index() {
|
static void test_map_as_index()
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_HAS_SFINAE
|
||||||
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
||||||
tensor.setRandom();
|
tensor.setRandom();
|
||||||
|
|
||||||
using NormalIndex = DSizes<ptrdiff_t, 4>;
|
using NormalIndex = DSizes<ptrdiff_t, 4>;
|
||||||
using CustomIndex = Matrix<unsigned int , 4, 1>;
|
using CustomIndex = std::map<ptrdiff_t, ptrdiff_t>;
|
||||||
|
CustomIndex coeffC;
|
||||||
|
coeffC[0] = 1;
|
||||||
|
coeffC[1] = 2;
|
||||||
|
coeffC[2] = 4;
|
||||||
|
coeffC[3] = 1;
|
||||||
|
NormalIndex coeff(1,2,4,1);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_matrix_as_index()
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_HAS_SFINAE
|
||||||
|
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
using NormalIndex = DSizes<ptrdiff_t, 4>;
|
||||||
|
using CustomIndex = Matrix<unsigned int, 4, 1>;
|
||||||
CustomIndex coeffC(1,2,4,1);
|
CustomIndex coeffC(1,2,4,1);
|
||||||
NormalIndex coeff(1,2,4,1);
|
NormalIndex coeff(1,2,4,1);
|
||||||
|
|
||||||
VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
|
VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
|
||||||
VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
|
VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_varlist_as_index()
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_HAS_SFINAE
|
||||||
|
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
DSizes<ptrdiff_t, 4> coeff(1,2,4,1);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeff({1,2,4,1}), tensor.coeff(coeff));
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeffRef({1,2,4,1}), tensor.coeffRef(coeff));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_sizes_as_index()
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_HAS_SFINAE
|
||||||
|
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
DSizes<ptrdiff_t, 4> coeff(1,2,4,1);
|
||||||
|
Sizes<1,2,4,1> coeffC;
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void test_cxx11_tensor_custom_index() {
|
void test_cxx11_tensor_custom_index() {
|
||||||
test_custom_index<ColMajor>();
|
test_map_as_index<ColMajor>();
|
||||||
test_custom_index<RowMajor>();
|
test_map_as_index<RowMajor>();
|
||||||
|
test_matrix_as_index<ColMajor>();
|
||||||
|
test_matrix_as_index<RowMajor>();
|
||||||
|
test_varlist_as_index<ColMajor>();
|
||||||
|
test_varlist_as_index<RowMajor>();
|
||||||
|
test_sizes_as_index<ColMajor>();
|
||||||
|
test_sizes_as_index<RowMajor>();
|
||||||
}
|
}
|
||||||
|
@ -293,7 +293,3 @@ void test_cxx11_tensor_simple()
|
|||||||
CALL_SUBTEST(test_simple_assign());
|
CALL_SUBTEST(test_simple_assign());
|
||||||
CALL_SUBTEST(test_resize());
|
CALL_SUBTEST(test_resize());
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
|
||||||
* kate: space-indent on; indent-width 2; mixedindent off; indent-mode cstyle;
|
|
||||||
*/
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user