mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 12:19:12 +08:00
added CustomIndex capability only to Tensor and not yet to TensorBase.
using Sfinae and is_base_of to select correct template which converts to array<Index,NumIndices> user: Gabriel Nützi <gnuetzi@gmx.ch> branch 'default' added unsupported/Eigen/CXX11/src/Tensor/TensorMetaMacros.h added unsupported/test/cxx11_tensor_customIndex.cpp changed unsupported/Eigen/CXX11/Tensor changed unsupported/Eigen/CXX11/src/Tensor/Tensor.h changed unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h changed unsupported/test/CMakeLists.txt
This commit is contained in:
parent
b4c79ee1d3
commit
6edae2d30d
@ -59,6 +59,7 @@
|
|||||||
|
|
||||||
#include "src/Tensor/TensorForwardDeclarations.h"
|
#include "src/Tensor/TensorForwardDeclarations.h"
|
||||||
#include "src/Tensor/TensorMeta.h"
|
#include "src/Tensor/TensorMeta.h"
|
||||||
|
#include "src/Tensor/TensorMetaMacros.h"
|
||||||
#include "src/Tensor/TensorDeviceType.h"
|
#include "src/Tensor/TensorDeviceType.h"
|
||||||
#include "src/Tensor/TensorIndexList.h"
|
#include "src/Tensor/TensorIndexList.h"
|
||||||
#include "src/Tensor/TensorDimensionList.h"
|
#include "src/Tensor/TensorDimensionList.h"
|
||||||
|
@ -88,6 +88,11 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
protected:
|
protected:
|
||||||
TensorStorage<Scalar, Dimensions, Options> m_storage;
|
TensorStorage<Scalar, Dimensions, Options> m_storage;
|
||||||
|
|
||||||
|
template<typename CustomIndex>
|
||||||
|
struct isOfNormalIndex{
|
||||||
|
static const bool value = internal::is_base_of< array<Index, NumIndices>, CustomIndex >::value;
|
||||||
|
};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// Metadata
|
// Metadata
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return NumIndices; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return NumIndices; }
|
||||||
@ -111,14 +116,29 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
EIGEN_STATIC_ASSERT(sizeof...(otherIndices) + 2 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
|
||||||
return coeff(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
|
return coeff(array<Index, NumIndices>{{firstIndex, secondIndex, otherIndices...}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
/** Normal Index */
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array<Index, NumIndices>& indices) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array<Index, NumIndices>& indices) const
|
||||||
{
|
{
|
||||||
eigen_internal_assert(checkIndexRange(indices));
|
eigen_internal_assert(checkIndexRange(indices));
|
||||||
return m_storage.data()[linearizedIndex(indices)];
|
return m_storage.data()[linearizedIndex(indices)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Custom Index */
|
||||||
|
template<typename CustomIndex,
|
||||||
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndex>::value) )
|
||||||
|
>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const CustomIndex & indices) const
|
||||||
|
{
|
||||||
|
return coeff(internal::customIndex2Array<Index,NumIndices>(indices));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const
|
||||||
{
|
{
|
||||||
eigen_internal_assert(index >= 0 && index < size());
|
eigen_internal_assert(index >= 0 && index < size());
|
||||||
@ -135,12 +155,23 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/** Normal Index */
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
|
||||||
{
|
{
|
||||||
eigen_internal_assert(checkIndexRange(indices));
|
eigen_internal_assert(checkIndexRange(indices));
|
||||||
return m_storage.data()[linearizedIndex(indices)];
|
return m_storage.data()[linearizedIndex(indices)];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Custom Index */
|
||||||
|
template<typename CustomIndex,
|
||||||
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndex>::value) )
|
||||||
|
>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const CustomIndex & indices)
|
||||||
|
{
|
||||||
|
return coeffRef(internal::customIndex2Array<Index,NumIndices>(indices));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
|
||||||
{
|
{
|
||||||
eigen_internal_assert(index >= 0 && index < size());
|
eigen_internal_assert(index >= 0 && index < size());
|
||||||
@ -178,9 +209,20 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/** Custom Index */
|
||||||
|
template<typename CustomIndex,
|
||||||
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndex>::value) )
|
||||||
|
>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const CustomIndex & indices) const
|
||||||
|
{
|
||||||
|
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
|
||||||
|
return coeff(internal::customIndex2Array<Index,NumIndices>(indices));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Normal Index */
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
|
||||||
{
|
{
|
||||||
eigen_assert(checkIndexRange(indices));
|
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
|
||||||
return coeff(indices);
|
return coeff(indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,12 +270,23 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/** Normal Index */
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
|
||||||
{
|
{
|
||||||
eigen_assert(checkIndexRange(indices));
|
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
|
||||||
return coeffRef(indices);
|
return coeffRef(indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Custom Index */
|
||||||
|
template<typename CustomIndex,
|
||||||
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomIndex>::value) )
|
||||||
|
>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const CustomIndex & indices)
|
||||||
|
{
|
||||||
|
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
|
||||||
|
return coeffRef(internal::customIndex2Array<Index,NumIndices>(indices));
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index index)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index index)
|
||||||
{
|
{
|
||||||
eigen_assert(index >= 0 && index < size());
|
eigen_assert(index >= 0 && index < size());
|
||||||
@ -295,12 +348,20 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/** Normal Dimension */
|
||||||
inline explicit Tensor(const array<Index, NumIndices>& dimensions)
|
inline explicit Tensor(const array<Index, NumIndices>& dimensions)
|
||||||
: m_storage(internal::array_prod(dimensions), dimensions)
|
: m_storage(internal::array_prod(dimensions), dimensions)
|
||||||
{
|
{
|
||||||
EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
|
EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Custom Dimension (delegating constructor c++11) */
|
||||||
|
template<typename CustomDimension,
|
||||||
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) )
|
||||||
|
>
|
||||||
|
inline explicit Tensor(const CustomDimension & dimensions) : Tensor(internal::customIndex2Array<Index,NumIndices>(dimensions))
|
||||||
|
{}
|
||||||
|
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Tensor(const TensorBase<OtherDerived, ReadOnlyAccessors>& other)
|
EIGEN_STRONG_INLINE Tensor(const TensorBase<OtherDerived, ReadOnlyAccessors>& other)
|
||||||
@ -350,6 +411,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/** Normal Dimension */
|
||||||
EIGEN_DEVICE_FUNC void resize(const array<Index, NumIndices>& dimensions)
|
EIGEN_DEVICE_FUNC void resize(const array<Index, NumIndices>& dimensions)
|
||||||
{
|
{
|
||||||
std::size_t i;
|
std::size_t i;
|
||||||
@ -367,6 +429,8 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Why this overload, DSizes is derived from array ??? //
|
||||||
EIGEN_DEVICE_FUNC void resize(const DSizes<Index, NumIndices>& dimensions) {
|
EIGEN_DEVICE_FUNC void resize(const DSizes<Index, NumIndices>& dimensions) {
|
||||||
array<Index, NumIndices> dims;
|
array<Index, NumIndices> dims;
|
||||||
for (std::size_t i = 0; i < NumIndices; ++i) {
|
for (std::size_t i = 0; i < NumIndices; ++i) {
|
||||||
@ -375,6 +439,17 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
|
|||||||
resize(dims);
|
resize(dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Custom Dimension */
|
||||||
|
template<typename CustomDimension,
|
||||||
|
EIGEN_SFINAE_ENABLE_IF( !(isOfNormalIndex<CustomDimension>::value) )
|
||||||
|
>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void resize(const CustomDimension & dimensions)
|
||||||
|
{
|
||||||
|
//eigen_assert(checkIndexRange(indices)); /* already in coeff */
|
||||||
|
return coeffRef(internal::customIndex2Array<Index,NumIndices>(dimensions));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
#ifndef EIGEN_EMULATE_CXX11_META_H
|
#ifndef EIGEN_EMULATE_CXX11_META_H
|
||||||
template <typename std::ptrdiff_t... Indices>
|
template <typename std::ptrdiff_t... Indices>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
|
@ -34,6 +34,9 @@ template <> struct max_n_1<0> {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#if defined(EIGEN_HAS_CONSTEXPR)
|
#if defined(EIGEN_HAS_CONSTEXPR)
|
||||||
#define EIGEN_CONSTEXPR constexpr
|
#define EIGEN_CONSTEXPR constexpr
|
||||||
#else
|
#else
|
||||||
@ -83,6 +86,54 @@ bool operator!=(const Tuple<U, V>& x, const Tuple<U, V>& y) {
|
|||||||
return !(x == y);
|
return !(x == y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
namespace internal{
|
||||||
|
|
||||||
|
template<typename IndexType, Index... Is>
|
||||||
|
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
array<Index,sizeof...(Is)> customIndex2Array(const IndexType & idx, numeric_list<Index,Is...>) {
|
||||||
|
return { idx(Is)... };
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Make an array (for index/dimensions) out of a custom index */
|
||||||
|
template<typename Index, int NumIndices, typename IndexType>
|
||||||
|
EIGEN_CONSTEXPR EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
array<Index,NumIndices> customIndex2Array(const IndexType & idx) {
|
||||||
|
return customIndex2Array(idx, typename gen_numeric_list<Index,NumIndices>::type{});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename B, typename D>
|
||||||
|
struct is_base_of
|
||||||
|
{
|
||||||
|
|
||||||
|
typedef char (&yes)[1];
|
||||||
|
typedef char (&no)[2];
|
||||||
|
|
||||||
|
template <typename BB, typename DD>
|
||||||
|
struct Host
|
||||||
|
{
|
||||||
|
operator BB*() const;
|
||||||
|
operator DD*();
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
static yes check(D*, T);
|
||||||
|
static no check(B*, int);
|
||||||
|
|
||||||
|
static const bool value = sizeof(check(Host<B,D>(), int())) == sizeof(yes);
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#undef EIGEN_CONSTEXPR
|
#undef EIGEN_CONSTEXPR
|
||||||
|
|
||||||
} // namespace Eigen
|
} // namespace Eigen
|
||||||
|
33
unsupported/Eigen/CXX11/src/Tensor/TensorMetaMacros.h
Normal file
33
unsupported/Eigen/CXX11/src/Tensor/TensorMetaMacros.h
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2015 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#ifndef EIGEN_CXX11_TENSOR_TENSOR_META_MACROS_H
|
||||||
|
#define EIGEN_CXX11_TENSOR_TENSOR_META_MACROS_H
|
||||||
|
|
||||||
|
|
||||||
|
/** use this macro in sfinae selection in templated functions
|
||||||
|
*
|
||||||
|
* template<typename T,
|
||||||
|
* typename std::enable_if< isBanana<T>::value , int >::type = 0
|
||||||
|
* >
|
||||||
|
* void foo(){}
|
||||||
|
*
|
||||||
|
* becomes =>
|
||||||
|
*
|
||||||
|
* template<typename TopoType,
|
||||||
|
* SFINAE_ENABLE_IF( isBanana<T>::value )
|
||||||
|
* >
|
||||||
|
* void foo(){}
|
||||||
|
*/
|
||||||
|
|
||||||
|
#define EIGEN_SFINAE_ENABLE_IF( __condition__ ) \
|
||||||
|
typename internal::enable_if< ( __condition__ ) , int >::type = 0
|
||||||
|
|
||||||
|
|
||||||
|
#endif
|
@ -142,6 +142,7 @@ if(EIGEN_TEST_CXX11)
|
|||||||
ei_add_test(cxx11_tensor_io "-std=c++0x")
|
ei_add_test(cxx11_tensor_io "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_generator "-std=c++0x")
|
ei_add_test(cxx11_tensor_generator "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
|
ei_add_test(cxx11_tensor_custom_op "-std=c++0x")
|
||||||
|
ei_add_test(cxx11_tensor_customIndex "-std=c++0x")
|
||||||
|
|
||||||
# These tests needs nvcc
|
# These tests needs nvcc
|
||||||
# ei_add_test(cxx11_tensor_device "-std=c++0x")
|
# ei_add_test(cxx11_tensor_device "-std=c++0x")
|
||||||
|
41
unsupported/test/cxx11_tensor_customIndex.cpp
Normal file
41
unsupported/test/cxx11_tensor_customIndex.cpp
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#include "main.h"
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include <Eigen/Dense>
|
||||||
|
#include <Eigen/CXX11/Tensor>
|
||||||
|
|
||||||
|
using Eigen::Tensor;
|
||||||
|
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_customIndex() {
|
||||||
|
|
||||||
|
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 7);
|
||||||
|
|
||||||
|
using NormalIndex = DSizes<ptrdiff_t, 4>;
|
||||||
|
using CustomIndex = Matrix<unsigned int , 4, 1>;
|
||||||
|
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
CustomIndex coeffC(1,2,4,1);
|
||||||
|
NormalIndex coeff(1,2,4,1);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeff( coeffC ), tensor.coeff( coeff ));
|
||||||
|
VERIFY_IS_EQUAL(tensor.coeffRef( coeffC ), tensor.coeffRef( coeff ));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void test_cxx11_tensor_customIndex() {
|
||||||
|
CALL_SUBTEST(test_customIndex<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_customIndex<RowMajor>());
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user