From a669052f12d6d71ba815764d6419726d64fef675 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Mon, 9 Jun 2014 09:45:30 -0700 Subject: [PATCH] Improved support for rvalues in tensor expressions. --- .../Eigen/CXX11/src/Tensor/TensorBase.h | 58 ++++++++++++++----- .../CXX11/src/Tensor/TensorContraction.h | 4 ++ .../CXX11/src/Tensor/TensorConvolution.h | 4 ++ .../Eigen/CXX11/src/Tensor/TensorExpr.h | 8 +++ .../src/Tensor/TensorForwardDeclarations.h | 6 +- .../Eigen/CXX11/src/Tensor/TensorMorphing.h | 5 +- .../Eigen/CXX11/src/Tensor/TensorTraits.h | 6 +- 7 files changed, 71 insertions(+), 20 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 932e5c82d..e447a5d40 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -22,7 +22,7 @@ namespace Eigen { */ template -class TensorBase +class TensorBase { public: typedef typename internal::traits::Scalar Scalar; @@ -30,19 +30,6 @@ class TensorBase typedef Scalar CoeffReturnType; typedef typename internal::packet_traits::type PacketReturnType; - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Derived& setZero() { - return setConstant(Scalar(0)); - } - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) { - return derived() = constant(val); - } - EIGEN_DEVICE_FUNC - EIGEN_STRONG_INLINE Derived& setRandom() { - return derived() = random(); - } - // Nullary operators EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const TensorCwiseNullaryOp, const Derived> @@ -224,14 +211,53 @@ class TensorBase return TensorReshapingOp(derived(), newDimensions); } + protected: + template friend class TensorBase; + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE const Derived& derived() const { return *static_cast(this); } +}; + + +template +class TensorBase : public TensorBase { + public: + typedef typename internal::traits::Scalar Scalar; + typedef typename internal::traits::Index Index; + typedef Scalar CoeffReturnType; + typedef typename internal::packet_traits::type PacketReturnType; + + template friend class TensorBase; + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setZero() { + return setConstant(Scalar(0)); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setConstant(const Scalar& val) { + return derived() = this->constant(val); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Derived& setRandom() { + return derived() = this->random(); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Derived& operator+=(const OtherDerived& other) { + return derived() = TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + Derived& operator-=(const OtherDerived& other) { + return derived() = TensorCwiseBinaryOp, const Derived, const OtherDerived>(derived(), other.derived()); + } + // Select the device on which to evaluate the expression. template TensorDevice device(const DeviceType& device) { return TensorDevice(device, derived()); } - protected: - template friend class TensorBase; + protected: EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return *static_cast(this); } EIGEN_DEVICE_FUNC diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index d424df36e..d371eb76d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -35,6 +35,10 @@ struct traits > typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; + + enum { + Flags = 0, + }; }; template diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h index ca2e0e562..501e9a522 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h @@ -35,6 +35,10 @@ struct traits > typedef typename KernelXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; + + enum { + Flags = 0, + }; }; template diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h index 60908ee94..de66da13f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h @@ -36,6 +36,10 @@ struct traits > typedef typename XprType::Scalar Scalar; typedef typename XprType::Nested XprTypeNested; typedef typename remove_reference::type _XprTypeNested; + + enum { + Flags = 0, + }; }; } // end namespace internal @@ -153,6 +157,10 @@ struct traits > typedef typename RhsXprType::Nested RhsNested; typedef typename remove_reference::type _LhsNested; typedef typename remove_reference::type _RhsNested; + + enum { + Flags = 0, + }; }; template diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index b8833362c..1fb90478f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -15,7 +15,7 @@ namespace Eigen { template class Tensor; template class TensorFixedSize; template class TensorMap; -template class TensorBase; +template::value> class TensorBase; template class TensorCwiseNullaryOp; template class TensorCwiseUnaryOp; @@ -29,6 +29,10 @@ template class TensorDevice; // Move to internal? template struct TensorEvaluator; +namespace internal { +template struct TensorAssign; +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_FORWARD_DECLARATIONS_H diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h index 3e089fe1e..7d5f9271e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h @@ -21,7 +21,7 @@ namespace Eigen { */ namespace internal { template -struct traits > +struct traits > : public traits { // Type promotion to handle the case where the types of the lhs and the rhs are different. typedef typename XprType::Scalar Scalar; @@ -81,6 +81,7 @@ template struct TensorEvaluator > { typedef TensorReshapingOp XprType; + typedef NewDimensions Dimensions; enum { IsAligned = TensorEvaluator::IsAligned, @@ -95,7 +96,7 @@ struct TensorEvaluator > typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; - const NewDimensions& dimensions() const { return m_dimensions; } + const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h index 2de698a57..40f805741 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h @@ -52,7 +52,7 @@ struct traits > typedef DenseIndex Index; enum { Options = Options_, - Flags = compute_tensor_flags::ret, + Flags = compute_tensor_flags::ret | LvalueBit, }; }; @@ -63,6 +63,10 @@ struct traits > typedef Scalar_ Scalar; typedef Dense StorageKind; typedef DenseIndex Index; + enum { + Options = Options_, + Flags = compute_tensor_flags::ret | LvalueBit, + }; };