// This file is part of Eigen, a lightweight C++ template library // for linear algebra. // // Copyright (C) 2023 Charlie Schlosser // // This Source Code Form is subject to the terms of the Mozilla // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. #ifndef EIGEN_DEVICEWRAPPER_H #define EIGEN_DEVICEWRAPPER_H namespace Eigen { template struct DeviceWrapper { using Base = EigenBase>; using Scalar = typename Derived::Scalar; EIGEN_DEVICE_FUNC DeviceWrapper(Base& xpr, Device& device) : m_xpr(xpr.derived()), m_device(device) {} EIGEN_DEVICE_FUNC DeviceWrapper(const Base& xpr, Device& device) : m_xpr(xpr.derived()), m_device(device) {} template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator=(const EigenBase& other) { using AssignOp = internal::assign_op; internal::call_assignment(*this, other.derived(), AssignOp()); return m_xpr; } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator+=(const EigenBase& other) { using AddAssignOp = internal::add_assign_op; internal::call_assignment(*this, other.derived(), AddAssignOp()); return m_xpr; } template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator-=(const EigenBase& other) { using SubAssignOp = internal::sub_assign_op; internal::call_assignment(*this, other.derived(), SubAssignOp()); return m_xpr; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& derived() { return m_xpr; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Device& device() { return m_device; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE NoAlias noalias() { return NoAlias(*this); } Derived& m_xpr; Device& m_device; }; namespace internal { // this is where we differentiate between lazy assignment and specialized kernels (e.g. matrix products) template ::Shape, typename evaluator_traits::Shape>::Kind, typename EnableIf = void> struct AssignmentWithDevice; // unless otherwise specified, use the default product implementation template struct AssignmentWithDevice, Functor, Device, Dense2Dense, Weak> { using SrcXprType = Product; using Base = Assignment; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src, const Functor& func, Device&) { Base::run(dst, src, func); }; }; // specialization for coeffcient-wise assignment template struct AssignmentWithDevice { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(DstXprType& dst, const SrcXprType& src, const Functor& func, Device& device) { #ifndef EIGEN_NO_DEBUG internal::check_for_aliasing(dst, src); #endif call_dense_assignment_loop(dst, src, func, device); } }; // this allows us to use the default evaluation scheme if it is not specialized for the device template struct dense_assignment_loop_with_device { using Base = dense_assignment_loop; static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR void run(Kernel& kernel, Device&) { Base::run(kernel); } }; // entry point for a generic expression with device template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR void call_assignment_no_alias(DeviceWrapper dst, const Src& src, const Func& func) { enum { NeedToTranspose = ((int(Dst::RowsAtCompileTime) == 1 && int(Src::ColsAtCompileTime) == 1) || (int(Dst::ColsAtCompileTime) == 1 && int(Src::RowsAtCompileTime) == 1)) && int(Dst::SizeAtCompileTime) != 1 }; using ActualDstTypeCleaned = std::conditional_t, Dst>; using ActualDstType = std::conditional_t, Dst&>; ActualDstType actualDst(dst.derived()); // TODO check whether this is the right place to perform these checks: EIGEN_STATIC_ASSERT_LVALUE(Dst) EIGEN_STATIC_ASSERT_SAME_MATRIX_SIZE(ActualDstTypeCleaned, Src) EIGEN_CHECK_BINARY_COMPATIBILIY(Func, typename ActualDstTypeCleaned::Scalar, typename Src::Scalar); // this provides a mechanism for specializing simple assignments, matrix products, etc AssignmentWithDevice::run(actualDst, src, func, dst.device()); } // copy and pasted from AssignEvaluator except forward device to kernel template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE EIGEN_CONSTEXPR void call_dense_assignment_loop(DstXprType& dst, const SrcXprType& src, const Functor& func, Device& device) { using DstEvaluatorType = evaluator; using SrcEvaluatorType = evaluator; SrcEvaluatorType srcEvaluator(src); // NOTE To properly handle A = (A*A.transpose())/s with A rectangular, // we need to resize the destination after the source evaluator has been created. resize_if_allowed(dst, src, func); DstEvaluatorType dstEvaluator(dst); using Kernel = generic_dense_assignment_kernel; Kernel kernel(dstEvaluator, srcEvaluator, func, dst.const_cast_derived()); dense_assignment_loop_with_device::run(kernel, device); } } // namespace internal template template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceWrapper EigenBase::device(Device& device) { return DeviceWrapper(derived(), device); } template template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DeviceWrapper EigenBase::device( Device& device) const { return DeviceWrapper(derived(), device); } } // namespace Eigen #endif