Add support for custom packed Lhs/Rhs blocks in tensor contractions

This commit is contained in:
Eugene Zhulenev 2019-04-01 11:47:31 -07:00
parent 45e65fbb77
commit 4e2f6de1a8
3 changed files with 247 additions and 102 deletions

View File

@ -105,7 +105,9 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKern
static const int NumDimensions = traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value; static const int NumDimensions = traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
static const int Layout = traits<LhsXprType>::Layout; static const int Layout = traits<LhsXprType>::Layout;
typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val, typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType; typename traits<LhsXprType>::PointerType,
typename traits<RhsXprType>::PointerType>::type
PointerType;
enum { enum {
Flags = 0 Flags = 0
@ -136,6 +138,80 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value; static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
}; };
// Helper class to allocate and deallocate temporary memory for packed buffers.
template <typename LhsScalar, typename RhsScalar>
struct TensorContractionBlockMemAllocator {
typedef void* BlockMemHandle;
template <typename Device>
EIGEN_DEVICE_FUNC static BlockMemHandle allocate(Device& d, const Index bm,
const Index bk,
const Index bn,
LhsScalar** lhs_block,
RhsScalar** rhs_block) {
eigen_assert(lhs_block);
eigen_assert(rhs_block);
BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
char* block_mem = static_cast<char*>(d.allocate(sz.lhs_size + sz.rhs_size));
eigen_assert(block_mem);
*lhs_block = reinterpret_cast<LhsScalar*>(block_mem);
*rhs_block = reinterpret_cast<RhsScalar*>(block_mem + sz.lhs_size);
return block_mem;
}
template <typename Device>
EIGEN_DEVICE_FUNC static BlockMemHandle allocateSlices(
Device& d, const Index bm, const Index bk, const Index bn,
const Index num_lhs, const Index num_rhs, const Index num_slices,
std::vector<LhsScalar*>* lhs_blocks,
std::vector<RhsScalar*>* rhs_blocks) {
eigen_assert(num_slices > 0);
eigen_assert(num_lhs >= 0 && num_rhs >= 0)
eigen_assert(num_lhs == 0 || lhs_blocks);
eigen_assert(num_rhs == 0 || rhs_blocks);
BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
void* block_mem = d.allocate(
(num_lhs * sz.lhs_size + num_rhs * sz.rhs_size) * num_slices);
eigen_assert(block_mem);
char* mem = static_cast<char*>(block_mem);
for (Index x = 0; x < num_slices; x++) {
if (num_lhs > 0) lhs_blocks[x].resize(num_lhs);
for (Index m = 0; m < num_lhs; m++) {
lhs_blocks[x][m] = reinterpret_cast<LhsScalar*>(mem);
mem += sz.lhs_size;
}
if (num_rhs > 0) rhs_blocks[x].resize(num_rhs);
for (Index n = 0; n < num_rhs; n++) {
rhs_blocks[x][n] = reinterpret_cast<RhsScalar*>(mem);
mem += sz.rhs_size;
}
}
return block_mem;
}
template <typename Device>
EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) {
d.deallocate(handle);
}
private:
struct BlockSizes {
Index lhs_size;
Index rhs_size;
};
EIGEN_DEVICE_FUNC static BlockSizes ComputeLhsRhsBlockSizes(const Index bm,
const Index bk,
const Index bn) {
Index align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1);
BlockSizes sz;
sz.lhs_size = divup<Index>(bm * bk * sizeof(LhsScalar), align) * align;
sz.rhs_size = divup<Index>(bn * bk * sizeof(RhsScalar), align) * align;
return sz;
}
};
// WARNING: In this code we assume that Lhs and Rhs tensor expressions are in // WARNING: In this code we assume that Lhs and Rhs tensor expressions are in
// ColMajor storage order. This property is guaranteed by the // ColMajor storage order. This property is guaranteed by the
// TensorContractionOp evaluator. TensorContractionKernel specifies how we pack // TensorContractionOp evaluator. TensorContractionKernel specifies how we pack
@ -164,16 +240,28 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
// TensorContractionInputMapper, or some specialization of it based on the // TensorContractionInputMapper, or some specialization of it based on the
// type of tensor expression (e.g. TensorImagePatchOp has optimized input // type of tensor expression (e.g. TensorImagePatchOp has optimized input
// mapper). // mapper).
template<typename ResScalar, typename LhsScalar, typename RhsScalar, template <typename ResScalar, typename LhsScalar, typename RhsScalar,
typename StorageIndex, typename OutputMapper, typename LhsMapper, typename StorageIndex, typename OutputMapper, typename LhsMapper,
typename RhsMapper> typename RhsMapper>
struct TensorContractionKernel { struct TensorContractionKernel {
TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n,
StorageIndex bm, StorageIndex bk, StorageIndex bn)
: m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {}
// Pack blocks of Lhs and Rhs into contiguous blocks in memory.
typedef LhsScalar* LhsBlock;
typedef RhsScalar* RhsBlock;
// Packed Lhs/Rhs block memory allocator.
typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar>
BlockMemAllocator;
typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits; typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
typedef internal::gemm_pack_lhs<LhsScalar, StorageIndex, typedef internal::gemm_pack_lhs<
typename LhsMapper::SubMapper, LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr,
Traits::mr, Traits::LhsProgress, Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor>
typename Traits::LhsPacket4Packing, ColMajor>
LhsPacker; LhsPacker;
typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex, typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex,
@ -186,29 +274,61 @@ struct TensorContractionKernel {
/*ConjugateLhs*/ false, /*ConjugateRhs*/ false> /*ConjugateLhs*/ false, /*ConjugateRhs*/ false>
GebpKernel; GebpKernel;
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE template <typename Device>
static void packLhs(LhsScalar* lhsBlock, EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block,
const typename LhsMapper::SubMapper& data_mapper, RhsBlock* rhs_block) {
return BlockMemAllocator::allocate(d, bm, bk, bn, lhs_block, rhs_block);
}
template <typename Device>
EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(
Device& d, const StorageIndex num_lhs, const StorageIndex num_rhs,
const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks,
std::vector<RhsBlock>* rhs_blocks) {
return BlockMemAllocator::allocateSlices(
d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks);
}
template <typename Device>
EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) {
BlockMemAllocator::deallocate(d, handle);
}
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(
LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper,
const StorageIndex depth, const StorageIndex rows) { const StorageIndex depth, const StorageIndex rows) {
LhsPacker()(lhsBlock, data_mapper, depth, rows, /*stride*/ 0, /*offset*/ 0); LhsPacker()(*lhsBlock, data_mapper, depth, rows, /*stride*/ 0,
/*offset*/ 0);
} }
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(
static void packRhs(RhsScalar* rhsBlock, RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper,
const typename RhsMapper::SubMapper& data_mapper,
const StorageIndex depth, const StorageIndex cols) { const StorageIndex depth, const StorageIndex cols) {
RhsPacker()(rhsBlock, data_mapper, depth, cols); RhsPacker()(*rhsBlock, data_mapper, depth, cols);
} }
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(
static void invoke(const OutputMapper& output_mapper, const OutputMapper& output_mapper, const LhsBlock& lhsBlock,
const LhsScalar* lhsBlock, const RhsScalar* rhsBlock, const RhsBlock& rhsBlock, const StorageIndex rows,
const StorageIndex rows, const StorageIndex depth, const StorageIndex depth, const StorageIndex cols,
const StorageIndex cols, const ResScalar alpha) { const ResScalar alpha) {
static const int kComputeStrideFromBlockDimensions = -1;
GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha, GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha,
/*strideA*/ -1, /*strideB*/ -1, /*strideA*/ kComputeStrideFromBlockDimensions,
/*strideB*/ kComputeStrideFromBlockDimensions,
/*offsetA*/ 0, /*offsetB*/ 0); /*offsetA*/ 0, /*offsetB*/ 0);
} }
private:
// These are dimensions of the original Tensors, and selected block sizes. The
// actual block sizes passed to all function above might be smaller because of
// the partial blocks at the end.
const StorageIndex m;
const StorageIndex k;
const StorageIndex n;
const StorageIndex bm;
const StorageIndex bk;
const StorageIndex bn;
}; };
} // end namespace internal } // end namespace internal
@ -737,11 +857,18 @@ struct TensorContractionEvaluatorBase
const Index kc = blocking.kc(); const Index kc = blocking.kc();
const Index mc = numext::mini(m, blocking.mc()); const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc()); const Index nc = numext::mini(n, blocking.nc());
const Index sizeA = mc * kc;
const Index sizeB = kc * nc;
LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar))); typedef typename TensorContractionKernel::LhsBlock LhsBlock;
RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar))); typedef typename TensorContractionKernel::RhsBlock RhsBlock;
LhsBlock blockA;
RhsBlock blockB;
TensorContractionKernel kernel(m, k_slice, n, mc, kc, nc);
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
const BlockMemHandle packed_mem =
kernel.allocate(this->m_device, &blockA, &blockB);
for(Index i2=0; i2<m; i2+=mc) for(Index i2=0; i2<m; i2+=mc)
{ {
@ -749,22 +876,20 @@ struct TensorContractionEvaluatorBase
for (Index k2 = k_start; k2 < k_end; k2 += kc) { for (Index k2 = k_start; k2 < k_end; k2 += kc) {
// make sure we don't overshoot right edge of left matrix, then pack vertical panel // make sure we don't overshoot right edge of left matrix, then pack vertical panel
const Index actual_kc = numext::mini(k2 + kc, k_end) - k2; const Index actual_kc = numext::mini(k2 + kc, k_end) - k2;
TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2), kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
actual_kc, actual_mc);
// series of horizontal blocks // series of horizontal blocks
for (Index j2 = 0; j2 < n; j2 += nc) { for (Index j2 = 0; j2 < n; j2 += nc) {
// make sure we don't overshoot right edge of right matrix, then pack block // make sure we don't overshoot right edge of right matrix, then pack block
const Index actual_nc = numext::mini(j2 + nc, n) - j2; const Index actual_nc = numext::mini(j2 + nc, n) - j2;
TensorContractionKernel::packRhs(blockB, rhs.getSubMapper(k2, j2), kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc,
actual_kc, actual_nc); actual_nc);
// call gebp (matrix kernel) // call gebp (matrix kernel)
// The parameters here are copied from Eigen's GEMM implementation // The parameters here are copied from Eigen's GEMM implementation
const OutputMapper output_mapper = output.getSubMapper(i2, j2); const OutputMapper output_mapper = output.getSubMapper(i2, j2);
TensorContractionKernel::invoke(output_mapper, blockA, blockB, kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
actual_mc, actual_kc, actual_nc, actual_nc, Scalar(1));
Scalar(1));
// We are done with this [i2, j2] output block. // We are done with this [i2, j2] output block.
if (use_output_kernel && k2 + kc >= k_end) { if (use_output_kernel && k2 + kc >= k_end) {
@ -775,8 +900,7 @@ struct TensorContractionEvaluatorBase
} }
} }
this->m_device.deallocate(blockA); kernel.deallocate(this->m_device, packed_mem);
this->m_device.deallocate(blockB);
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {

View File

@ -24,12 +24,17 @@ enum {
*/ */
/// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which /// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which
/// is scalar * for CoeffLoader. /// is scalar * for CoeffLoader.
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> struct CoeffLoader; template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t, struct CoeffLoader;
int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
template <class> class MakePointer_ = MakePointer> class BaseTensorContractionMapper;
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> struct CoeffLoader { template <typename Scalar, typename Index, int side, typename Tensor,
typename nocontract_t, typename contract_t, int packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
template <class> class MakePointer_ = MakePointer>
class BaseTensorContractionMapper;
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
struct CoeffLoader {
enum { enum {
DirectOffsets = false DirectOffsets = false
}; };
@ -40,6 +45,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer
eigen_assert(false && "unsupported"); eigen_assert(false && "unsupported");
} }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
data() const {
eigen_assert(false && "unsupported");
return NULL;
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@ -48,12 +59,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer
return m_tensor.template packet<LoadMode>(index); return m_tensor.template packet<LoadMode>(index);
} }
private: private:
const Tensor m_tensor; const Tensor m_tensor;
}; };
template <typename Tensor, template <class> class MakePointer_> struct CoeffLoader<Tensor, true, MakePointer_> { template <typename Tensor, template <class> class MakePointer_>
struct CoeffLoader<Tensor, true, MakePointer_> {
enum { enum {
DirectOffsets = true DirectOffsets = true
}; };
@ -64,6 +75,11 @@ template <typename Tensor, template <class> class MakePointer_> struct CoeffLoad
m_data += offset; m_data += offset;
} }
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
data() const {
return m_data;
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
@ -214,6 +230,17 @@ class SimpleTensorContractionMapper {
return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1; return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
} }
const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const {
return m_tensor;
}
const nocontract_t& nocontract_strides() const {
return m_nocontract_strides;
}
const nocontract_t& ij_strides() const { return m_ij_strides; }
const contract_t& contract_strides() const { return m_contract_strides; }
const contract_t& k_strides() const { return m_k_strides; }
protected: protected:
CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor; CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
const nocontract_t m_nocontract_strides; const nocontract_t m_nocontract_strides;
@ -445,6 +472,10 @@ class TensorContractionSubMapper {
return false; return false;
} }
const ParentMapper& base_mapper() const { return m_base_mapper; }
Index vert_offset() const { return m_vert_offset; }
Index horiz_offset() const { return m_horiz_offset; }
private: private:
ParentMapper m_base_mapper; ParentMapper m_base_mapper;
const Index m_vert_offset; const Index m_vert_offset;

