mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-22 09:39:34 +08:00

Eigen is now able to use triSYCL with EIGEN_SYCL_TRISYCL and TRISYCL_INCLUDE_DIR options Fix contraction kernel with correct nd_item dimension
402 lines
24 KiB
C++
402 lines
24 KiB
C++
// This file is part of Eigen, a lightweight C++ template library
|
|
// for linear algebra.
|
|
//
|
|
// Mehdi Goli Codeplay Software Ltd.
|
|
// Ralph Potter Codeplay Software Ltd.
|
|
// Luke Iwanski Codeplay Software Ltd.
|
|
// Contact: <eigen@codeplay.com>
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla
|
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
/*****************************************************************
|
|
* TensorTensorContractionsycl.h
|
|
*
|
|
* \brief:
|
|
* TensorContractionsycl
|
|
*
|
|
*****************************************************************/
|
|
|
|
#ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
|
|
#define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
|
|
namespace Eigen {
|
|
|
|
template <typename Index, typename LhsScalar, typename RhsScalar,bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels;
|
|
template<typename Indices, typename LeftArgType, typename RightArgType>
|
|
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> :
|
|
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, const Eigen::SyclDevice> > {
|
|
|
|
typedef const Eigen::SyclDevice Device;
|
|
|
|
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
|
|
typedef TensorContractionEvaluatorBase<Self> Base;
|
|
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
|
|
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
|
typedef typename XprType::Index Index;
|
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
|
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
|
|
|
|
enum {
|
|
Layout = TensorEvaluator<LeftArgType, Device>::Layout,
|
|
};
|
|
|
|
// Most of the code is assuming that both input tensors are ColMajor. If the
|
|
// inputs are RowMajor, we will "cheat" by swapping the LHS and RHS:
|
|
// If we want to compute A * B = C, where A is LHS and B is RHS, the code
|
|
// will pretend B is LHS and A is RHS.
|
|
typedef typename internal::conditional<
|
|
static_cast<int>(Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
|
|
typedef typename internal::conditional<
|
|
static_cast<int>(Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
|
|
|
|
static const int LDims =
|
|
internal::array_size<typename TensorEvaluator<EvalLeftArgType, Device>::Dimensions>::value;
|
|
static const int RDims =
|
|
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
|
|
static const int ContractDims = internal::array_size<Indices>::value;
|
|
|
|
typedef array<Index, LDims> left_dim_mapper_t;
|
|
typedef array<Index, RDims> right_dim_mapper_t;
|
|
|
|
typedef array<Index, ContractDims> contract_t;
|
|
typedef array<Index, LDims - ContractDims> left_nocontract_t;
|
|
typedef array<Index, RDims - ContractDims> right_nocontract_t;
|
|
|
|
static const int NumDims = LDims + RDims - 2 * ContractDims;
|
|
|
|
typedef DSizes<Index, NumDims> Dimensions;
|
|
|
|
// typedefs needed in evalTo
|
|
typedef typename internal::remove_const<typename EvalLeftArgType::Scalar>::type LhsScalar;
|
|
typedef typename internal::remove_const<typename EvalRightArgType::Scalar>::type RhsScalar;
|
|
|
|
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
|
|
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
|
|
|
typedef typename LeftEvaluator::Dimensions LeftDimensions;
|
|
typedef typename RightEvaluator::Dimensions RightDimensions;
|
|
|
|
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
|
|
Base(op, device) {}
|
|
|
|
// We need to redefine this method to make nvcc happy
|
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* data) {
|
|
this->m_leftImpl.evalSubExprsIfNeeded(NULL);
|
|
this->m_rightImpl.evalSubExprsIfNeeded(NULL);
|
|
if (data) {
|
|
evalTo(data);
|
|
return false;
|
|
} else {
|
|
this->m_result = static_cast<Scalar*>(this->m_device.allocate(this->dimensions().TotalSize() * sizeof(Scalar)));
|
|
evalTo(this->m_result);
|
|
return true;
|
|
}
|
|
}
|
|
const Eigen::SyclDevice& device() const {return this->m_device;}
|
|
void evalTo(Scalar* buffer) const {
|
|
// Here is the result
|
|
if (this->m_lhs_inner_dim_contiguous) {
|
|
if (this->m_rhs_inner_dim_contiguous) {
|
|
if (this->m_rhs_inner_dim_reordered) {
|
|
evalTyped<true, true, true, Unaligned>(buffer);
|
|
}
|
|
else {
|
|
evalTyped<true, true, false, Unaligned>(buffer);
|
|
}
|
|
}
|
|
else {
|
|
if (this->m_rhs_inner_dim_reordered) {
|
|
evalTyped<true, false, true, Unaligned>(buffer);
|
|
}
|
|
else {
|
|
evalTyped<true, false, false, Unaligned>(buffer);
|
|
}
|
|
}
|
|
}
|
|
else {
|
|
if (this->m_rhs_inner_dim_contiguous) {
|
|
if (this->m_rhs_inner_dim_reordered) {
|
|
evalTyped<false, true, true, Unaligned>(buffer);
|
|
}
|
|
else {
|
|
evalTyped<false, true, false, Unaligned>(buffer);
|
|
}
|
|
}
|
|
else {
|
|
if (this->m_rhs_inner_dim_reordered) {
|
|
evalTyped<false, false, true, Unaligned>(buffer);
|
|
}
|
|
else {
|
|
evalTyped<false, false, false, Unaligned>(buffer);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
|
void evalTyped(Scalar* buffer) const {
|
|
// columns in left side, rows in right side
|
|
const Index k = this->m_k_size;
|
|
EIGEN_UNUSED_VARIABLE(k)
|
|
// rows in left side
|
|
const Index m = this->m_i_size;
|
|
// columns in right side
|
|
const Index n = this->m_j_size;
|
|
|
|
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
|
|
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
|
|
LaunchSyclKernels<Index, LhsScalar, RhsScalar,lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered>::Run(*this, buffer, m, n, k,
|
|
this->m_k_strides, this->m_left_contracting_strides, this->m_right_contracting_strides,
|
|
this->m_i_strides, this->m_j_strides, this->m_left_nocontract_strides, this->m_right_nocontract_strides);
|
|
}
|
|
// required by sycl to construct the expr on the device. Returns original left_impl
|
|
const TensorEvaluator<LeftArgType, Device>& left_impl() const {
|
|
return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_leftImpl, this->m_rightImpl);
|
|
}
|
|
// required by sycl to construct the expr on the device. Returns original right_impl
|
|
const TensorEvaluator<RightArgType, Device>& right_impl() const {
|
|
return choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), this->m_rightImpl, this->m_leftImpl);
|
|
}
|
|
};
|
|
|
|
template <typename HostExpr, typename OutScalar, typename LhsScalar, typename RhsScalar, typename LHSFunctorExpr, typename RHSFunctorExpr, typename LhsLocalAcc, typename RhsLocalAcc, typename OutAccessor, typename Index, typename ContractT, typename LeftNocontractT,
|
|
typename RightNocontractT, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered,
|
|
typename HostExpr::Index TileSizeDimM, typename HostExpr::Index TileSizeDimN,typename HostExpr::Index TileSizeDimK, typename HostExpr::Index WorkLoadPerThreadM,typename HostExpr::Index WorkLoadPerThreadN,
|
|
typename HostExpr::Index LocalThreadSizeM, typename HostExpr::Index LocalThreadSizeN, typename HostExpr::Index LoadPerThreadLhs, typename HostExpr::Index LoadPerThreadRhs, typename LHSTupleType, typename RHSTupleType, typename Device> struct KernelConstructor{
|
|
typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr;
|
|
typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr;
|
|
typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<LHSHostExpr>::Type LHSPlaceHolderExpr;
|
|
typedef typename Eigen::TensorSycl::internal::createPlaceHolderExpression<RHSHostExpr>::Type RHSPlaceHolderExpr;
|
|
LHSFunctorExpr lhs_functors;
|
|
RHSFunctorExpr rhs_functors;
|
|
LhsLocalAcc localLhs;
|
|
RhsLocalAcc localRhs;
|
|
OutAccessor out_res;
|
|
size_t out_offset;
|
|
Index roundUpK, M, N, K;
|
|
ContractT m_k_strides, m_left_contracting_strides, m_right_contracting_strides;
|
|
LeftNocontractT m_i_strides, m_left_nocontract_strides;
|
|
RightNocontractT m_j_strides, m_right_nocontract_strides;
|
|
LHSTupleType left_tuple_of_accessors;
|
|
RHSTupleType right_tuple_of_accessors;
|
|
Device dev;
|
|
|
|
|
|
KernelConstructor(LHSFunctorExpr lhs_functors_, RHSFunctorExpr rhs_functors_, LhsLocalAcc localLhs_, RhsLocalAcc localRhs_, OutAccessor out_res_, size_t out_offset_,
|
|
Index roundUpK_, Index M_, Index N_, Index K_, ContractT m_k_strides_, ContractT m_left_contracting_strides_,
|
|
ContractT m_right_contracting_strides_, LeftNocontractT m_i_strides_, RightNocontractT m_j_strides_,
|
|
LeftNocontractT m_left_nocontract_strides_, RightNocontractT m_right_nocontract_strides_, LHSTupleType left_tuple_of_accessors_, RHSTupleType right_tuple_of_accessors_, Device dev_)
|
|
:lhs_functors(lhs_functors_), rhs_functors(rhs_functors_), localLhs(localLhs_), localRhs(localRhs_), out_res(out_res_),
|
|
out_offset(out_offset_), roundUpK(roundUpK_), M(M_), N(N_), K(K_),
|
|
m_k_strides(m_k_strides_), m_left_contracting_strides(m_left_contracting_strides_),
|
|
m_right_contracting_strides(m_right_contracting_strides_),
|
|
m_i_strides(m_i_strides_), m_left_nocontract_strides(m_left_nocontract_strides_),
|
|
m_j_strides(m_j_strides_), m_right_nocontract_strides(m_right_nocontract_strides_),
|
|
left_tuple_of_accessors(left_tuple_of_accessors_), right_tuple_of_accessors(right_tuple_of_accessors_), dev(dev_){}
|
|
|
|
void operator()(cl::sycl::nd_item<2> itemID) {
|
|
typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<HostExpr>::Type DevExpr;
|
|
typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<LHSHostExpr>::Type LHSDevExpr;
|
|
typedef typename Eigen::TensorSycl::internal::ConvertToDeviceExpression<RHSHostExpr>::Type RHSDevExpr;
|
|
auto lhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<LHSDevExpr, LHSPlaceHolderExpr>(lhs_functors, left_tuple_of_accessors);
|
|
auto rhs_dev_expr = Eigen::TensorSycl::internal::createDeviceExpression<RHSDevExpr, RHSPlaceHolderExpr>(rhs_functors, right_tuple_of_accessors);
|
|
typedef decltype(lhs_dev_expr.expr) LeftArgType;
|
|
typedef decltype(rhs_dev_expr.expr) RightArgType;
|
|
typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), LeftArgType, RightArgType>::type EvalLeftArgType;
|
|
typedef typename internal::conditional<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor), RightArgType, LeftArgType>::type EvalRightArgType;
|
|
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
|
|
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
|
typedef internal::TensorContractionInputMapper<LhsScalar, Index, internal::Lhs,
|
|
LeftEvaluator, LeftNocontractT,
|
|
ContractT, 1,
|
|
lhs_inner_dim_contiguous,
|
|
false, Unaligned, MakeGlobalPointer> LhsMapper;
|
|
|
|
typedef internal::TensorContractionInputMapper<RhsScalar, Index, internal::Rhs,
|
|
RightEvaluator, RightNocontractT,
|
|
ContractT, 1,
|
|
rhs_inner_dim_contiguous,
|
|
rhs_inner_dim_reordered, Unaligned, MakeGlobalPointer> RhsMapper;
|
|
// initialize data mappers must happen inside the kernel for device eval
|
|
LhsMapper lhs(LeftEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(),
|
|
lhs_dev_expr.expr, rhs_dev_expr.expr), dev), m_left_nocontract_strides, m_i_strides, m_left_contracting_strides, m_k_strides);
|
|
RhsMapper rhs(RightEvaluator(choose(Cond<static_cast<int>(Eigen::internal::traits<DevExpr>::Layout) == static_cast<int>(ColMajor)>(),
|
|
rhs_dev_expr.expr, lhs_dev_expr.expr),dev), m_right_nocontract_strides, m_j_strides, m_right_contracting_strides, m_k_strides);
|
|
auto out_ptr = ConvertToActualTypeSycl(OutScalar, out_res);
|
|
// Matmul Kernel
|
|
// Thread identifiers
|
|
const Index mLocalThreadId = itemID.get_local(0); // Local ID row
|
|
const Index nLocalThreadId = itemID.get_local(1); // Local ID col
|
|
const Index mGroupId = itemID.get_group(0); // Work-group ID row
|
|
const Index nGroupId = itemID.get_group(1); // Work-group ID localCol
|
|
const Index linearLocalThreadId = nLocalThreadId*LocalThreadSizeM + mLocalThreadId; // linear local thread ID
|
|
// Allocate register space
|
|
LhsScalar privateLhs;
|
|
RhsScalar privateRhs[WorkLoadPerThreadN];
|
|
OutScalar privateRes[WorkLoadPerThreadM][WorkLoadPerThreadN];
|
|
// Initialise the privateResumulation registers
|
|
for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
|
|
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
|
|
privateRes[wLPTM][wLPTN] = static_cast<OutScalar>(0);
|
|
}
|
|
}
|
|
|
|
// Tile Lhs
|
|
for (Index lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
|
|
Index localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
|
|
Index localLhsRow = localLhsLinearId% TileSizeDimM;
|
|
Index localLhsCol = localLhsLinearId/TileSizeDimM;
|
|
// Load the value (wide vector load)
|
|
Index GlobalLhsColId = TileSizeDimK*0 + localLhsCol;
|
|
localLhs[0 + ((localLhsCol*TileSizeDimM + localLhsRow)*2)] =((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId):static_cast<OutScalar>(0);
|
|
}
|
|
// Tile Rhs
|
|
for (Index lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
|
|
Index localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
|
|
Index localRhsRow = localRhsLinearId% TileSizeDimN;
|
|
Index localRhsCol = localRhsLinearId/TileSizeDimN;
|
|
// Load the value (wide vector load)
|
|
Index GlobalRhsRowId = TileSizeDimK*0 + localRhsCol;
|
|
localRhs[0 + ((localRhsCol*TileSizeDimN + localRhsRow) *2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow): static_cast<OutScalar>(0);
|
|
|
|
}
|
|
// Loop over all tiles
|
|
const Index numTiles = roundUpK/TileSizeDimK;
|
|
Index firstHalf=0;
|
|
do {
|
|
// Synchronise
|
|
itemID.barrier(cl::sycl::access::fence_space::local_space);
|
|
// Load the next tile of Lhs and Rhs into local memory
|
|
Index nextHalf = firstHalf + 1;
|
|
if (nextHalf < numTiles) {
|
|
// Tile A
|
|
for (Index lPTL=0; lPTL<LoadPerThreadLhs; lPTL++) {
|
|
Index localLhsLinearId = lPTL*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
|
|
Index localLhsRow = localLhsLinearId% TileSizeDimM;
|
|
Index localLhsCol = localLhsLinearId/TileSizeDimM;
|
|
// global K id
|
|
Index GlobalLhsColId = TileSizeDimK*nextHalf + localLhsCol;
|
|
// Store the loaded value into local memory
|
|
localLhs[(nextHalf%2) + ((localLhsCol*TileSizeDimM + localLhsRow) *2)] = ((GlobalLhsColId < K)&& (mGroupId*(TileSizeDimM)+ localLhsRow <M))? lhs(mGroupId*(TileSizeDimM) + localLhsRow, GlobalLhsColId): static_cast<OutScalar>(0);
|
|
}
|
|
// Tile B
|
|
for (Index lPTR=0; lPTR<LoadPerThreadRhs; lPTR++) {
|
|
Index localRhsLinearId = lPTR*LocalThreadSizeN*LocalThreadSizeM + linearLocalThreadId;
|
|
Index localRhsRow = localRhsLinearId% TileSizeDimN;
|
|
Index localRhsCol = localRhsLinearId/TileSizeDimN;
|
|
// Load the value (wide vector load)
|
|
Index GlobalRhsRowId = TileSizeDimK*nextHalf + localRhsCol;
|
|
// Store the loaded vector into local memory
|
|
localRhs[(nextHalf%2) +((localRhsCol*TileSizeDimN + localRhsRow)*2)] = ((GlobalRhsRowId < K)&& ((nGroupId*(TileSizeDimN) + localRhsRow)< N))? rhs(GlobalRhsRowId, nGroupId*(TileSizeDimN) + localRhsRow):static_cast<OutScalar>(0);
|
|
}
|
|
}
|
|
// Loop over the values of a single tile
|
|
for (Index k=0; k<TileSizeDimK; k++) {
|
|
// Cache the values of localRhs in registers
|
|
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
|
|
Index localRhsCol = nLocalThreadId + wLPTN*LocalThreadSizeN;
|
|
privateRhs[wLPTN] = localRhs[(firstHalf%2) +((k*TileSizeDimN + localRhsCol)*2)];
|
|
}
|
|
// Perform the computation
|
|
for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
|
|
Index localLhsRow = mLocalThreadId + wLPTM*LocalThreadSizeM;
|
|
privateLhs = localLhs[(firstHalf%2)+ ((k*TileSizeDimM + localLhsRow)*2)];
|
|
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
|
|
privateRes[wLPTM][wLPTN] += privateLhs * privateRhs[wLPTN];
|
|
}
|
|
}
|
|
}
|
|
// Next tile
|
|
firstHalf++;
|
|
} while (firstHalf<numTiles);
|
|
|
|
// Store the final results in C
|
|
for (Index wLPTM=0; wLPTM<WorkLoadPerThreadM; wLPTM++) {
|
|
Index globalRow = mGroupId*TileSizeDimM + mLocalThreadId + wLPTM*LocalThreadSizeM;
|
|
if (globalRow< M){
|
|
for (Index wLPTN=0; wLPTN<WorkLoadPerThreadN; wLPTN++) {
|
|
Index globalCol = nGroupId*TileSizeDimN + nLocalThreadId + wLPTN*LocalThreadSizeN;
|
|
if(globalCol<N)
|
|
out_ptr[globalCol*M + globalRow +ConvertToActualSyclOffset(OutScalar, out_offset)] = privateRes[wLPTM][wLPTN];
|
|
}
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
template <typename Index, typename LhsScalar, typename RhsScalar, bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered> struct LaunchSyclKernels {
|
|
|
|
static const Index TileSizeDimM = 32ul; // Tile size for dimension M
|
|
static const Index TileSizeDimN = 32ul; // Tile size for dimension N
|
|
static const Index TileSizeDimK = 16ul; // Tile size for dimension K
|
|
static const Index WorkLoadPerThreadM = 4ul; // Work load per thread in dimension M
|
|
static const Index WorkLoadPerThreadN = 4ul; // work load per thread in dimension N
|
|
static const Index LocalThreadSizeM = (TileSizeDimM/WorkLoadPerThreadM); // Local thread size for the first dimension (M here)
|
|
static const Index LocalThreadSizeN = (TileSizeDimN/WorkLoadPerThreadN); // Local thread size for the second dimension (N here)
|
|
static const Index LoadPerThreadLhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimN)); // workload per thread for Lhs expression
|
|
static const Index LoadPerThreadRhs = ((TileSizeDimK*WorkLoadPerThreadM*WorkLoadPerThreadN)/(TileSizeDimM)); // workload per thread for Rhs expression
|
|
|
|
// RoundUp function to make sure that the global threadId is divisable by local threadId
|
|
static Index RoundUp(Index x, Index y) {
|
|
return ((((x) + (y) - 1) / (y))*(y));
|
|
}
|
|
|
|
template< typename Self, typename OutScalar, typename ContractT, typename LeftNocontractT, typename RightNocontractT>
|
|
static void Run(const Self& self, OutScalar* buffer, Index M, Index N, Index K,
|
|
ContractT m_k_strides, ContractT m_left_contracting_strides, ContractT m_right_contracting_strides,
|
|
LeftNocontractT m_i_strides, RightNocontractT m_j_strides, LeftNocontractT m_left_nocontract_strides, RightNocontractT m_right_nocontract_strides){
|
|
|
|
typedef typename Self::XprType HostExpr;
|
|
typedef typename Eigen::internal::traits<HostExpr>::_LhsNested LHSHostExpr;
|
|
typedef typename Eigen::internal::traits<HostExpr>::_RhsNested RHSHostExpr;
|
|
typedef TensorEvaluator<LHSHostExpr, const Eigen::SyclDevice> OrigLHSExpr;
|
|
typedef TensorEvaluator<RHSHostExpr, const Eigen::SyclDevice> OrigRHSExpr;
|
|
typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigLHSExpr> LHSFunctorExpr;
|
|
typedef Eigen::TensorSycl::internal::FunctorExtractor<OrigRHSExpr> RHSFunctorExpr;
|
|
// extract lhs functor list
|
|
LHSFunctorExpr lhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.left_impl());
|
|
// extract rhs functor list
|
|
RHSFunctorExpr rhs_functors = Eigen::TensorSycl::internal::extractFunctors(self.right_impl());
|
|
|
|
Index roundUpK = RoundUp(K, TileSizeDimK);
|
|
Index roundUpM = RoundUp(M, TileSizeDimM);
|
|
Index roundUpN = RoundUp(N, TileSizeDimN);
|
|
ptrdiff_t out_offset = self.device().get_offset(buffer);
|
|
self.device().sycl_queue().submit([&](cl::sycl::handler &cgh) {
|
|
/// work-around for gcc bug
|
|
typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl())) LHSTupleType;
|
|
/// work-around for gcc bug
|
|
typedef decltype(Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl())) RHSTupleType;
|
|
// create lhs tuple of accessors
|
|
LHSTupleType left_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigLHSExpr>(cgh, self.left_impl());
|
|
// create rhs tuple of accessors
|
|
RHSTupleType right_tuple_of_accessors = Eigen::TensorSycl::internal::createTupleOfAccessors<OrigRHSExpr>(cgh, self.right_impl());
|
|
|
|
// Local memory for elements of Lhs
|
|
typedef cl::sycl::accessor<LhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> LhsLocalAcc;
|
|
LhsLocalAcc localLhs(cl::sycl::range<1>(2* TileSizeDimM * TileSizeDimK), cgh);
|
|
// Local memory for elements of Rhs
|
|
typedef cl::sycl::accessor<RhsScalar, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::local> RhsLocalAcc;
|
|
RhsLocalAcc localRhs(cl::sycl::range<1>(2* TileSizeDimK * TileSizeDimN), cgh);
|
|
|
|
typedef cl::sycl::accessor<uint8_t, 1, cl::sycl::access::mode::read_write, cl::sycl::access::target::global_buffer> OutAccessor;
|
|
//OutScalar memory
|
|
OutAccessor out_res= self.device(). template get_sycl_accessor<cl::sycl::access::mode::read_write>(cgh, buffer);
|
|
// sycl parallel for
|
|
cgh.parallel_for(cl::sycl::nd_range<2>(cl::sycl::range<2>(roundUpM/WorkLoadPerThreadM, roundUpN/WorkLoadPerThreadN),
|
|
cl::sycl::range<2>(LocalThreadSizeM, LocalThreadSizeN)),
|
|
KernelConstructor<HostExpr, OutScalar, LhsScalar, RhsScalar, LHSFunctorExpr, RHSFunctorExpr, LhsLocalAcc, RhsLocalAcc, OutAccessor, Index, ContractT, LeftNocontractT,
|
|
RightNocontractT, lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, TileSizeDimM, TileSizeDimN, TileSizeDimK,
|
|
WorkLoadPerThreadM, WorkLoadPerThreadN, LocalThreadSizeM, LocalThreadSizeN, LoadPerThreadLhs, LoadPerThreadRhs, LHSTupleType, RHSTupleType, Eigen::SyclKernelDevice>(lhs_functors, rhs_functors,
|
|
localLhs, localRhs, out_res, out_offset, roundUpK, M, N, K, m_k_strides, m_left_contracting_strides, m_right_contracting_strides,m_i_strides, m_j_strides,
|
|
m_left_nocontract_strides,m_right_nocontract_strides, left_tuple_of_accessors, right_tuple_of_accessors, Eigen::SyclKernelDevice()));
|
|
});
|
|
self.device().asynchronousExec();
|
|
}
|
|
};
|
|
|
|
} // end namespace Eigen
|
|
#endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_SYCL_H
|