Fix TensorForcedEval in the case of the evaluator being copied.

This commit is contained in:
Antonio Sánchez 2024-01-10 00:45:39 +00:00 committed by Rasmus Munk Larsen
parent 3f3bc6d862
commit 5385773015

View File

@ -13,6 +13,8 @@
// IWYU pragma: private // IWYU pragma: private
#include "./InternalHeaderCheck.h" #include "./InternalHeaderCheck.h"
#include <memory>
namespace Eigen { namespace Eigen {
/** \class TensorForcedEval /** \class TensorForcedEval
@ -91,6 +93,26 @@ struct non_integral_type_placement_new<Eigen::SyclDevice, CoeffReturnType> {
}; };
} // end namespace internal } // end namespace internal
template <typename Device>
class DeviceTempPointerHolder {
public:
DeviceTempPointerHolder(const Device& device, size_t size)
: device_(device), size_(size), ptr_(device.allocate_temp(size)) {}
~DeviceTempPointerHolder() {
device_.deallocate_temp(ptr_);
size_ = 0;
ptr_ = nullptr;
}
void* ptr() { return ptr_; }
private:
Device device_;
size_t size_;
void* ptr_;
};
template <typename ArgType_, typename Device> template <typename ArgType_, typename Device>
struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> { struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
typedef const internal::remove_all_t<ArgType_> ArgType; typedef const internal::remove_all_t<ArgType_> ArgType;
@ -124,7 +146,11 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
//===--------------------------------------------------------------------===// //===--------------------------------------------------------------------===//
TensorEvaluator(const XprType& op, const Device& device) TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device), m_op(op.expression()), m_device(device), m_buffer(NULL) {} : m_impl(op.expression(), device),
m_op(op.expression()),
m_device(device),
m_buffer_holder(nullptr),
m_buffer(nullptr) {}
~TensorEvaluator() { cleanup(); } ~TensorEvaluator() { cleanup(); }
@ -132,11 +158,8 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) { EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
const Index numValues = internal::array_prod(m_impl.dimensions()); const Index numValues = internal::array_prod(m_impl.dimensions());
m_buffer_holder = std::make_shared<DeviceTempPointerHolder<Device>>(m_device, numValues * sizeof(CoeffReturnType));
if (m_buffer != nullptr) { m_buffer = static_cast<EvaluatorPointerType>(m_buffer_holder->ptr());
m_device.deallocate_temp(m_buffer);
}
m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(numValues * sizeof(CoeffReturnType)));
internal::non_integral_type_placement_new<Device, CoeffReturnType>()(numValues, m_buffer); internal::non_integral_type_placement_new<Device, CoeffReturnType>()(numValues, m_buffer);
@ -154,10 +177,9 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
template <typename EvalSubExprsCallback> template <typename EvalSubExprsCallback>
EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) { EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
const Index numValues = internal::array_prod(m_impl.dimensions()); const Index numValues = internal::array_prod(m_impl.dimensions());
if (m_buffer != nullptr) { m_buffer_holder = std::make_shared<DeviceTempPointerHolder<Device>>(m_device, numValues * sizeof(CoeffReturnType));
m_device.deallocate_temp(m_buffer); m_buffer = static_cast<EvaluatorPointerType>(m_buffer_holder->ptr());
}
m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(numValues * sizeof(CoeffReturnType)));
typedef TensorEvalToOp<const std::remove_const_t<ArgType>> EvalTo; typedef TensorEvalToOp<const std::remove_const_t<ArgType>> EvalTo;
EvalTo evalToTmp(m_device.get(m_buffer), m_op); EvalTo evalToTmp(m_device.get(m_buffer), m_op);
@ -171,8 +193,8 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
#endif #endif
EIGEN_STRONG_INLINE void cleanup() { EIGEN_STRONG_INLINE void cleanup() {
m_device.deallocate_temp(m_buffer); m_buffer_holder = nullptr;
m_buffer = NULL; m_buffer = nullptr;
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_buffer[index]; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_buffer[index]; }
@ -188,7 +210,7 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
bool /*root_of_expr_ast*/ = false) const { bool /*root_of_expr_ast*/ = false) const {
eigen_assert(m_buffer != NULL); eigen_assert(m_buffer != nullptr);
return TensorBlock::materialize(m_buffer, m_impl.dimensions(), desc, scratch); return TensorBlock::materialize(m_buffer, m_impl.dimensions(), desc, scratch);
} }
@ -196,13 +218,14 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const { return m_buffer; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EvaluatorPointerType data() const { return m_buffer; }
private: private:
TensorEvaluator<ArgType, Device> m_impl; TensorEvaluator<ArgType, Device> m_impl;
const ArgType m_op; const ArgType m_op;
const Device EIGEN_DEVICE_REF m_device; const Device EIGEN_DEVICE_REF m_device;
EvaluatorPointerType m_buffer; std::shared_ptr<DeviceTempPointerHolder<Device>> m_buffer_holder;
EvaluatorPointerType m_buffer; // Cached copy of the value stored in m_buffer_holder.
}; };
} // end namespace Eigen } // end namespace Eigen