diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h index a2a925775..3bfe80c9e 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h @@ -102,6 +102,7 @@ struct TensorEvaluator, Device> { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator::Dimensions Dimensions; @@ -112,9 +113,14 @@ struct TensorEvaluator, Device> return m_rightImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { - m_leftImpl.evalSubExprsIfNeeded(); - m_rightImpl.evalSubExprsIfNeeded(); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + eigen_assert(internal::dimensions_match(m_leftImpl.dimensions(), m_rightImpl.dimensions())); + m_leftImpl.evalSubExprsIfNeeded(NULL); + // If the lhs provides raw access to its storage area (i.e. if m_leftImpl.data() returns a non + // null value), attempt to evaluate the rhs expression in place. Returns true iff in place + // evaluation isn't supported and the caller still needs to manually assign the values generated + // by the rhs to the lhs. + return m_rightImpl.evalSubExprsIfNeeded(m_leftImpl.data()); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_leftImpl.cleanup(); diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h index ac9829ce9..0f969036c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h @@ -39,13 +39,20 @@ struct TensorEvaluator PacketAccess = Derived::PacketAccess, }; - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device&) - : m_data(const_cast(m.data())), m_dims(m.dimensions()) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const Derived& m, const Device& device) + : m_data(const_cast(m.data())), m_dims(m.dimensions()), m_device(device) { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* dest) { + if (dest) { + m_device.memcpy((void*)dest, m_data, sizeof(Scalar) * m_dims.TotalSize()); + return false; + } + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { @@ -70,9 +77,12 @@ struct TensorEvaluator return internal::pstoret(m_data + index, x); } + Scalar* data() const { return m_data; } + protected: Scalar* m_data; Dimensions m_dims; + const Device& m_device; }; @@ -98,7 +108,7 @@ struct TensorEvaluator EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dims; } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { @@ -112,6 +122,8 @@ struct TensorEvaluator return internal::ploadt(m_data + index); } + const Scalar* data() const { return m_data; } + protected: const Scalar* m_data; Dimensions m_dims; @@ -138,13 +150,14 @@ struct TensorEvaluator, Device> { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { } EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const @@ -158,6 +171,8 @@ struct TensorEvaluator, Device> return m_functor.packetOp(index); } + Scalar* data() const { return NULL; } + private: const NullaryOp m_functor; TensorEvaluator m_argImpl; @@ -183,14 +198,16 @@ struct TensorEvaluator, Device> { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator::Dimensions Dimensions; EIGEN_DEVICE_FUNC const Dimensions& dimensions() const { return m_argImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { - m_argImpl.evalSubExprsIfNeeded(); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + m_argImpl.evalSubExprsIfNeeded(NULL); + return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_argImpl.cleanup(); @@ -207,6 +224,8 @@ struct TensorEvaluator, Device> return m_functor.packetOp(m_argImpl.template packet(index)); } + Scalar* data() const { return NULL; } + private: const UnaryOp m_functor; TensorEvaluator m_argImpl; @@ -233,6 +252,7 @@ struct TensorEvaluator::Dimensions Dimensions; @@ -243,9 +263,10 @@ struct TensorEvaluator(index), m_rightImpl.template packet(index)); } + Scalar* data() const { return NULL; } + private: const BinaryOp m_functor; TensorEvaluator m_leftImpl; @@ -289,6 +312,7 @@ struct TensorEvaluator { } typedef typename XprType::Index Index; + typedef typename XprType::Scalar Scalar; typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::PacketReturnType PacketReturnType; typedef typename TensorEvaluator::Dimensions Dimensions; @@ -299,10 +323,11 @@ struct TensorEvaluator return m_condImpl.dimensions(); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeeded() { - m_condImpl.evalSubExprsIfNeeded(); - m_thenImpl.evalSubExprsIfNeeded(); - m_elseImpl.evalSubExprsIfNeeded(); + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar*) { + m_condImpl.evalSubExprsIfNeeded(NULL); + m_thenImpl.evalSubExprsIfNeeded(NULL); + m_elseImpl.evalSubExprsIfNeeded(NULL); + return true; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_condImpl.cleanup(); @@ -327,6 +352,8 @@ struct TensorEvaluator m_elseImpl.template packet(index)); } + Scalar* data() const { return NULL; } + private: TensorEvaluator m_condImpl; TensorEvaluator m_thenImpl;