mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Fix TensorContractionOp evaluators for GPU and SYCL
This commit is contained in:
parent
038b55464b
commit
c95aacab90
@ -23,15 +23,18 @@
|
|||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
|
template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
|
||||||
template<typename Indices, typename LeftArgType, typename RightArgType>
|
template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
|
||||||
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> :
|
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> :
|
||||||
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > {
|
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> > {
|
||||||
|
|
||||||
|
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
|
||||||
|
"SYCL tensor contraction does not support output kernels.");
|
||||||
|
|
||||||
typedef const Eigen::SyclDevice Device;
|
typedef const Eigen::SyclDevice Device;
|
||||||
|
|
||||||
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
|
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
|
||||||
typedef TensorContractionEvaluatorBase<Self> Base;
|
typedef TensorContractionEvaluatorBase<Self> Base;
|
||||||
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
|
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
|
||||||
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user