Tightened the definition of isOfNormalIndex to take into account integer types in addition to arrays of indices

Only compile the custom index code  when EIGEN_HAS_SFINAE is defined. For the time beeing, EIGEN_HAS_SFINAE is a synonym for EIGEN_HAS_VARIADIC_TEMPLATES, but this might evolve in the future.
Moved some code around.
This commit is contained in:
Benoit Steiner 2015-10-14 09:31:37 -07:00
parent fc7478c04d
commit 6585efc553
6 changed files with 41 additions and 52 deletions

View File

@ -57,9 +57,9 @@
#endif #endif
#include "src/Tensor/TensorMacros.h"
#include "src/Tensor/TensorForwardDeclarations.h" #include "src/Tensor/TensorForwardDeclarations.h"
#include "src/Tensor/TensorMeta.h" #include "src/Tensor/TensorMeta.h"
#include "src/Tensor/TensorMetaMacros.h"
#include "src/Tensor/TensorDeviceType.h" #include "src/Tensor/TensorDeviceType.h"
#include "src/Tensor/TensorIndexList.h" #include "src/Tensor/TensorIndexList.h"
#include "src/Tensor/TensorDimensionList.h" #include "src/Tensor/TensorDimensionList.h"

View File

@ -88,10 +88,14 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
protected: protected:
TensorStorage<Scalar, Dimensions, Options> m_storage; TensorStorage<Scalar, Dimensions, Options> m_storage;
#ifdef EIGEN_HAS_SFINAE
template<typename CustomIndices> template<typename CustomIndices>
struct isOfNormalIndex{ struct isOfNormalIndex{
static const bool value = internal::is_base_of< array<Index, NumIndices>, CustomIndices >::value; static const bool is_array = internal::is_base_of<array<Index, NumIndices>, CustomIndices >::value;
static const bool is_int = NumTraits<CustomIndices>::IsInteger;
static const bool value = is_array | is_int;
}; };
#endif
public: public:
// Metadata // Metadata
@ -129,6 +133,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
} }
// custom indices // custom indices
#ifdef EIGEN_HAS_SFINAE
template<typename CustomIndices, template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
> >
@ -136,8 +141,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
{ {
return coeff(internal::customIndices2Array<Index,NumIndices>(indices)); return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
} }
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const
{ {
@ -163,6 +167,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
} }
// custom indices // custom indices
#ifdef EIGEN_HAS_SFINAE
template<typename CustomIndices, template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
> >
@ -170,7 +175,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
{ {
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices)); return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
} }
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
{ {
@ -210,19 +215,19 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
#endif #endif
// custom indices // custom indices
#ifdef EIGEN_HAS_SFINAE
template<typename CustomIndices, template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
> >
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const CustomIndices & indices) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const CustomIndices & indices) const
{ {
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
return coeff(internal::customIndices2Array<Index,NumIndices>(indices)); return coeff(internal::customIndices2Array<Index,NumIndices>(indices));
} }
#endif
// normal indices // normal indices
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
{ {
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
return coeff(indices); return coeff(indices);
} }
@ -273,19 +278,19 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
// normal indices // normal indices
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
{ {
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
return coeffRef(indices); return coeffRef(indices);
} }
// custom indices // custom indices
#ifdef EIGEN_HAS_SFINAE
template<typename CustomIndices, template<typename CustomIndices,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) ) EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndices>::value) )
> >
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const CustomIndices & indices) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const CustomIndices & indices)
{ {
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices)); return coeffRef(internal::customIndices2Array<Index,NumIndices>(indices));
} }
#endif
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index index) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index index)
{ {
@ -355,13 +360,6 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
} }
/** Custom Dimension (delegating constructor c++11) */
template<typename CustomDimension,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) )
>
inline explicit Tensor(const CustomDimension & dimensions) : Tensor(internal::customIndices2Array<Index,NumIndices>(dimensions))
{}
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Tensor(const TensorBase<OtherDerived, ReadOnlyAccessors>& other) EIGEN_STRONG_INLINE Tensor(const TensorBase<OtherDerived, ReadOnlyAccessors>& other)
@ -429,7 +427,6 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
#endif #endif
} }
// Why this overload, DSizes is derived from array ??? // // Why this overload, DSizes is derived from array ??? //
EIGEN_DEVICE_FUNC void resize(const DSizes<Index, NumIndices>& dimensions) { EIGEN_DEVICE_FUNC void resize(const DSizes<Index, NumIndices>& dimensions) {
array<Index, NumIndices> dims; array<Index, NumIndices> dims;
@ -440,15 +437,15 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
} }
/** Custom Dimension */ /** Custom Dimension */
#ifdef EIGEN_HAS_SFINAE
template<typename CustomDimension, template<typename CustomDimension,
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) ) EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) )
> >
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(const CustomDimension & dimensions) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(const CustomDimension & dimensions)
{ {
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
return coeffRef(internal::customIndices2Array<Index,NumIndices>(dimensions)); return coeffRef(internal::customIndices2Array<Index,NumIndices>(dimensions));
} }
#endif
#ifndef EIGEN_EMULATE_CXX11_META_H #ifndef EIGEN_EMULATE_CXX11_META_H
template <typename std::ptrdiff_t... Indices> template <typename std::ptrdiff_t... Indices>

