From 47981c5925caa8316205ea84b17616dd69073678 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 7 Jul 2014 14:07:57 -0700 Subject: [PATCH] Added support for tensor slicing --- .../Eigen/CXX11/src/Tensor/TensorMorphing.h | 343 +++++++++++++++++- 1 file changed, 327 insertions(+), 16 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index 764bba4e6..55954a3a7 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -20,10 +20,9 @@ namespace Eigen { * */ namespace internal { -template -struct traits > : public traits +template +struct traits > : public traits { - // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename XprType::Scalar Scalar; typedef typename internal::packet_traits::type Packet; typedef typename traits::StorageKind StorageKind; @@ -32,24 +31,24 @@ struct traits > : public traits::type _Nested; }; -template -struct eval, Eigen::Dense> +template +struct eval, Eigen::Dense> { - typedef const TensorReshapingOp& type; + typedef const TensorReshapingOp& type; }; -template -struct nested, 1, typename eval >::type> +template +struct nested, 1, typename eval >::type> { - typedef TensorReshapingOp type; + typedef TensorReshapingOp type; }; } // end namespace internal -template -class TensorReshapingOp : public TensorBase > +template +class TensorReshapingOp : public TensorBase, WriteAccessors> { public: typedef typename Eigen::internal::traits::Scalar Scalar; @@ -71,16 +70,27 @@ class TensorReshapingOp : public TensorBase::type& expression() const { return m_xpr; } + template + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorReshapingOp& operator = (const OtherDerived& other) + { + typedef TensorAssignOp Assign; + Assign assign(*this, other); + internal::TensorExecutor::run(assign, DefaultDevice()); + return *this; + } + protected: typename XprType::Nested m_xpr; const NewDimensions m_dims; }; -template -struct TensorEvaluator, Device> +// Eval as rvalue +template +struct TensorEvaluator, Device> { - typedef TensorReshapingOp XprType; + typedef TensorReshapingOp XprType; typedef NewDimensions Dimensions; enum { @@ -88,7 +98,7 @@ struct TensorEvaluator, Device> PacketAccess = TensorEvaluator::PacketAccess, }; - TensorEvaluator(const XprType& op, const Device& device) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) : m_impl(op.expression(), device), m_dimensions(op.dimensions()) { } @@ -96,7 +106,7 @@ struct TensorEvaluator, Device> typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - const Dimensions& dimensions() const { return m_dimensions; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { m_impl.evalSubExprsIfNeeded(); @@ -116,12 +126,313 @@ struct TensorEvaluator, Device> return m_impl.template packet(index); } + protected: + NewDimensions m_dimensions; + TensorEvaluator m_impl; +}; + + +// Eval as lvalue +// TODO(bsteiner): share the code with the evaluator for rvalue reshapes. +template +struct TensorEvaluator, Device> +{ + typedef TensorReshapingOp XprType; + typedef NewDimensions Dimensions; + + enum { + IsAligned = TensorEvaluator::IsAligned, + PacketAccess = TensorEvaluator::PacketAccess, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_impl(op.expression(), device), m_dimensions(op.dimensions()) + { } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { + m_impl.evalSubExprsIfNeeded(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + m_impl.cleanup(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const + { + return m_impl.coeff(index); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) + { + return m_impl.coeffRef(index); + } + template EIGEN_STRONG_INLINE + void writePacket(Index index, const PacketReturnType& x) + { + m_impl.template writePacket(index, x); + } + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const + { + return m_impl.template packet(index); + } + private: NewDimensions m_dimensions; TensorEvaluator m_impl; }; +/** \class TensorSlicing + * \ingroup CXX11_Tensor_Module + * + * \brief Tensor slicing class. + * + * + */ +namespace internal { +template +struct traits > : public traits +{ + typedef typename XprType::Scalar Scalar; + typedef typename internal::packet_traits::type Packet; + typedef typename traits::StorageKind StorageKind; + typedef typename traits::Index Index; + typedef typename XprType::Nested Nested; + typedef typename remove_reference::type _Nested; +}; + +template +struct eval, Eigen::Dense> +{ + typedef const TensorSlicingOp& type; +}; + +template +struct nested, 1, typename eval >::type> +{ + typedef TensorSlicingOp type; +}; + +} // end namespace internal + + + +template +class TensorSlicingOp : public TensorBase > +{ + public: + typedef typename Eigen::internal::traits::Scalar Scalar; + typedef typename Eigen::internal::traits::Packet Packet; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef typename Eigen::internal::nested::type Nested; + typedef typename Eigen::internal::traits::StorageKind StorageKind; + typedef typename Eigen::internal::traits::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorSlicingOp(const XprType& expr, const StartIndices& indices, const Sizes& sizes) + : m_xpr(expr), m_indices(indices), m_sizes(sizes) {} + + EIGEN_DEVICE_FUNC + const StartIndices& startIndices() const { return m_indices; } + EIGEN_DEVICE_FUNC + const Sizes& sizes() const { return m_sizes; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + expression() const { return m_xpr; } + + template + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE TensorSlicingOp& operator = (const OtherDerived& other) + { + typedef TensorAssignOp Assign; + Assign assign(*this, other); + internal::TensorExecutor::run(assign, DefaultDevice()); + return *this; + } + + protected: + typename XprType::Nested m_xpr; + const StartIndices m_indices; + const Sizes m_sizes; +}; + + +// Eval as rvalue +template +struct TensorEvaluator, Device> +{ + typedef TensorSlicingOp XprType; + static const int NumDims = internal::array_size::value; + + enum { + IsAligned = TensorEvaluator::IsAligned, + PacketAccess = /*TensorEvaluator::PacketAccess*/false, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_impl(op.expression(), device), m_dimensions(op.sizes()), m_offsets(op.startIndices()) + { + for (int i = 0; i < internal::array_size::value; ++i) { + eigen_assert(m_impl.dimensions()[i] >= op.sizes()[i] + op.startIndices()[i]); + } + + const typename TensorEvaluator::Dimensions& input_dims = m_impl.dimensions(); + for (int i = 0; i < NumDims; ++i) { + if (i > 0) { + m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; + } else { + m_inputStrides[0] = 1; + } + } + + const Sizes& output_dims = op.sizes(); + for (int i = 0; i < NumDims; ++i) { + if (i > 0) { + m_outputStrides[i] = m_outputStrides[i-1] * output_dims[i-1]; + } else { + m_outputStrides[0] = 1; + } + } + } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef Sizes Dimensions; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { + m_impl.evalSubExprsIfNeeded(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + m_impl.cleanup(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const + { + Index inputIndex = 0; + for (int i = NumDims - 1; i >= 0; --i) { + const Index idx = index / m_outputStrides[i]; + inputIndex += (idx + m_offsets[i]) * m_inputStrides[i]; + index -= idx * m_outputStrides[i]; + } + return m_impl.coeff(inputIndex); + } + + /* template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const + { + return m_impl.template packet(index); + }*/ + + private: + Dimensions m_dimensions; + array m_outputStrides; + array m_inputStrides; + const StartIndices m_offsets; + TensorEvaluator m_impl; +}; + + +// Eval as lvalue +// TODO(bsteiner): share the code with the evaluator for rvalue slices. +template +struct TensorEvaluator, Device> +{ + typedef TensorSlicingOp XprType; + static const int NumDims = internal::array_size::value; + + enum { + IsAligned = TensorEvaluator::IsAligned, + PacketAccess = /*TensorEvaluator::PacketAccess*/false, + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_impl(op.expression(), device), m_dimensions(op.sizes()), m_offsets(op.startIndices()) + { + for (int i = 0; i < internal::array_size::value; ++i) { + eigen_assert(m_impl.dimensions()[i] >= op.sizes()[i] + op.startIndices()[i]); + } + + const typename TensorEvaluator::Dimensions& input_dims = m_impl.dimensions(); + for (int i = 0; i < NumDims; ++i) { + if (i > 0) { + m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; + } else { + m_inputStrides[0] = 1; + } + } + + const Sizes& output_dims = op.sizes(); + for (int i = 0; i < NumDims; ++i) { + if (i > 0) { + m_outputStrides[i] = m_outputStrides[i-1] * output_dims[i-1]; + } else { + m_outputStrides[0] = 1; + } + } + } + + typedef typename XprType::Index Index; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef Sizes Dimensions; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { + m_impl.evalSubExprsIfNeeded(); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + m_impl.cleanup(); + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const + { + Index inputIndex = 0; + for (int i = NumDims - 1; i >= 0; --i) { + const Index idx = index / m_outputStrides[i]; + inputIndex += (idx + m_offsets[i]) * m_inputStrides[i]; + index -= idx * m_outputStrides[i]; + } + return m_impl.coeff(inputIndex); + } + + /* template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const + { + return m_impl.template packet(index); + }*/ + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType& coeffRef(Index index) + { + Index inputIndex = 0; + for (int i = NumDims - 1; i >= 0; --i) { + const Index idx = index / m_outputStrides[i]; + inputIndex += (idx + m_offsets[i]) * m_inputStrides[i]; + index -= idx * m_outputStrides[i]; + } + return m_impl.coeffRef(inputIndex); + } + + private: + Dimensions m_dimensions; + array m_outputStrides; + array m_inputStrides; + const StartIndices m_offsets; + TensorEvaluator m_impl; +}; + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_MORPHING_H