Added support for tensor contractions

Updated expression evaluation mechanism to also compute the size of the tensor result
Misc fixes and improvements.
This commit is contained in:
Benoit Steiner 2014-06-04 09:21:48 -07:00
parent 736267cf6b
commit 6fa6cdd2b9
14 changed files with 370 additions and 96 deletions

View File

@ -39,6 +39,7 @@
#include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h"
#include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h"

View File

@ -23,6 +23,8 @@ template <typename T, size_t n> class array {
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const T& operator[] (size_t index) const { return values[index]; } EIGEN_STRONG_INLINE const T& operator[] (size_t index) const { return values[index]; }
static const std::size_t size = n;
T values[n]; T values[n];
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC

View File

@ -81,7 +81,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
typedef typename Base::PacketReturnType PacketReturnType; typedef typename Base::PacketReturnType PacketReturnType;
enum { enum {
IsAligned = bool(EIGEN_ALIGN), IsAligned = bool(EIGEN_ALIGN) & !(Options_&DontAlign),
PacketAccess = true, PacketAccess = true,
}; };
@ -94,11 +94,11 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
TensorStorage<Scalar, NumIndices, Dynamic, Options> m_storage; TensorStorage<Scalar, NumIndices, Dynamic, Options> m_storage;
public: public:
EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return m_storage.dimensions()[n]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return m_storage.dimensions()[n]; }
EIGEN_STRONG_INLINE const DSizes<DenseIndex, NumIndices_>& dimensions() const { return m_storage.dimensions(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DSizes<DenseIndex, NumIndices_>& dimensions() const { return m_storage.dimensions(); }
EIGEN_STRONG_INLINE Index size() const { return m_storage.size(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_storage.size(); }
EIGEN_STRONG_INLINE Scalar *data() { return m_storage.data(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar *data() { return m_storage.data(); }
EIGEN_STRONG_INLINE const Scalar *data() const { return m_storage.data(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar *data() const { return m_storage.data(); }
// This makes EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED // This makes EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
// work, because that uses base().coeffRef() - and we don't yet // work, because that uses base().coeffRef() - and we don't yet
@ -116,13 +116,13 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
} }
#endif #endif
inline const Scalar& coeff(const array<Index, NumIndices>& indices) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(const array<Index, NumIndices>& indices) const
{ {
eigen_internal_assert(checkIndexRange(indices)); eigen_internal_assert(checkIndexRange(indices));
return m_storage.data()[linearizedIndex(indices)]; return m_storage.data()[linearizedIndex(indices)];
} }
inline const Scalar& coeff(Index index) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const
{ {
eigen_internal_assert(index >= 0 && index < size()); eigen_internal_assert(index >= 0 && index < size());
return m_storage.data()[index]; return m_storage.data()[index];
@ -138,13 +138,13 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
} }
#endif #endif
inline Scalar& coeffRef(const array<Index, NumIndices>& indices) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
{ {
eigen_internal_assert(checkIndexRange(indices)); eigen_internal_assert(checkIndexRange(indices));
return m_storage.data()[linearizedIndex(indices)]; return m_storage.data()[linearizedIndex(indices)];
} }
inline Scalar& coeffRef(Index index) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index)
{ {
eigen_internal_assert(index >= 0 && index < size()); eigen_internal_assert(index >= 0 && index < size());
return m_storage.data()[index]; return m_storage.data()[index];
@ -160,19 +160,19 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
} }
#endif #endif
inline const Scalar& operator()(const array<Index, NumIndices>& indices) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(const array<Index, NumIndices>& indices) const
{ {
eigen_assert(checkIndexRange(indices)); eigen_assert(checkIndexRange(indices));
return coeff(indices); return coeff(indices);
} }
inline const Scalar& operator()(Index index) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator()(Index index) const
{ {
eigen_internal_assert(index >= 0 && index < size()); eigen_internal_assert(index >= 0 && index < size());
return coeff(index); return coeff(index);
} }
inline const Scalar& operator[](Index index) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& operator[](Index index) const
{ {
// The bracket operator is only for vectors, use the parenthesis operator instead. // The bracket operator is only for vectors, use the parenthesis operator instead.
EIGEN_STATIC_ASSERT(NumIndices == 1, YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_STATIC_ASSERT(NumIndices == 1, YOU_MADE_A_PROGRAMMING_MISTAKE);
@ -189,19 +189,19 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
} }
#endif #endif
inline Scalar& operator()(const array<Index, NumIndices>& indices) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(const array<Index, NumIndices>& indices)
{ {
eigen_assert(checkIndexRange(indices)); eigen_assert(checkIndexRange(indices));
return coeffRef(indices); return coeffRef(indices);
} }
inline Scalar& operator()(Index index) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator()(Index index)
{ {
eigen_assert(index >= 0 && index < size()); eigen_assert(index >= 0 && index < size());
return coeffRef(index); return coeffRef(index);
} }
inline Scalar& operator[](Index index) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& operator[](Index index)
{ {
// The bracket operator is only for vectors, use the parenthesis operator instead // The bracket operator is only for vectors, use the parenthesis operator instead
EIGEN_STATIC_ASSERT(NumIndices == 1, YOU_MADE_A_PROGRAMMING_MISTAKE) EIGEN_STATIC_ASSERT(NumIndices == 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
@ -223,11 +223,10 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> template<typename... IndexTypes>
inline Tensor(Index firstDimension, IndexTypes... otherDimensions) inline Tensor(Index firstDimension, IndexTypes... otherDimensions)
: m_storage() : m_storage(internal::array_prod(array<Index, NumIndices>{{firstDimension, otherDimensions...}}), array<Index, NumIndices>{{firstDimension, otherDimensions...}})
{ {
// The number of dimensions used to construct a tensor must be equal to the rank of the tensor. // The number of dimensions used to construct a tensor must be equal to the rank of the tensor.
EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
resize(array<Index, NumIndices>{{firstDimension, otherDimensions...}});
} }
#endif #endif
@ -237,7 +236,6 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
} }
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Tensor& operator=(const OtherDerived& other) EIGEN_STRONG_INLINE Tensor& operator=(const OtherDerived& other)
@ -306,7 +304,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_> >
array_zip_and_reduce<logical_and_op, lesser_op>(indices, m_storage.dimensions()); array_zip_and_reduce<logical_and_op, lesser_op>(indices, m_storage.dimensions());
} }
inline Index linearizedIndex(const array<Index, NumIndices>& indices) const EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index linearizedIndex(const array<Index, NumIndices>& indices) const
{ {
if (Options&RowMajor) { if (Options&RowMajor) {
return m_storage.dimensions().IndexOfRowMajor(indices); return m_storage.dimensions().IndexOfRowMajor(indices);

View File

@ -53,7 +53,6 @@ template<typename Derived1, typename Derived2>
struct TensorAssign<Derived1, Derived2, true> struct TensorAssign<Derived1, Derived2, true>
{ {
typedef typename Derived1::Index Index; typedef typename Derived1::Index Index;
EIGEN_DEVICE_FUNC
static inline void run(Derived1& dst, const Derived2& src) static inline void run(Derived1& dst, const Derived2& src)
{ {
TensorEvaluator<Derived1> evalDst(dst); TensorEvaluator<Derived1> evalDst(dst);
@ -63,7 +62,7 @@ struct TensorAssign<Derived1, Derived2, true>
static const int LhsStoreMode = TensorEvaluator<Derived1>::IsAligned ? Aligned : Unaligned; static const int LhsStoreMode = TensorEvaluator<Derived1>::IsAligned ? Aligned : Unaligned;
static const int RhsLoadMode = TensorEvaluator<Derived2>::IsAligned ? Aligned : Unaligned; static const int RhsLoadMode = TensorEvaluator<Derived2>::IsAligned ? Aligned : Unaligned;
static const int PacketSize = unpacket_traits<typename TensorEvaluator<Derived1>::PacketReturnType>::size; static const int PacketSize = unpacket_traits<typename TensorEvaluator<Derived1>::PacketReturnType>::size;
static const int VectorizedSize = (size / PacketSize) * PacketSize; const int VectorizedSize = (size / PacketSize) * PacketSize;
for (Index i = 0; i < VectorizedSize; i += PacketSize) { for (Index i = 0; i < VectorizedSize; i += PacketSize) {
evalDst.template writePacket<LhsStoreMode>(i, evalSrc.template packet<RhsLoadMode>(i)); evalDst.template writePacket<LhsStoreMode>(i, evalSrc.template packet<RhsLoadMode>(i));
@ -148,7 +147,7 @@ struct TensorAssignMultiThreaded
// GPU: the evaluation of the expressions is offloaded to a GPU. // GPU: the evaluation of the expressions is offloaded to a GPU.
#ifdef EIGEN_USE_GPU #if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
template <typename LhsEvaluator, typename RhsEvaluator> template <typename LhsEvaluator, typename RhsEvaluator>
__global__ void EigenMetaKernelNoCheck(LhsEvaluator evalDst, const RhsEvaluator evalSrc) { __global__ void EigenMetaKernelNoCheck(LhsEvaluator evalDst, const RhsEvaluator evalSrc) {
const int index = blockIdx.x * blockDim.x + threadIdx.x; const int index = blockIdx.x * blockDim.x + threadIdx.x;

View File

@ -30,13 +30,16 @@ class TensorBase
typedef Scalar CoeffReturnType; typedef Scalar CoeffReturnType;
typedef typename internal::packet_traits<Scalar>::type PacketReturnType; typedef typename internal::packet_traits<Scalar>::type PacketReturnType;
Derived& setZero() { EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setZero() {
return setConstant(Scalar(0)); return setConstant(Scalar(0));
} }
Derived& setConstant(const Scalar& val) { EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) {
return derived() = constant(val); return derived() = constant(val);
} }
Derived& setRandom() { EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Derived& setRandom() {
return derived() = random(); return derived() = random();
} }
@ -45,13 +48,13 @@ class TensorBase
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
constant(const Scalar& value) const { constant(const Scalar& value) const {
return TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> return TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived>
(internal::scalar_constant_op<Scalar>(value)); (derived(), internal::scalar_constant_op<Scalar>(value));
} }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived> EIGEN_STRONG_INLINE const TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>
random() const { random() const {
return TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>(); return TensorCwiseNullaryOp<internal::scalar_random_op<Scalar>, const Derived>(derived());
} }
// Coefficient-wise unary operators // Coefficient-wise unary operators
@ -191,10 +194,19 @@ class TensorBase
return TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived()); return TensorCwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
} }
// Contractions.
typedef std::pair<Index, Index> DimensionPair;
template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived>
contract(const OtherDerived& other, const Dimensions& dims) const {
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived>(derived(), other.derived(), dims);
}
// Coefficient-wise ternary operators. // Coefficient-wise ternary operators.
template<typename ThenDerived,typename ElseDerived> template<typename ThenDerived, typename ElseDerived>
inline const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived> inline const TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>
select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const{ select(const ThenDerived& thenTensor, const ElseDerived& elseTensor) const {
return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived()); return TensorSelectOp<const Derived, const ThenDerived, const ElseDerived>(derived(), thenTensor.derived(), elseTensor.derived());
} }

View File

@ -0,0 +1,229 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H
namespace Eigen {
/** \class TensorContraction
* \ingroup CXX11_Tensor_Module
*
* \brief Tensor contraction class.
*
*
*/
namespace internal {
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
{
// Type promotion to handle the case where the types of the lhs and the rhs are different.
typedef typename internal::promote_storage_type<typename LhsXprType::Scalar,
typename RhsXprType::Scalar>::ret Scalar;
typedef typename internal::packet_traits<Scalar>::type Packet;
typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
typedef typename promote_index_type<typename traits<LhsXprType>::Index,
typename traits<RhsXprType>::Index>::type Index;
typedef typename LhsXprType::Nested LhsNested;
typedef typename RhsXprType::Nested RhsNested;
typedef typename remove_reference<LhsNested>::type _LhsNested;
typedef typename remove_reference<RhsNested>::type _RhsNested;
};
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
{
typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
};
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
{
typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
};
} // end namespace internal
template<typename Indices, typename LhsXprType, typename RhsXprType>
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType> >
{
public:
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorContractionOp>::Packet Packet;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename internal::promote_storage_type<typename LhsXprType::CoeffReturnType,
typename RhsXprType::CoeffReturnType>::ret CoeffReturnType;
typedef typename internal::promote_storage_type<typename LhsXprType::PacketReturnType,
typename RhsXprType::PacketReturnType>::ret PacketReturnType;
typedef typename Eigen::internal::nested<TensorContractionOp>::type Nested;
typedef typename Eigen::internal::traits<TensorContractionOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
EIGEN_DEVICE_FUNC
const Indices& indices() const { return m_indices; }
/** \returns the nested expressions */
EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename LhsXprType::Nested>::type&
lhsExpression() const { return m_lhs_xpr; }
EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename RhsXprType::Nested>::type&
rhsExpression() const { return m_rhs_xpr; }
protected:
typename LhsXprType::Nested m_lhs_xpr;
typename RhsXprType::Nested m_rhs_xpr;
const Indices m_indices;
};
template <size_t n> struct max_n_1 {
static const size_t size = n;
};
template <> struct max_n_1<0> {
static const size_t size = 1;
};
template<typename Indices, typename LeftArgType, typename RightArgType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType> >
{
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
static const int NumDims = max_n_1<TensorEvaluator<LeftArgType>::Dimensions::count + TensorEvaluator<RightArgType>::Dimensions::count - 2 * Indices::size>::size;
typedef typename XprType::Index Index;
typedef DSizes<Index, NumDims> Dimensions;
enum {
IsAligned = TensorEvaluator<LeftArgType>::IsAligned & TensorEvaluator<RightArgType>::IsAligned,
PacketAccess = /*TensorEvaluator<LeftArgType>::PacketAccess & TensorEvaluator<RightArgType>::PacketAccess */
false,
};
TensorEvaluator(const XprType& op)
: m_leftImpl(op.lhsExpression()), m_rightImpl(op.rhsExpression())
{
Index index = 0;
Index stride = 1;
m_shiftright = 1;
int skipped = 0;
const typename TensorEvaluator<LeftArgType>::Dimensions& left_dims = m_leftImpl.dimensions();
for (int i = 0; i < TensorEvaluator<LeftArgType>::Dimensions::count; ++i) {
bool skip = false;
for (int j = 0; j < Indices::size; ++j) {
if (op.indices()[j].first == i) {
skip = true;
m_leftOffsets[2*skipped] = stride;
m_leftOffsets[2*skipped+1] = stride * left_dims[i];
m_stitchsize[skipped] = left_dims[i];
break;
}
}
if (!skip) {
m_dimensions[index++] = left_dims[i];
m_shiftright *= left_dims[i];
} else {
++skipped;
}
stride *= left_dims[i];
}
stride = 1;
skipped = 0;
const typename TensorEvaluator<RightArgType>::Dimensions& right_dims = m_rightImpl.dimensions();
for (int i = 0; i < TensorEvaluator<RightArgType>::Dimensions::count; ++i) {
bool skip = false;
for (int j = 0; j < Indices::size; ++j) {
if (op.indices()[j].second == i) {
skip = true;
m_rightOffsets[2*skipped] = stride;
m_rightOffsets[2*skipped+1] = stride * right_dims[i];
break;
}
}
if (!skip) {
m_dimensions[index++] = right_dims[i];
} else {
++skipped;
}
stride *= right_dims[i];
}
// Scalar case
if (TensorEvaluator<LeftArgType>::Dimensions::count + TensorEvaluator<LeftArgType>::Dimensions::count == 2 * Indices::size) {
m_dimensions[0] = 1;
}
}
// typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType;
const Dimensions& dimensions() const { return m_dimensions; }
void evalTo(typename XprType::Scalar* buffer) const {
for (int i = 0; i < dimensions().TotalSize(); ++i) {
buffer[i] += coeff(i);
}
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
{
const Index startLeft = index % m_shiftright;
const Index startRight = index / m_shiftright;
CoeffReturnType result = CoeffReturnType(0);
partialStitch(startLeft, startRight, 0, result);
return result;
}
/* TODO: vectorization
template<int LoadMode>
EIGEN_DEVICE_FUNC PacketReturnType packet(Index index) const
{
assert(false);
}*/
private:
EIGEN_DEVICE_FUNC void partialStitch(Index startLeft, Index startRight, int StitchIndex, CoeffReturnType& accum) const {
Index firstLeft = (startLeft / m_leftOffsets[2*StitchIndex]) * m_leftOffsets[2*StitchIndex+1] + (startLeft % m_leftOffsets[2*StitchIndex]);
Index firstRight = (startRight / m_rightOffsets[2*StitchIndex]) * m_rightOffsets[2*StitchIndex+1] + (startRight % m_rightOffsets[2*StitchIndex]);
for (int j = 0; j < m_stitchsize[StitchIndex]; ++j) {
const Index left = firstLeft+j*m_leftOffsets[2*StitchIndex];
const Index right = firstRight+j*m_rightOffsets[2*StitchIndex];
if (StitchIndex < Indices::size-1) {
partialStitch(left, right, StitchIndex+1, accum);
} else {
accum += m_leftImpl.coeff(left) * m_rightImpl.coeff(right);
}
}
}
private:
array<Index, 2*Indices::size> m_leftOffsets;
array<Index, 2*Indices::size> m_rightOffsets;
array<Index, Indices::size> m_stitchsize;
Index m_shiftright;
Dimensions m_dimensions;
TensorEvaluator<LeftArgType> m_leftImpl;
TensorEvaluator<RightArgType> m_rightImpl;
};
} // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H

View File

@ -59,7 +59,7 @@ template <typename ExpressionType> class TensorDevice<ExpressionType, ThreadPool
#endif #endif
#ifdef EIGEN_USE_GPU #if defined(EIGEN_USE_GPU) && defined(__CUDACC__)
template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice> template <typename ExpressionType> class TensorDevice<ExpressionType, GpuDevice>
{ {
public: public:

View File

@ -37,17 +37,14 @@ struct ThreadPoolDevice {
// GPU offloading // GPU offloading
#ifdef EIGEN_USE_GPU #ifdef EIGEN_USE_GPU
struct GpuDevice { struct GpuDevice {
// todo: support for multiple gpu; // The cudastream is not owned: the caller is responsible for its initialization and eventual destruction.
GpuDevice() { GpuDevice(const cudaStream_t* stream) : stream_(stream) { eigen_assert(stream); }
cudaStreamCreate(&stream_);
} const cudaStream_t& stream() const { return *stream_; }
~GpuDevice() {
cudaStreamDestroy(stream_);
}
const cudaStream_t& stream() const { return stream_; }
private: private:
cudaStream_t stream_; // TODO: multigpu.
const cudaStream_t* stream_;
}; };
#endif #endif

View File

@ -35,14 +35,14 @@ namespace Eigen {
namespace internal { namespace internal {
template<std::size_t n, typename Dimension> struct dget { template<std::size_t n, typename Dimension> struct dget {
static const std::size_t value = internal::get<n, typename Dimension::Base>::value; static const std::size_t value = get<n, typename Dimension::Base>::value;
}; };
template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor> template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
struct fixed_size_tensor_index_linearization_helper struct fixed_size_tensor_index_linearization_helper
{ {
template <typename Dimensions> template <typename Dimensions> EIGEN_DEVICE_FUNC
static inline Index run(array<Index, NumIndices> const& indices, static inline Index run(array<Index, NumIndices> const& indices,
const Dimensions& dimensions) const Dimensions& dimensions)
{ {
@ -55,7 +55,7 @@ struct fixed_size_tensor_index_linearization_helper
template<typename Index, std::size_t NumIndices, bool RowMajor> template<typename Index, std::size_t NumIndices, bool RowMajor>
struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor> struct fixed_size_tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
{ {
template <typename Dimensions> template <typename Dimensions> EIGEN_DEVICE_FUNC
static inline Index run(array<Index, NumIndices> const& indices, static inline Index run(array<Index, NumIndices> const& indices,
const Dimensions&) const Dimensions&)
{ {
@ -93,11 +93,11 @@ struct Sizes : internal::numeric_list<std::size_t, Indices...> {
return *this; return *this;
} }
template <typename DenseIndex> template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *static_cast<const Base*>(this)); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *static_cast<const Base*>(this));
} }
template <typename DenseIndex> template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *static_cast<const Base*>(this)); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *static_cast<const Base*>(this));
} }
@ -139,11 +139,11 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
return *this; return *this;
} }
template <typename DenseIndex> template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *this); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *this);
} }
template <typename DenseIndex> template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *this); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *this);
} }
@ -180,13 +180,18 @@ struct tensor_index_linearization_helper<Index, NumIndices, 0, RowMajor>
template <typename DenseIndex, std::size_t NumDims> template <typename DenseIndex, std::size_t NumDims>
struct DSizes : array<DenseIndex, NumDims> { struct DSizes : array<DenseIndex, NumDims> {
typedef array<DenseIndex, NumDims> Base; typedef array<DenseIndex, NumDims> Base;
static const std::size_t count = NumDims;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() const {
return internal::array_prod(*static_cast<const Base*>(this)); return internal::array_prod(*static_cast<const Base*>(this));
} }
DSizes() { } EIGEN_DEVICE_FUNC DSizes() {
explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { } for (int i = 0 ; i < NumDims; ++i) {
(*this)[i] = 0;
}
}
EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { }
DSizes& operator = (const array<DenseIndex, NumDims>& other) { DSizes& operator = (const array<DenseIndex, NumDims>& other) {
*static_cast<Base*>(this) = other; *static_cast<Base*>(this) = other;
@ -194,10 +199,10 @@ struct DSizes : array<DenseIndex, NumDims> {
} }
// A constexpr would be so much better here // A constexpr would be so much better here
size_t IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this)); return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this));
} }
size_t IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this)); return internal::tensor_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this));
} }
}; };

