Creating a pointer type in TensorCustomOp.h

This commit is contained in:
Mehdi Goli 2018-08-08 11:19:02 +01:00
parent 10d286f55b
commit 3055e3a7c2

View File

@ -88,6 +88,7 @@ struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Devi
typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
static const int PacketSize = PacketType<CoeffReturnType, Device>::size; static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
typedef typename internal::remove_all<typename Eigen::internal::traits<XprType>::PointerType>::type * PointerType;
enum { enum {
IsAligned = false, IsAligned = false,
@ -106,7 +107,7 @@ struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Devi
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(PointerType data) {
if (data) { if (data) {
evalTo(data); evalTo(data);
return false; return false;
@ -139,23 +140,22 @@ struct TensorEvaluator<const TensorCustomUnaryOp<CustomUnaryFunc, XprType>, Devi
return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
} }
EIGEN_DEVICE_FUNC typename Eigen::internal::traits<XprType>::PointerType data() const { return m_result; } EIGEN_DEVICE_FUNC PointerType data() const { return m_result; }
#ifdef EIGEN_USE_SYCL #ifdef EIGEN_USE_SYCL
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Device& device() const { return m_device; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Device& device() const { return m_device; }
#endif #endif
protected: protected:
EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { EIGEN_DEVICE_FUNC void evalTo(PointerType data) {
TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result( TensorMap<Tensor<CoeffReturnType, NumDims, Layout, Index> > result(data, m_dimensions);
data, m_dimensions);
m_op.func().eval(m_op.expression(), result, m_device); m_op.func().eval(m_op.expression(), result, m_device);
} }
Dimensions m_dimensions; Dimensions m_dimensions;
const ArgType m_op; const ArgType m_op;
const Device& m_device; const Device& m_device;
CoeffReturnType* m_result; PointerType m_result;
}; };
@ -250,6 +250,7 @@ struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType,
typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType; typedef typename internal::remove_const<typename XprType::CoeffReturnType>::type CoeffReturnType;
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType; typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
static const int PacketSize = PacketType<CoeffReturnType, Device>::size; static const int PacketSize = PacketType<CoeffReturnType, Device>::size;
typedef typename internal::remove_all<typename Eigen::internal::traits<XprType>::PointerType>::type * PointerType;
enum { enum {
IsAligned = false, IsAligned = false,
@ -268,12 +269,12 @@ struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType,
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(CoeffReturnType* data) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(PointerType data) {
if (data) { if (data) {
evalTo(data); evalTo(data);
return false; return false;
} else { } else {
m_result = static_cast<Scalar *>(m_device.allocate_temp(dimensions().TotalSize() * sizeof(Scalar))); m_result = static_cast<PointerType>(m_device.allocate_temp(dimensions().TotalSize() * sizeof(CoeffReturnType)));
evalTo(m_result); evalTo(m_result);
return true; return true;
} }
@ -300,22 +301,22 @@ struct TensorEvaluator<const TensorCustomBinaryOp<CustomBinaryFunc, LhsXprType,
return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize); return TensorOpCost(sizeof(CoeffReturnType), 0, 0, vectorized, PacketSize);
} }
EIGEN_DEVICE_FUNC typename internal::traits<XprType>::PointerType data() const { return m_result; } EIGEN_DEVICE_FUNC PointerType data() const { return m_result; }
#ifdef EIGEN_USE_SYCL #ifdef EIGEN_USE_SYCL
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Device& device() const { return m_device; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Device& device() const { return m_device; }
#endif #endif
protected: protected:
EIGEN_DEVICE_FUNC void evalTo(Scalar* data) { EIGEN_DEVICE_FUNC void evalTo(PointerType data) {
TensorMap<Tensor<Scalar, NumDims, Layout> > result(data, m_dimensions); TensorMap<Tensor<CoeffReturnType, NumDims, Layout> > result(data, m_dimensions);
m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device); m_op.func().eval(m_op.lhsExpression(), m_op.rhsExpression(), result, m_device);
} }
Dimensions m_dimensions; Dimensions m_dimensions;
const XprType m_op; const XprType m_op;
const Device& m_device; const Device& m_device;
CoeffReturnType* m_result; PointerType m_result;
}; };