Fix TensorContractionOp evaluators for GPU and SYCL

This commit is contained in:
Eugene Zhulenev 2018-07-17 14:09:37 -07:00
parent 038b55464b
commit c95aacab90
2 changed files with 13 additions and 10 deletions

View File

@ -505,9 +505,9 @@ template<typename Scalar, typename Index, typename LhsMapper,
__global__ void __global__ void
#if defined(EIGEN_HIPCC) #if defined(EIGEN_HIPCC)
__launch_bounds__(512, 1) __launch_bounds__(512, 1)
#else #else
__launch_bounds__(512) __launch_bounds__(512)
#endif #endif
EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs, EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
const OutputMapper output, const OutputMapper output,
const Index m_size, const Index n_size, const Index k_size) { const Index m_size, const Index n_size, const Index k_size) {
@ -698,7 +698,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
#undef prefetch_lhs #undef prefetch_lhs
#undef add_vals #undef add_vals
Index horiz_base = threadIdx.y*4+base_n; Index horiz_base = threadIdx.y*4+base_n;
if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) { if (!CHECK_LHS_BOUNDARY && !CHECK_RHS_BOUNDARY) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -1137,7 +1137,7 @@ template<typename Index, typename LhsMapper,
__global__ void __global__ void
#if defined(EIGEN_HIPCC) #if defined(EIGEN_HIPCC)
__launch_bounds__(256, 1) __launch_bounds__(256, 1)
#else #else
__launch_bounds__(256) __launch_bounds__(256)
#endif #endif
EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs, EigenFloatContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
@ -1184,7 +1184,7 @@ template<typename Index, typename LhsMapper,
__global__ void __global__ void
#if defined(EIGEN_HIPCC) #if defined(EIGEN_HIPCC)
__launch_bounds__(256, 1) __launch_bounds__(256, 1)
#else #else
__launch_bounds__(256) __launch_bounds__(256)
#endif #endif
EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs, EigenFloatContractionKernel16x16(const LhsMapper lhs, const RhsMapper rhs,

View File

@ -23,15 +23,18 @@
namespace Eigen { namespace Eigen {
template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels; template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
template<typename Indices, typename LeftArgType, typename RightArgType> template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> : struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> :
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > { public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, const Eigen::SyclDevice> > {
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
"SYCL tensor contraction does not support output kernels.");
typedef const Eigen::SyclDevice Device; typedef const Eigen::SyclDevice Device;
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base; typedef TensorContractionEvaluatorBase<Self> Base;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;