mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Added support for tensor references
This commit is contained in:
parent
f786897e4b
commit
debc97821c
@ -76,6 +76,8 @@
|
|||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h"
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorMap.h"
|
||||||
|
|
||||||
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorRef.h"
|
||||||
|
|
||||||
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
|
#include "unsupported/Eigen/CXX11/src/Tensor/TensorIO.h"
|
||||||
|
|
||||||
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
|
#include "Eigen/src/Core/util/ReenableStupidWarnings.h"
|
||||||
|
@ -15,6 +15,7 @@ namespace Eigen {
|
|||||||
template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor;
|
template<typename Scalar_, std::size_t NumIndices_, int Options_ = 0> class Tensor;
|
||||||
template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize;
|
template<typename Scalar_, typename Dimensions, int Options_ = 0> class TensorFixedSize;
|
||||||
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
|
template<typename PlainObjectType, int Options_ = Unaligned> class TensorMap;
|
||||||
|
template<typename PlainObjectType> class TensorRef;
|
||||||
template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase;
|
template<typename Derived, int AccessLevel = internal::accessors_level<Derived>::value> class TensorBase;
|
||||||
|
|
||||||
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
|
template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryOp;
|
||||||
|
360
unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
Normal file
360
unsupported/Eigen/CXX11/src/Tensor/TensorRef.h
Normal file
@ -0,0 +1,360 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#ifndef EIGEN_CXX11_TENSOR_TENSOR_REF_H
|
||||||
|
#define EIGEN_CXX11_TENSOR_TENSOR_REF_H
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template <typename Dimensions, typename Scalar>
|
||||||
|
class TensorLazyBaseEvaluator {
|
||||||
|
public:
|
||||||
|
TensorLazyBaseEvaluator() : m_refcount(0) { }
|
||||||
|
virtual ~TensorLazyBaseEvaluator() { }
|
||||||
|
|
||||||
|
virtual const Dimensions& dimensions() const = 0;
|
||||||
|
virtual const Scalar* data() const = 0;
|
||||||
|
|
||||||
|
virtual const Scalar coeff(DenseIndex index) const = 0;
|
||||||
|
virtual Scalar& coeffRef(DenseIndex index) = 0;
|
||||||
|
|
||||||
|
void incrRefCount() { ++m_refcount; }
|
||||||
|
void decrRefCount() { --m_refcount; }
|
||||||
|
int refCount() const { return m_refcount; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
// No copy, no assigment;
|
||||||
|
TensorLazyBaseEvaluator(const TensorLazyBaseEvaluator& other);
|
||||||
|
TensorLazyBaseEvaluator& operator = (const TensorLazyBaseEvaluator& other);
|
||||||
|
|
||||||
|
int m_refcount;
|
||||||
|
};
|
||||||
|
|
||||||
|
static char dummy[8];
|
||||||
|
|
||||||
|
template <typename Dimensions, typename Expr, typename Device>
|
||||||
|
class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, typename TensorEvaluator<Expr, Device>::Scalar> {
|
||||||
|
public:
|
||||||
|
// typedef typename TensorEvaluator<Expr, Device>::Dimensions Dimensions;
|
||||||
|
typedef typename TensorEvaluator<Expr, Device>::Scalar Scalar;
|
||||||
|
|
||||||
|
TensorLazyEvaluatorReadOnly(const Expr& expr, const Device& device) : m_impl(expr, device) {
|
||||||
|
m_dims = m_impl.dimensions();
|
||||||
|
m_impl.evalSubExprsIfNeeded(NULL);
|
||||||
|
}
|
||||||
|
virtual ~TensorLazyEvaluatorReadOnly() {
|
||||||
|
m_impl.cleanup();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const Dimensions& dimensions() const {
|
||||||
|
return m_dims;
|
||||||
|
}
|
||||||
|
virtual const Scalar* data() const {
|
||||||
|
return m_impl.data();
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual const Scalar coeff(DenseIndex index) const {
|
||||||
|
return m_impl.coeff(index);
|
||||||
|
}
|
||||||
|
virtual Scalar& coeffRef(DenseIndex index) {
|
||||||
|
eigen_assert(false && "can't reference the coefficient of a rvalue");
|
||||||
|
return *reinterpret_cast<Scalar*>(dummy);
|
||||||
|
};
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TensorEvaluator<Expr, Device> m_impl;
|
||||||
|
Dimensions m_dims;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Dimensions, typename Expr, typename Device>
|
||||||
|
class TensorLazyEvaluatorWritable : public TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> {
|
||||||
|
public:
|
||||||
|
typedef TensorLazyEvaluatorReadOnly<Dimensions, Expr, Device> Base;
|
||||||
|
typedef typename Base::Scalar Scalar;
|
||||||
|
|
||||||
|
TensorLazyEvaluatorWritable(const Expr& expr, const Device& device) : Base(expr, device) {
|
||||||
|
}
|
||||||
|
virtual ~TensorLazyEvaluatorWritable() {
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual Scalar& coeffRef(DenseIndex index) {
|
||||||
|
return this->m_impl.coeffRef(index);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Dimensions, typename Expr, typename Device>
|
||||||
|
class TensorLazyEvaluator : public internal::conditional<bool(internal::is_lvalue<Expr>::value),
|
||||||
|
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
||||||
|
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type {
|
||||||
|
public:
|
||||||
|
typedef typename internal::conditional<bool(internal::is_lvalue<Expr>::value),
|
||||||
|
TensorLazyEvaluatorWritable<Dimensions, Expr, Device>,
|
||||||
|
TensorLazyEvaluatorReadOnly<Dimensions, const Expr, Device> >::type Base;
|
||||||
|
typedef typename Base::Scalar Scalar;
|
||||||
|
|
||||||
|
TensorLazyEvaluator(const Expr& expr, const Device& device) : Base(expr, device) {
|
||||||
|
}
|
||||||
|
virtual ~TensorLazyEvaluator() {
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
|
||||||
|
/** \class TensorRef
|
||||||
|
* \ingroup CXX11_Tensor_Module
|
||||||
|
*
|
||||||
|
* \brief A reference to a tensor expression
|
||||||
|
* The expression will be evaluated lazily (as much as possible).
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef<PlainObjectType> >
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef TensorRef<PlainObjectType> Self;
|
||||||
|
typedef typename PlainObjectType::Base Base;
|
||||||
|
typedef typename Eigen::internal::nested<Self>::type Nested;
|
||||||
|
typedef typename internal::traits<PlainObjectType>::StorageKind StorageKind;
|
||||||
|
typedef typename internal::traits<PlainObjectType>::Index Index;
|
||||||
|
typedef typename internal::traits<PlainObjectType>::Scalar Scalar;
|
||||||
|
typedef typename internal::packet_traits<Scalar>::type Packet;
|
||||||
|
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||||
|
typedef typename Base::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef Scalar* PointerType;
|
||||||
|
typedef PointerType PointerArgType;
|
||||||
|
|
||||||
|
static const Index NumIndices = PlainObjectType::NumIndices;
|
||||||
|
typedef typename PlainObjectType::Dimensions Dimensions;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = false,
|
||||||
|
PacketAccess = false,
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Expression>
|
||||||
|
EIGEN_STRONG_INLINE TensorRef(const Expression& expr) : m_evaluator(new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice())) {
|
||||||
|
m_evaluator->incrRefCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Expression>
|
||||||
|
EIGEN_STRONG_INLINE TensorRef& operator = (const Expression& expr) {
|
||||||
|
unrefEvaluator();
|
||||||
|
m_evaluator = new internal::TensorLazyEvaluator<Dimensions, Expression, DefaultDevice>(expr, DefaultDevice());
|
||||||
|
m_evaluator->incrRefCount();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
~TensorRef() {
|
||||||
|
unrefEvaluator();
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorRef(const TensorRef& other) : m_evaluator(other.m_evaluator) {
|
||||||
|
eigen_assert(m_evaluator->refCount() > 0);
|
||||||
|
m_evaluator->incrRefCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorRef& operator = (const TensorRef& other) {
|
||||||
|
if (this != &other) {
|
||||||
|
unrefEvaluator();
|
||||||
|
m_evaluator = other.m_evaluator;
|
||||||
|
eigen_assert(m_evaluator->refCount() > 0);
|
||||||
|
m_evaluator->incrRefCount();
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_evaluator->dimensions(); }
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Index size() const { return m_evaluator->dimensions().TotalSize(); }
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar* data() const { return m_evaluator->data(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar operator()(Index index) const
|
||||||
|
{
|
||||||
|
return m_evaluator->coeff(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
|
||||||
|
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar operator()(Index firstIndex, IndexTypes... otherIndices) const
|
||||||
|
{
|
||||||
|
const std::size_t NumIndices = (sizeof...(otherIndices) + 1);
|
||||||
|
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
|
||||||
|
return coeff(indices);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1) const
|
||||||
|
{
|
||||||
|
array<Index, 2> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
return coeff(indices);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2) const
|
||||||
|
{
|
||||||
|
array<Index, 3> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
indices[2] = i2;
|
||||||
|
return coeff(indices);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3) const
|
||||||
|
{
|
||||||
|
array<Index, 4> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
indices[2] = i2;
|
||||||
|
indices[3] = i3;
|
||||||
|
return coeff(indices);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar operator()(Index i0, Index i1, Index i2, Index i3, Index i4) const
|
||||||
|
{
|
||||||
|
array<Index, 5> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
indices[2] = i2;
|
||||||
|
indices[3] = i3;
|
||||||
|
indices[4] = i4;
|
||||||
|
return coeff(indices);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar coeff(const array<Index, NumIndices>& indices) const
|
||||||
|
{
|
||||||
|
const Dimensions& dims = this->dimensions();
|
||||||
|
Index index = 0;
|
||||||
|
if (PlainObjectType::Options&RowMajor) {
|
||||||
|
index += indices[0];
|
||||||
|
for (int i = 1; i < NumIndices; ++i) {
|
||||||
|
index = index * dims[i] + indices[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
index += indices[NumIndices-1];
|
||||||
|
for (int i = NumIndices-2; i >= 0; --i) {
|
||||||
|
index = index * dims[i] + indices[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m_evaluator->coeff(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
||||||
|
{
|
||||||
|
return m_evaluator->coeff(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
|
||||||
|
{
|
||||||
|
return m_evaluator->coeffRef(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
EIGEN_STRONG_INLINE void unrefEvaluator() {
|
||||||
|
if (m_evaluator) {
|
||||||
|
m_evaluator->decrRefCount();
|
||||||
|
if (m_evaluator->refCount() == 0) {
|
||||||
|
delete m_evaluator;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
internal::TensorLazyBaseEvaluator<Dimensions, Scalar>* m_evaluator;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// evaluator for rvalues
|
||||||
|
template<typename Derived, typename Device>
|
||||||
|
struct TensorEvaluator<const TensorRef<Derived>, Device>
|
||||||
|
{
|
||||||
|
typedef typename Derived::Index Index;
|
||||||
|
typedef typename Derived::Scalar Scalar;
|
||||||
|
typedef typename Derived::Packet Packet;
|
||||||
|
typedef typename Derived::Scalar CoeffReturnType;
|
||||||
|
typedef typename Derived::Packet PacketReturnType;
|
||||||
|
typedef typename Derived::Dimensions Dimensions;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = false,
|
||||||
|
PacketAccess = false,
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
|
||||||
|
: m_ref(m)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_ref.dimensions(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
|
||||||
|
return m_ref.coeff(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
|
||||||
|
return m_ref.coeffRef(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
Scalar* data() const { return m_ref.data(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
TensorRef<Derived> m_ref;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// evaluator for lvalues
|
||||||
|
template<typename Derived, typename Device>
|
||||||
|
struct TensorEvaluator<TensorRef<Derived>, Device> : public TensorEvaluator<const TensorRef<Derived>, Device>
|
||||||
|
{
|
||||||
|
typedef typename Derived::Index Index;
|
||||||
|
typedef typename Derived::Scalar Scalar;
|
||||||
|
typedef typename Derived::Packet Packet;
|
||||||
|
typedef typename Derived::Scalar CoeffReturnType;
|
||||||
|
typedef typename Derived::Packet PacketReturnType;
|
||||||
|
typedef typename Derived::Dimensions Dimensions;
|
||||||
|
|
||||||
|
typedef TensorEvaluator<const TensorRef<Derived>, Device> Base;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = false,
|
||||||
|
PacketAccess = false,
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(TensorRef<Derived>& m, const Device& d) : Base(m, d)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
|
||||||
|
return this->m_ref.coeffRef(index);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
#endif // EIGEN_CXX11_TENSOR_TENSOR_REF_H
|
@ -84,6 +84,20 @@ struct traits<TensorMap<PlainObjectType, Options_> >
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename PlainObjectType>
|
||||||
|
struct traits<TensorRef<PlainObjectType> >
|
||||||
|
: public traits<PlainObjectType>
|
||||||
|
{
|
||||||
|
typedef traits<PlainObjectType> BaseTraits;
|
||||||
|
typedef typename BaseTraits::Scalar Scalar;
|
||||||
|
typedef typename BaseTraits::StorageKind StorageKind;
|
||||||
|
typedef typename BaseTraits::Index Index;
|
||||||
|
enum {
|
||||||
|
Options = BaseTraits::Options,
|
||||||
|
Flags = ((BaseTraits::Flags | LvalueBit) & ~AlignedBit) | (Options&Aligned ? AlignedBit : 0),
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template<typename _Scalar, std::size_t NumIndices_, int Options>
|
template<typename _Scalar, std::size_t NumIndices_, int Options>
|
||||||
struct eval<Tensor<_Scalar, NumIndices_, Options>, Eigen::Dense>
|
struct eval<Tensor<_Scalar, NumIndices_, Options>, Eigen::Dense>
|
||||||
@ -121,6 +135,19 @@ struct eval<const TensorMap<PlainObjectType, Options>, Eigen::Dense>
|
|||||||
typedef const TensorMap<PlainObjectType, Options>& type;
|
typedef const TensorMap<PlainObjectType, Options>& type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename PlainObjectType>
|
||||||
|
struct eval<TensorRef<PlainObjectType>, Eigen::Dense>
|
||||||
|
{
|
||||||
|
typedef const TensorRef<PlainObjectType>& type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename PlainObjectType>
|
||||||
|
struct eval<const TensorRef<PlainObjectType>, Eigen::Dense>
|
||||||
|
{
|
||||||
|
typedef const TensorRef<PlainObjectType>& type;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template <typename Scalar_, std::size_t NumIndices_, int Options_>
|
template <typename Scalar_, std::size_t NumIndices_, int Options_>
|
||||||
struct nested<Tensor<Scalar_, NumIndices_, Options_>, 1, typename eval<Tensor<Scalar_, NumIndices_, Options_> >::type>
|
struct nested<Tensor<Scalar_, NumIndices_, Options_>, 1, typename eval<Tensor<Scalar_, NumIndices_, Options_> >::type>
|
||||||
{
|
{
|
||||||
@ -145,6 +172,7 @@ struct nested<const TensorFixedSize<Scalar_, Dimensions, Options>, 1, typename e
|
|||||||
typedef const TensorFixedSize<Scalar_, Dimensions, Options>& type;
|
typedef const TensorFixedSize<Scalar_, Dimensions, Options>& type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template <typename PlainObjectType, int Options>
|
template <typename PlainObjectType, int Options>
|
||||||
struct nested<TensorMap<PlainObjectType, Options>, 1, typename eval<TensorMap<PlainObjectType, Options> >::type>
|
struct nested<TensorMap<PlainObjectType, Options>, 1, typename eval<TensorMap<PlainObjectType, Options> >::type>
|
||||||
{
|
{
|
||||||
@ -157,6 +185,18 @@ struct nested<const TensorMap<PlainObjectType, Options>, 1, typename eval<Tensor
|
|||||||
typedef const TensorMap<PlainObjectType, Options>& type;
|
typedef const TensorMap<PlainObjectType, Options>& type;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename PlainObjectType>
|
||||||
|
struct nested<TensorRef<PlainObjectType>, 1, typename eval<TensorRef<PlainObjectType> >::type>
|
||||||
|
{
|
||||||
|
typedef const TensorRef<PlainObjectType>& type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename PlainObjectType>
|
||||||
|
struct nested<const TensorRef<PlainObjectType>, 1, typename eval<TensorRef<PlainObjectType> >::type>
|
||||||
|
{
|
||||||
|
typedef const TensorRef<PlainObjectType>& type;
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
@ -126,5 +126,6 @@ if(EIGEN_TEST_CXX11)
|
|||||||
ei_add_test(cxx11_tensor_striding "-std=c++0x")
|
ei_add_test(cxx11_tensor_striding "-std=c++0x")
|
||||||
# ei_add_test(cxx11_tensor_device "-std=c++0x")
|
# ei_add_test(cxx11_tensor_device "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_thread_pool "-std=c++0x")
|
ei_add_test(cxx11_tensor_thread_pool "-std=c++0x")
|
||||||
|
ei_add_test(cxx11_tensor_ref "-std=c++0x")
|
||||||
ei_add_test(cxx11_tensor_io "-std=c++0x")
|
ei_add_test(cxx11_tensor_io "-std=c++0x")
|
||||||
endif()
|
endif()
|
||||||
|
192
unsupported/test/cxx11_tensor_ref.cpp
Normal file
192
unsupported/test/cxx11_tensor_ref.cpp
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#include "main.h"
|
||||||
|
|
||||||
|
#include <Eigen/CXX11/Tensor>
|
||||||
|
|
||||||
|
using Eigen::Tensor;
|
||||||
|
using Eigen::RowMajor;
|
||||||
|
|
||||||
|
static void test_simple_lvalue_ref()
|
||||||
|
{
|
||||||
|
Tensor<int, 1> input(6);
|
||||||
|
input.setRandom();
|
||||||
|
|
||||||
|
TensorRef<Tensor<int, 1>> ref3(input);
|
||||||
|
TensorRef<Tensor<int, 1>> ref4 = input;
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(ref3.data(), input.data());
|
||||||
|
VERIFY_IS_EQUAL(ref4.data(), input.data());
|
||||||
|
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
VERIFY_IS_EQUAL(ref3(i), input(i));
|
||||||
|
VERIFY_IS_EQUAL(ref4(i), input(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
ref3.coeffRef(i) = i;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
VERIFY_IS_EQUAL(input(i), i);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
ref4.coeffRef(i) = -i * 2;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
VERIFY_IS_EQUAL(input(i), -i*2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_simple_rvalue_ref()
|
||||||
|
{
|
||||||
|
Tensor<int, 1> input1(6);
|
||||||
|
input1.setRandom();
|
||||||
|
Tensor<int, 1> input2(6);
|
||||||
|
input2.setRandom();
|
||||||
|
|
||||||
|
TensorRef<Tensor<int, 1>> ref3(input1 + input2);
|
||||||
|
TensorRef<Tensor<int, 1>> ref4 = input1 + input2;
|
||||||
|
|
||||||
|
VERIFY_IS_NOT_EQUAL(ref3.data(), input1.data());
|
||||||
|
VERIFY_IS_NOT_EQUAL(ref4.data(), input1.data());
|
||||||
|
VERIFY_IS_NOT_EQUAL(ref3.data(), input2.data());
|
||||||
|
VERIFY_IS_NOT_EQUAL(ref4.data(), input2.data());
|
||||||
|
|
||||||
|
for (int i = 0; i < 6; ++i) {
|
||||||
|
VERIFY_IS_EQUAL(ref3(i), input1(i) + input2(i));
|
||||||
|
VERIFY_IS_EQUAL(ref4(i), input1(i) + input2(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_multiple_dims()
|
||||||
|
{
|
||||||
|
Tensor<float, 3> input(3,5,7);
|
||||||
|
input.setRandom();
|
||||||
|
|
||||||
|
TensorRef<Tensor<float, 3>> ref(input);
|
||||||
|
VERIFY_IS_EQUAL(ref.data(), input.data());
|
||||||
|
VERIFY_IS_EQUAL(ref.dimension(0), 3);
|
||||||
|
VERIFY_IS_EQUAL(ref.dimension(1), 5);
|
||||||
|
VERIFY_IS_EQUAL(ref.dimension(2), 7);
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
for (int j = 0; j < 5; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(ref(i,j,k), input(i,j,k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_slice()
|
||||||
|
{
|
||||||
|
Tensor<float, 5> tensor(2,3,5,7,11);
|
||||||
|
tensor.setRandom();
|
||||||
|
|
||||||
|
Eigen::DSizes<ptrdiff_t, 5> indices(1,2,3,4,5);
|
||||||
|
Eigen::DSizes<ptrdiff_t, 5> sizes(1,1,1,1,1);
|
||||||
|
TensorRef<Tensor<float, 5>> slice = tensor.slice(indices, sizes);
|
||||||
|
VERIFY_IS_EQUAL(slice(0,0,0,0,0), tensor(1,2,3,4,5));
|
||||||
|
|
||||||
|
Eigen::DSizes<ptrdiff_t, 5> indices2(1,1,3,4,5);
|
||||||
|
Eigen::DSizes<ptrdiff_t, 5> sizes2(1,1,2,2,3);
|
||||||
|
slice = tensor.slice(indices2, sizes2);
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 2; ++j) {
|
||||||
|
for (int k = 0; k < 3; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(slice(0,0,i,j,k), tensor(1,1,3+i,4+j,5+k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Eigen::DSizes<ptrdiff_t, 5> indices3(0,0,0,0,0);
|
||||||
|
Eigen::DSizes<ptrdiff_t, 5> sizes3(2,3,1,1,1);
|
||||||
|
slice = tensor.slice(indices3, sizes3);
|
||||||
|
VERIFY_IS_EQUAL(slice.data(), tensor.data());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_ref_of_ref()
|
||||||
|
{
|
||||||
|
Tensor<float, 3> input(3,5,7);
|
||||||
|
input.setRandom();
|
||||||
|
|
||||||
|
TensorRef<Tensor<float, 3>> ref(input);
|
||||||
|
TensorRef<Tensor<float, 3>> ref_of_ref(ref);
|
||||||
|
TensorRef<Tensor<float, 3>> ref_of_ref2;
|
||||||
|
ref_of_ref2 = ref;
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref.data(), input.data());
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref.dimension(0), 3);
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref.dimension(1), 5);
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref.dimension(2), 7);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref2.data(), input.data());
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref2.dimension(0), 3);
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref2.dimension(1), 5);
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref2.dimension(2), 7);
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
for (int j = 0; j < 5; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref(i,j,k), input(i,j,k));
|
||||||
|
VERIFY_IS_EQUAL(ref_of_ref2(i,j,k), input(i,j,k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_ref_in_expr()
|
||||||
|
{
|
||||||
|
Tensor<float, 3> input(3,5,7);
|
||||||
|
input.setRandom();
|
||||||
|
TensorRef<Tensor<float, 3>> input_ref(input);
|
||||||
|
|
||||||
|
Tensor<float, 3> result(3,5,7);
|
||||||
|
result.setRandom();
|
||||||
|
TensorRef<Tensor<float, 3>> result_ref(result);
|
||||||
|
|
||||||
|
Tensor<float, 3> bias(3,5,7);
|
||||||
|
bias.setRandom();
|
||||||
|
|
||||||
|
result_ref = input_ref + bias;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
for (int j = 0; j < 5; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(result_ref(i,j,k), input(i,j,k) + bias(i,j,k));
|
||||||
|
VERIFY_IS_NOT_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = result_ref;
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
for (int j = 0; j < 5; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(result(i,j,k), input(i,j,k) + bias(i,j,k));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void test_cxx11_tensor_ref()
|
||||||
|
{
|
||||||
|
CALL_SUBTEST(test_simple_lvalue_ref());
|
||||||
|
CALL_SUBTEST(test_simple_rvalue_ref());
|
||||||
|
CALL_SUBTEST(test_multiple_dims());
|
||||||
|
CALL_SUBTEST(test_slice());
|
||||||
|
CALL_SUBTEST(test_ref_of_ref());
|
||||||
|
CALL_SUBTEST(test_ref_in_expr());
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user