From 5385773015ae4f8c4949d083e3c66d80f992e6a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Wed, 10 Jan 2024 00:45:39 +0000 Subject: [PATCH] Fix TensorForcedEval in the case of the evaluator being copied. --- .../Eigen/CXX11/src/Tensor/TensorForcedEval.h | 53 +++++++++++++------ 1 file changed, 38 insertions(+), 15 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h index 0e87fac04..d0fbfb38a 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h @@ -13,6 +13,8 @@ // IWYU pragma: private #include "./InternalHeaderCheck.h" +#include + namespace Eigen { /** \class TensorForcedEval @@ -91,6 +93,26 @@ struct non_integral_type_placement_new { }; } // end namespace internal +template +class DeviceTempPointerHolder { + public: + DeviceTempPointerHolder(const Device& device, size_t size) + : device_(device), size_(size), ptr_(device.allocate_temp(size)) {} + + ~DeviceTempPointerHolder() { + device_.deallocate_temp(ptr_); + size_ = 0; + ptr_ = nullptr; + } + + void* ptr() { return ptr_; } + + private: + Device device_; + size_t size_; + void* ptr_; +}; + template struct TensorEvaluator, Device> { typedef const internal::remove_all_t ArgType; @@ -124,7 +146,11 @@ struct TensorEvaluator, Device> { //===--------------------------------------------------------------------===// TensorEvaluator(const XprType& op, const Device& device) - : m_impl(op.expression(), device), m_op(op.expression()), m_device(device), m_buffer(NULL) {} + : m_impl(op.expression(), device), + m_op(op.expression()), + m_device(device), + m_buffer_holder(nullptr), + m_buffer(nullptr) {} ~TensorEvaluator() { cleanup(); } @@ -132,11 +158,8 @@ struct TensorEvaluator, Device> { EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) { const Index numValues = internal::array_prod(m_impl.dimensions()); - - if (m_buffer != nullptr) { - m_device.deallocate_temp(m_buffer); - } - m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(numValues * sizeof(CoeffReturnType))); + m_buffer_holder = std::make_shared>(m_device, numValues * sizeof(CoeffReturnType)); + m_buffer = static_cast(m_buffer_holder->ptr()); internal::non_integral_type_placement_new()(numValues, m_buffer); @@ -154,10 +177,9 @@ struct TensorEvaluator, Device> { template EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) { const Index numValues = internal::array_prod(m_impl.dimensions()); - if (m_buffer != nullptr) { - m_device.deallocate_temp(m_buffer); - } - m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(numValues * sizeof(CoeffReturnType))); + m_buffer_holder = std::make_shared>(m_device, numValues * sizeof(CoeffReturnType)); + m_buffer = static_cast(m_buffer_holder->ptr()); + typedef TensorEvalToOp> EvalTo; EvalTo evalToTmp(m_device.get(m_buffer), m_op); @@ -171,8 +193,8 @@ struct TensorEvaluator, Device> { #endif EIGEN_STRONG_INLINE void cleanup() { - m_device.deallocate_temp(m_buffer); - m_buffer = NULL; + m_buffer_holder = nullptr; + m_buffer = nullptr; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_buffer[index]; } @@ -188,7 +210,7 @@ struct TensorEvaluator, Device> { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch, bool /*root_of_expr_ast*/ = false) const { - eigen_assert(m_buffer != NULL); + eigen_assert(m_buffer != nullptr); return TensorBlock::materialize(m_buffer, m_impl.dimensions(), desc, scratch); } @@ -196,13 +218,14 @@ struct TensorEvaluator, Device> { return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); } - EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const { return m_buffer; } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EvaluatorPointerType data() const { return m_buffer; } private: TensorEvaluator m_impl; const ArgType m_op; const Device EIGEN_DEVICE_REF m_device; - EvaluatorPointerType m_buffer; + std::shared_ptr> m_buffer_holder; + EvaluatorPointerType m_buffer; // Cached copy of the value stored in m_buffer_holder. }; } // end namespace Eigen