Fix TensorForcedEval in the case of the evaluator being copied.

This commit is contained in:
Antonio Sánchez 2024-01-10 00:45:39 +00:00 committed by Rasmus Munk Larsen
parent 3f3bc6d862
commit 5385773015

View File

@ -13,6 +13,8 @@
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
#include <memory>
namespace Eigen {
/** \class TensorForcedEval
@ -91,6 +93,26 @@ struct non_integral_type_placement_new<Eigen::SyclDevice, CoeffReturnType> {
};
} // end namespace internal
template <typename Device>
class DeviceTempPointerHolder {
public:
DeviceTempPointerHolder(const Device& device, size_t size)
: device_(device), size_(size), ptr_(device.allocate_temp(size)) {}
~DeviceTempPointerHolder() {
device_.deallocate_temp(ptr_);
size_ = 0;
ptr_ = nullptr;
}
void* ptr() { return ptr_; }
private:
Device device_;
size_t size_;
void* ptr_;
};
template <typename ArgType_, typename Device>
struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
typedef const internal::remove_all_t<ArgType_> ArgType;
@ -124,7 +146,11 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
//===--------------------------------------------------------------------===//
TensorEvaluator(const XprType& op, const Device& device)
: m_impl(op.expression(), device), m_op(op.expression()), m_device(device), m_buffer(NULL) {}
: m_impl(op.expression(), device),
m_op(op.expression()),
m_device(device),
m_buffer_holder(nullptr),
m_buffer(nullptr) {}
~TensorEvaluator() { cleanup(); }
@ -132,11 +158,8 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType) {
const Index numValues = internal::array_prod(m_impl.dimensions());
if (m_buffer != nullptr) {
m_device.deallocate_temp(m_buffer);
}
m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(numValues * sizeof(CoeffReturnType)));
m_buffer_holder = std::make_shared<DeviceTempPointerHolder<Device>>(m_device, numValues * sizeof(CoeffReturnType));
m_buffer = static_cast<EvaluatorPointerType>(m_buffer_holder->ptr());
internal::non_integral_type_placement_new<Device, CoeffReturnType>()(numValues, m_buffer);
@ -154,10 +177,9 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
template <typename EvalSubExprsCallback>
EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType, EvalSubExprsCallback done) {
const Index numValues = internal::array_prod(m_impl.dimensions());
if (m_buffer != nullptr) {
m_device.deallocate_temp(m_buffer);
}
m_buffer = m_device.get((CoeffReturnType*)m_device.allocate_temp(numValues * sizeof(CoeffReturnType)));
m_buffer_holder = std::make_shared<DeviceTempPointerHolder<Device>>(m_device, numValues * sizeof(CoeffReturnType));
m_buffer = static_cast<EvaluatorPointerType>(m_buffer_holder->ptr());
typedef TensorEvalToOp<const std::remove_const_t<ArgType>> EvalTo;
EvalTo evalToTmp(m_device.get(m_buffer), m_op);
@ -171,8 +193,8 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
#endif
EIGEN_STRONG_INLINE void cleanup() {
m_device.deallocate_temp(m_buffer);
m_buffer = NULL;
m_buffer_holder = nullptr;
m_buffer = nullptr;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { return m_buffer[index]; }
@ -188,7 +210,7 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorBlock block(TensorBlockDesc& desc, TensorBlockScratch& scratch,
bool /*root_of_expr_ast*/ = false) const {
eigen_assert(m_buffer != NULL);
eigen_assert(m_buffer != nullptr);
return TensorBlock::materialize(m_buffer, m_impl.dimensions(), desc, scratch);
}
@ -196,13 +218,14 @@ struct TensorEvaluator<const TensorForcedEvalOp<ArgType_>, Device> {
return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EvaluatorPointerType data() const { return m_buffer; }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE EvaluatorPointerType data() const { return m_buffer; }
private:
TensorEvaluator<ArgType, Device> m_impl;
const ArgType m_op;
const Device EIGEN_DEVICE_REF m_device;
EvaluatorPointerType m_buffer;
std::shared_ptr<DeviceTempPointerHolder<Device>> m_buffer_holder;
EvaluatorPointerType m_buffer; // Cached copy of the value stored in m_buffer_holder.
};
} // end namespace Eigen