View File

@ -21,7 +21,6 @@ namespace Eigen {
* *
* TODO: add support for more types of expressions, in particular expressions * TODO: add support for more types of expressions, in particular expressions
* leading to lvalues (slicing, reshaping, etc...) * leading to lvalues (slicing, reshaping, etc...)
* TODO: add support for vectorization
*/ */
template<typename Derived> template<typename Derived>
@ -32,16 +31,19 @@ struct TensorEvaluator
typedef typename Derived::Packet Packet; typedef typename Derived::Packet Packet;
typedef typename Derived::Scalar CoeffReturnType; typedef typename Derived::Scalar CoeffReturnType;
typedef typename Derived::Packet PacketReturnType; typedef typename Derived::Packet PacketReturnType;
typedef typename Derived::Dimensions Dimensions;
enum { enum {
IsAligned = Derived::IsAligned, IsAligned = Derived::IsAligned,
PacketAccess = Derived::PacketAccess, PacketAccess = Derived::PacketAccess,
}; };
TensorEvaluator(Derived& m) EIGEN_DEVICE_FUNC TensorEvaluator(Derived& m)
: m_data(const_cast<Scalar*>(m.data())) : m_data(const_cast<Scalar*>(m.data())), m_dims(m.dimensions())
{ } { }
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_dims; }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const { EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const {
return m_data[index]; return m_data[index];
} }
@ -64,29 +66,34 @@ struct TensorEvaluator
protected: protected:
Scalar* m_data; Scalar* m_data;
Dimensions m_dims;
}; };
// -------------------- CwiseNullaryOp -------------------- // -------------------- CwiseNullaryOp --------------------
template<typename NullaryOp, typename PlainObjectType> template<typename NullaryOp, typename ArgType>
struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType> >
{ {
typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> XprType; typedef TensorCwiseNullaryOp<NullaryOp, ArgType> XprType;
enum { enum {
IsAligned = true, IsAligned = true,
PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess, PacketAccess = internal::functor_traits<NullaryOp>::PacketAccess,
}; };
EIGEN_DEVICE_FUNC
TensorEvaluator(const XprType& op) TensorEvaluator(const XprType& op)
: m_functor(op.functor()) : m_functor(op.functor()), m_argImpl(op.nestedExpression())
{ } { }
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType; typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{ {
@ -101,6 +108,7 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
private: private:
const NullaryOp m_functor; const NullaryOp m_functor;
TensorEvaluator<ArgType> m_argImpl;
}; };
@ -117,7 +125,7 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
PacketAccess = TensorEvaluator<ArgType>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess, PacketAccess = TensorEvaluator<ArgType>::PacketAccess & internal::functor_traits<UnaryOp>::PacketAccess,
}; };
TensorEvaluator(const XprType& op) EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
: m_functor(op.functor()), : m_functor(op.functor()),
m_argImpl(op.nestedExpression()) m_argImpl(op.nestedExpression())
{ } { }
@ -125,6 +133,9 @@ struct TensorEvaluator<const TensorCwiseUnaryOp<UnaryOp, ArgType> >
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType; typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<ArgType>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); }
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{ {
@ -156,7 +167,7 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
internal::functor_traits<BinaryOp>::PacketAccess, internal::functor_traits<BinaryOp>::PacketAccess,
}; };
TensorEvaluator(const XprType& op) EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
: m_functor(op.functor()), : m_functor(op.functor()),
m_leftImpl(op.lhsExpression()), m_leftImpl(op.lhsExpression()),
m_rightImpl(op.rhsExpression()) m_rightImpl(op.rhsExpression())
@ -165,6 +176,13 @@ struct TensorEvaluator<const TensorCwiseBinaryOp<BinaryOp, LeftArgType, RightArg
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType; typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<LeftArgType>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
{
// TODO: use right impl instead if right impl dimensions are known at compile time.
return m_leftImpl.dimensions();
}
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{ {
@ -196,7 +214,7 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
TensorEvaluator<IfArgType>::PacketAccess*/, TensorEvaluator<IfArgType>::PacketAccess*/,
}; };
TensorEvaluator(const XprType& op) EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op)
: m_condImpl(op.ifExpression()), : m_condImpl(op.ifExpression()),
m_thenImpl(op.thenExpression()), m_thenImpl(op.thenExpression()),
m_elseImpl(op.elseExpression()) m_elseImpl(op.elseExpression())
@ -205,7 +223,13 @@ struct TensorEvaluator<const TensorSelectOp<IfArgType, ThenArgType, ElseArgType>
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename XprType::PacketReturnType PacketReturnType; typedef typename XprType::PacketReturnType PacketReturnType;
typedef typename TensorEvaluator<IfArgType>::Dimensions Dimensions;
EIGEN_DEVICE_FUNC const Dimensions& dimensions() const
{
// TODO: use then or else impl instead if they happen to be known at compile time.
return m_condImpl.dimensions();
}
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
{ {
return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index); return m_condImpl.coeff(index) ? m_thenImpl.coeff(index) : m_elseImpl.coeff(index);

View File

@ -28,13 +28,13 @@ namespace Eigen {
* *
*/ */
namespace internal { namespace internal {
template<typename NullaryOp, typename PlainObjectType> template<typename NullaryOp, typename XprType>
struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > struct traits<TensorCwiseNullaryOp<NullaryOp, XprType> >
: traits<PlainObjectType> : traits<XprType>
{ {
typedef typename PlainObjectType::Packet Packet; typedef typename XprType::Packet Packet;
typedef typename PlainObjectType::Scalar Scalar; typedef typename XprType::Scalar Scalar;
typedef typename PlainObjectType::Nested XprTypeNested; typedef typename XprType::Nested XprTypeNested;
typedef typename remove_reference<XprTypeNested>::type _XprTypeNested; typedef typename remove_reference<XprTypeNested>::type _XprTypeNested;
}; };
@ -42,27 +42,31 @@ struct traits<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> >
template<typename NullaryOp, typename PlainObjectType> template<typename NullaryOp, typename XprType>
class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, PlainObjectType> > class TensorCwiseNullaryOp : public TensorBase<TensorCwiseNullaryOp<NullaryOp, XprType> >
{ {
public: public:
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar; typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Scalar Scalar;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet; typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Packet Packet;
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar; typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
typedef typename PlainObjectType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;
typedef typename PlainObjectType::PacketReturnType PacketReturnType; typedef typename XprType::PacketReturnType PacketReturnType;
typedef TensorCwiseNullaryOp<NullaryOp, PlainObjectType> Nested; typedef TensorCwiseNullaryOp<NullaryOp, XprType> Nested;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind; typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::StorageKind StorageKind;
typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index; typedef typename Eigen::internal::traits<TensorCwiseNullaryOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const NullaryOp& func = NullaryOp()) EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorCwiseNullaryOp(const XprType& xpr, const NullaryOp& func = NullaryOp())
: m_functor(func) {} : m_xpr(xpr), m_functor(func) {}
EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename XprType::Nested>::type&
nestedExpression() const { return m_xpr; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const NullaryOp& functor() const { return m_functor; } const NullaryOp& functor() const { return m_functor; }
protected: protected:
// todo: add tensor dimension to be able to do some sanity checks typename XprType::Nested m_xpr;
const NullaryOp m_functor; const NullaryOp m_functor;
}; };

View File

@ -52,7 +52,7 @@ class TensorFixedSize : public TensorBase<TensorFixedSize<Scalar_, Dimensions_,
public: public:
EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return m_storage.dimensions()[n]; } EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return m_storage.dimensions()[n]; }
EIGEN_STRONG_INLINE array<Index, NumIndices> dimensions() const { return m_storage.dimensions(); } EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_storage.dimensions(); }
EIGEN_STRONG_INLINE Index size() const { return m_storage.size(); } EIGEN_STRONG_INLINE Index size() const { return m_storage.size(); }
EIGEN_STRONG_INLINE Scalar *data() { return m_storage.data(); } EIGEN_STRONG_INLINE Scalar *data() { return m_storage.data(); }
EIGEN_STRONG_INLINE const Scalar *data() const { return m_storage.data(); } EIGEN_STRONG_INLINE const Scalar *data() const { return m_storage.data(); }

View File

@ -21,6 +21,8 @@ template<typename NullaryOp, typename PlainObjectType> class TensorCwiseNullaryO
template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp; template<typename UnaryOp, typename XprType> class TensorCwiseUnaryOp;
template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp; template<typename BinaryOp, typename LeftXprType, typename RightXprType> class TensorCwiseBinaryOp;
template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp; template<typename IfXprType, typename ThenXprType, typename ElseXprType> class TensorSelectOp;
template <typename XprType> class TensorReductionOp;
template<typename Dimensions, typename LeftXprType, typename RightXprType> class TensorContractionOp;
template<typename ExpressionType, typename DeviceType> class TensorDevice; template<typename ExpressionType, typename DeviceType> class TensorDevice;

View File

@ -53,7 +53,7 @@ class TensorStorage
EIGEN_STRONG_INLINE const T *data() const { return m_data; } EIGEN_STRONG_INLINE const T *data() const { return m_data; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const FixedDimensions dimensions() const { return m_dimensions; } EIGEN_STRONG_INLINE const FixedDimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE DenseIndex size() const { return m_dimensions.TotalSize(); } EIGEN_STRONG_INLINE DenseIndex size() const { return m_dimensions.TotalSize(); }
@ -111,7 +111,8 @@ class TensorStorage<T, NumIndices_, Dynamic, Options_, typename internal::gen_nu
~TensorStorage() { internal::conditional_aligned_delete_auto<T,(Options_&DontAlign)==0>(m_data, internal::array_prod(m_dimensions)); } ~TensorStorage() { internal::conditional_aligned_delete_auto<T,(Options_&DontAlign)==0>(m_data, internal::array_prod(m_dimensions)); }
void swap(Self_& other) void swap(Self_& other)
{ std::swap(m_data,other.m_data); std::swap(m_dimensions,other.m_dimensions); } { std::swap(m_data,other.m_data); std::swap(m_dimensions,other.m_dimensions); }
const DSizes<DenseIndex, NumIndices_>& dimensions() const {return m_dimensions;}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const DSizes<DenseIndex, NumIndices_>& dimensions() const {return m_dimensions;}
void conservativeResize(DenseIndex size, const array<DenseIndex, NumIndices_>& nbDimensions) void conservativeResize(DenseIndex size, const array<DenseIndex, NumIndices_>& nbDimensions)
{ {
@ -132,10 +133,10 @@ class TensorStorage<T, NumIndices_, Dynamic, Options_, typename internal::gen_nu
m_dimensions = nbDimensions; m_dimensions = nbDimensions;
} }
T *data() { return m_data; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T *data() { return m_data; }
const T *data() const { return m_data; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T *data() const { return m_data; }
DenseIndex size() const { return m_dimensions.TotalSize(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex size() const { return m_dimensions.TotalSize(); }
}; };