View File

@ -26,8 +26,19 @@
* void foo(){} * void foo(){}
*/ */
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
#define EIGEN_HAS_SFINAE
#endif
#define EIGEN_SFINAE_ENABLE_IF( __condition__ ) \ #define EIGEN_SFINAE_ENABLE_IF( __condition__ ) \
typename internal::enable_if< ( __condition__ ) , int >::type = 0 typename internal::enable_if< ( __condition__ ) , int >::type = 0
#if defined(EIGEN_HAS_CONSTEXPR)
#define EIGEN_CONSTEXPR constexpr
#else
#define EIGEN_CONSTEXPR
#endif
#endif #endif

View File

@ -32,17 +32,6 @@ template <> struct max_n_1<0> {
}; };
#if defined(EIGEN_HAS_CONSTEXPR)
#define EIGEN_CONSTEXPR constexpr
#else
#define EIGEN_CONSTEXPR
#endif
// Tuple mimics std::pair but works on e.g. nvcc. // Tuple mimics std::pair but works on e.g. nvcc.
template <typename U, typename V> struct Tuple { template <typename U, typename V> struct Tuple {
public: public:
@ -88,7 +77,7 @@ bool operator!=(const Tuple<U, V>& x, const Tuple<U, V>& y) {
#ifdef EIGEN_HAS_SFINAE
namespace internal{ namespace internal{
template<typename IndexType, Index... Is> template<typename IndexType, Index... Is>
@ -127,15 +116,10 @@ namespace internal{
}; };
} }
#endif
#undef EIGEN_CONSTEXPR
} // namespace Eigen } // namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_META_H #endif // EIGEN_CXX11_TENSOR_TENSOR_META_H

View File

@ -142,7 +142,7 @@ if(EIGEN_TEST_CXX11)
ei_add_test(cxx11_tensor_io "-std=c++0x") ei_add_test(cxx11_tensor_io "-std=c++0x")
ei_add_test(cxx11_tensor_generator "-std=c++0x") ei_add_test(cxx11_tensor_generator "-std=c++0x")
ei_add_test(cxx11_tensor_custom_op "-std=c++0x") ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
ei_add_test(cxx11_tensor_customIndex "-std=c++0x") ei_add_test(cxx11_tensor_custom_index "-std=c++0x")
# These tests needs nvcc # These tests needs nvcc
# ei_add_test(cxx11_tensor_device "-std=c++0x") # ei_add_test(cxx11_tensor_device "-std=c++0x")

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library // This file is part of Eigen, a lightweight C++ template library
// for linear algebra. // for linear algebra.
// //
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> // Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
// //
// This Source Code Form is subject to the terms of the Mozilla // This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed // Public License v. 2.0. If a copy of the MPL was not distributed
@ -17,25 +17,22 @@ using Eigen::Tensor;
template <int DataLayout> template <int DataLayout>
static void test_customIndex() { static void test_custom_index() {
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7); Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
tensor.setRandom();
using NormalIndex = DSizes<ptrdiff_t, 4>; using NormalIndex = DSizes<ptrdiff_t, 4>;
using CustomIndex = Matrix<unsigned int , 4, 1>; using CustomIndex = Matrix<unsigned int , 4, 1>;
tensor.setRandom();
CustomIndex coeffC(1,2,4,1); CustomIndex coeffC(1,2,4,1);
NormalIndex coeff(1,2,4,1); NormalIndex coeff(1,2,4,1);
VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff)); VERIFY_IS_EQUAL(tensor.coeff(coeffC), tensor.coeff(coeff));
VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff)); VERIFY_IS_EQUAL(tensor.coeffRef(coeffC), tensor.coeffRef(coeff));
} }
void test_cxx11_tensor_customIndex() { void test_cxx11_tensor_custom_index() {
CALL_SUBTEST(test_customIndex<ColMajor>()); test_custom_index<ColMajor>();
CALL_SUBTEST(test_customIndex<RowMajor>()); test_custom_index<RowMajor>();
} }