View File

@ -280,6 +280,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper> Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
TensorContractionKernel; TensorContractionKernel;
typedef typename TensorContractionKernel::LhsBlock LhsBlock;
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn, Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk, Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col, Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
@ -311,7 +315,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
gm_(gm), gm_(gm),
gn_(gn), gn_(gn),
nm0_(nm0), nm0_(nm0),
nn0_(nn0) nn0_(nn0),
kernel_(m_, k_, n_, bm_, bk_, bn_)
{ {
// These two options are mutually exclusive. // These two options are mutually exclusive.
eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only)); eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
@ -342,26 +347,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
} }
// Allocate memory for packed rhs/lhs matrices. // Allocate memory for packed rhs/lhs matrices.
size_t align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1); packed_mem_ = kernel_.allocateSlices( //
size_t lhs_size = device_, //
divup<size_t>(bm_ * bk_ * sizeof(LhsScalar), align) * align; /*num_lhs=*/nm0_, //
size_t rhs_size = /*num_rhs=*/nn0_, //
divup<size_t>(bn_ * bk_ * sizeof(RhsScalar), align) * align; /*num_slices=*/std::min<Index>(nk_, P - 1), //
packed_mem_ = static_cast<char*>(device_.allocate( packed_lhs_, packed_rhs_);
(nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
char* mem = static_cast<char*>(packed_mem_);
for (Index x = 0; x < numext::mini<Index>(nk_, P - 1); x++) {
packed_lhs_[x].resize(nm0_);
for (Index m = 0; m < nm0_; m++) {
packed_lhs_[x][m] = reinterpret_cast<LhsScalar*>(mem);
mem += lhs_size;
}
packed_rhs_[x].resize(nn0_);
for (Index n = 0; n < nn0_; n++) {
packed_rhs_[x][n] = reinterpret_cast<RhsScalar*>(mem);
mem += rhs_size;
}
}
if (parallelize_by_sharding_dim_only_) { if (parallelize_by_sharding_dim_only_) {
const int num_worker_threads = device_.numThreadsInPool(); const int num_worker_threads = device_.numThreadsInPool();
@ -373,14 +364,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
std::memory_order_relaxed); std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gn_; Index num_blocks = num_worker_threads * gn_;
thread_local_packed_mem_ = device_.allocate(num_blocks * rhs_size); thread_local_packed_mem_ = kernel_.allocateSlices( //
mem = static_cast<char*>(thread_local_packed_mem_); device_, //
/*num_lhs=*/0, //
/*num_rhs=*/num_blocks, //
/*num_slices=*/1, //
/*lhs_blocks=*/nullptr, &thread_local_packed_rhs_);
thread_local_packed_rhs_.resize(num_blocks, nullptr);
for (Index i = 0; i < num_blocks; ++i) {
thread_local_packed_rhs_[i] = reinterpret_cast<RhsScalar*>(mem);
mem += rhs_size;
}
} else { } else {
can_use_thread_local_packed_ = new std::atomic<bool>[nm_]; can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
for (int i = 0; i < nm_; ++i) for (int i = 0; i < nm_; ++i)
@ -388,14 +378,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
std::memory_order_relaxed); std::memory_order_relaxed);
Index num_blocks = num_worker_threads * gm_; Index num_blocks = num_worker_threads * gm_;
thread_local_packed_mem_ = device_.allocate(num_blocks * lhs_size); thread_local_packed_mem_ = kernel_.allocateSlices( //
mem = static_cast<char*>(thread_local_packed_mem_); device_, //
/*num_lhs=*/num_blocks, //
thread_local_packed_lhs_.resize(num_blocks, nullptr); /*num_rhs=*/0, //
for (Index i = 0; i < num_blocks; ++i) { /*num_slices=*/1, &thread_local_packed_lhs_, //
thread_local_packed_lhs_[i] = reinterpret_cast<LhsScalar*>(mem); /*rhs_blocks=*/nullptr);
mem += lhs_size;
}
} }
} }
} }
@ -405,9 +393,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m]; for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
delete[] state_kernel_[x]; delete[] state_kernel_[x];
} }
device_.deallocate(packed_mem_); kernel_.deallocate(device_, packed_mem_);
if (parallelize_by_sharding_dim_only_) { if (parallelize_by_sharding_dim_only_) {
device_.deallocate(thread_local_packed_mem_); kernel_.deallocate(device_, thread_local_packed_mem_);
delete[] can_use_thread_local_packed_; delete[] can_use_thread_local_packed_;
} }
} }
@ -455,6 +443,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// coarsening). // coarsening).
const Index nm0_; const Index nm0_;
const Index nn0_; const Index nn0_;
// Tensor contraction kernel.
TensorContractionKernel kernel_;
// Parallelization strategy. // Parallelization strategy.
// //
@ -491,9 +481,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// actively executing + one to track completion of kernels in the second // actively executing + one to track completion of kernels in the second
// slice. // slice.
static const Index P = 3; static const Index P = 3;
void* packed_mem_;
std::vector<LhsScalar*> packed_lhs_[P - 1]; // Handle to the allocated temporary storage for Lhs/Rhs blocks.
std::vector<RhsScalar*> packed_rhs_[P - 1]; BlockMemHandle packed_mem_;
std::vector<LhsBlock> packed_lhs_[P - 1];
std::vector<RhsBlock> packed_rhs_[P - 1];
// If we choose to parallelize only by the sharding dimension, each thread // If we choose to parallelize only by the sharding dimension, each thread
// will have it's own "thead local" (not a c++ thread local storage) memory // will have it's own "thead local" (not a c++ thread local storage) memory
@ -511,11 +503,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// completion of the K-1 kernel, so we have to allocate "global" packed_lhs_ // completion of the K-1 kernel, so we have to allocate "global" packed_lhs_
// and packed_rhs_ to allow kernels to be executed later on a thread // and packed_rhs_ to allow kernels to be executed later on a thread
// different from the thread that was used for packing. // different from the thread that was used for packing.
void* thread_local_packed_mem_; BlockMemHandle thread_local_packed_mem_;
// Only one of these will beinitialized depending on shard_by_col value. // Only one of these will be initialized depending on shard_by_col value.
std::vector<LhsScalar*> thread_local_packed_lhs_; std::vector<LhsBlock> thread_local_packed_lhs_;
std::vector<RhsScalar*> thread_local_packed_rhs_; std::vector<RhsBlock> thread_local_packed_rhs_;
// After a particular shard for Kth slice missed thread local execution // After a particular shard for Kth slice missed thread local execution
// opportunity (K-1 slice didn't complete kernels execution), we can no // opportunity (K-1 slice didn't complete kernels execution), we can no
@ -532,7 +524,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
std::atomic<Index> state_packing_ready_[P]; std::atomic<Index> state_packing_ready_[P];
std::atomic<Index> state_switch_[P]; std::atomic<Index> state_switch_[P];
LhsScalar* packed_lhs(Index m, Index k, Index m1, bool use_thread_local) { LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
if (use_thread_local) { if (use_thread_local) {
eigen_assert(!shard_by_col_); eigen_assert(!shard_by_col_);
@ -546,7 +538,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
} }
} }
RhsScalar* packed_rhs(Index n, Index k, Index n1, bool use_thread_local) { RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
if (use_thread_local) { if (use_thread_local) {
eigen_assert(shard_by_col_); eigen_assert(shard_by_col_);
@ -580,7 +572,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
} else { } else {
// If we can't guarantee that all kernels in `k` slice will be // If we can't guarantee that all kernels in `k` slice will be
// executed sequentially in current thread, it's no longer safe to use // executed sequentially in current thread, it's no longer safe to use
// thread local memory in followig slices along the k dimensions. // thread local memory in following slices along the k dimensions.
eigen_assert(k > 0); eigen_assert(k > 0);
can_use_thread_local_packed_[m].store(false, can_use_thread_local_packed_[m].store(false,
std::memory_order_relaxed); std::memory_order_relaxed);
@ -589,9 +581,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index mend = m * gm_ + gm(m); const Index mend = m * gm_ + gm(m);
for (Index m1 = m * gm_; m1 < mend; m1++) for (Index m1 = m * gm_; m1 < mend; m1++)
TensorContractionKernel::packLhs(packed_lhs(m, k, m1, use_thread_local), kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
lhs_.getSubMapper(m1 * bm_, k * bk_), lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
bk(k), bm(m1));
if (!parallel_pack_ && shard_by_col_) { if (!parallel_pack_ && shard_by_col_) {
assert(!use_thread_local); assert(!use_thread_local);
@ -634,9 +625,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// deadlocks. // deadlocks.
memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar)); memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
} }
TensorContractionKernel::packRhs(packed_rhs(n, k, n1, use_thread_local), kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
rhs_.getSubMapper(k * bk_, n1 * bn_), rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
bk(k), bn(n1));
} }
if (parallel_pack_ || shard_by_col_) { if (parallel_pack_ || shard_by_col_) {
@ -661,7 +651,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index n1 = n * gn_; n1 < nend; n1++) {
for (Index m1 = m * gm_; m1 < mend; m1++) { for (Index m1 = m * gm_; m1 < mend; m1++) {
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
TensorContractionKernel::invoke( kernel_.invoke(
output_mapper, output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
@ -678,7 +668,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
for (Index m1 = m * gm_; m1 < mend; m1++) for (Index m1 = m * gm_; m1 < mend; m1++)
for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index n1 = n * gn_; n1 < nend; n1++) {
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_); const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
TensorContractionKernel::invoke( kernel_.invoke(
output_mapper, output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local), packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1), packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),