mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Add support for custom packed Lhs/Rhs blocks in tensor contractions
This commit is contained in:
parent
45e65fbb77
commit
4e2f6de1a8
@ -105,7 +105,9 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKern
|
|||||||
static const int NumDimensions = traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
|
static const int NumDimensions = traits<LhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
|
||||||
static const int Layout = traits<LhsXprType>::Layout;
|
static const int Layout = traits<LhsXprType>::Layout;
|
||||||
typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
|
typedef typename conditional<Pointer_type_promotion<typename LhsXprType::Scalar, Scalar>::val,
|
||||||
typename traits<LhsXprType>::PointerType, typename traits<RhsXprType>::PointerType>::type PointerType;
|
typename traits<LhsXprType>::PointerType,
|
||||||
|
typename traits<RhsXprType>::PointerType>::type
|
||||||
|
PointerType;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
Flags = 0
|
Flags = 0
|
||||||
@ -136,6 +138,80 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
|
|||||||
static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
|
static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Helper class to allocate and deallocate temporary memory for packed buffers.
|
||||||
|
template <typename LhsScalar, typename RhsScalar>
|
||||||
|
struct TensorContractionBlockMemAllocator {
|
||||||
|
typedef void* BlockMemHandle;
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
EIGEN_DEVICE_FUNC static BlockMemHandle allocate(Device& d, const Index bm,
|
||||||
|
const Index bk,
|
||||||
|
const Index bn,
|
||||||
|
LhsScalar** lhs_block,
|
||||||
|
RhsScalar** rhs_block) {
|
||||||
|
eigen_assert(lhs_block);
|
||||||
|
eigen_assert(rhs_block);
|
||||||
|
BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
|
||||||
|
char* block_mem = static_cast<char*>(d.allocate(sz.lhs_size + sz.rhs_size));
|
||||||
|
eigen_assert(block_mem);
|
||||||
|
*lhs_block = reinterpret_cast<LhsScalar*>(block_mem);
|
||||||
|
*rhs_block = reinterpret_cast<RhsScalar*>(block_mem + sz.lhs_size);
|
||||||
|
return block_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
EIGEN_DEVICE_FUNC static BlockMemHandle allocateSlices(
|
||||||
|
Device& d, const Index bm, const Index bk, const Index bn,
|
||||||
|
const Index num_lhs, const Index num_rhs, const Index num_slices,
|
||||||
|
std::vector<LhsScalar*>* lhs_blocks,
|
||||||
|
std::vector<RhsScalar*>* rhs_blocks) {
|
||||||
|
eigen_assert(num_slices > 0);
|
||||||
|
eigen_assert(num_lhs >= 0 && num_rhs >= 0)
|
||||||
|
eigen_assert(num_lhs == 0 || lhs_blocks);
|
||||||
|
eigen_assert(num_rhs == 0 || rhs_blocks);
|
||||||
|
BlockSizes sz = ComputeLhsRhsBlockSizes(bm, bk, bn);
|
||||||
|
void* block_mem = d.allocate(
|
||||||
|
(num_lhs * sz.lhs_size + num_rhs * sz.rhs_size) * num_slices);
|
||||||
|
eigen_assert(block_mem);
|
||||||
|
char* mem = static_cast<char*>(block_mem);
|
||||||
|
|
||||||
|
for (Index x = 0; x < num_slices; x++) {
|
||||||
|
if (num_lhs > 0) lhs_blocks[x].resize(num_lhs);
|
||||||
|
for (Index m = 0; m < num_lhs; m++) {
|
||||||
|
lhs_blocks[x][m] = reinterpret_cast<LhsScalar*>(mem);
|
||||||
|
mem += sz.lhs_size;
|
||||||
|
}
|
||||||
|
if (num_rhs > 0) rhs_blocks[x].resize(num_rhs);
|
||||||
|
for (Index n = 0; n < num_rhs; n++) {
|
||||||
|
rhs_blocks[x][n] = reinterpret_cast<RhsScalar*>(mem);
|
||||||
|
mem += sz.rhs_size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return block_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) {
|
||||||
|
d.deallocate(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct BlockSizes {
|
||||||
|
Index lhs_size;
|
||||||
|
Index rhs_size;
|
||||||
|
};
|
||||||
|
EIGEN_DEVICE_FUNC static BlockSizes ComputeLhsRhsBlockSizes(const Index bm,
|
||||||
|
const Index bk,
|
||||||
|
const Index bn) {
|
||||||
|
Index align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1);
|
||||||
|
BlockSizes sz;
|
||||||
|
sz.lhs_size = divup<Index>(bm * bk * sizeof(LhsScalar), align) * align;
|
||||||
|
sz.rhs_size = divup<Index>(bn * bk * sizeof(RhsScalar), align) * align;
|
||||||
|
return sz;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// WARNING: In this code we assume that Lhs and Rhs tensor expressions are in
|
// WARNING: In this code we assume that Lhs and Rhs tensor expressions are in
|
||||||
// ColMajor storage order. This property is guaranteed by the
|
// ColMajor storage order. This property is guaranteed by the
|
||||||
// TensorContractionOp evaluator. TensorContractionKernel specifies how we pack
|
// TensorContractionOp evaluator. TensorContractionKernel specifies how we pack
|
||||||
@ -168,12 +244,24 @@ template<typename ResScalar, typename LhsScalar, typename RhsScalar,
|
|||||||
typename StorageIndex, typename OutputMapper, typename LhsMapper,
|
typename StorageIndex, typename OutputMapper, typename LhsMapper,
|
||||||
typename RhsMapper>
|
typename RhsMapper>
|
||||||
struct TensorContractionKernel {
|
struct TensorContractionKernel {
|
||||||
|
TensorContractionKernel(StorageIndex m, StorageIndex k, StorageIndex n,
|
||||||
|
StorageIndex bm, StorageIndex bk, StorageIndex bn)
|
||||||
|
: m(m), k(k), n(n), bm(bm), bk(bk), bn(bn) {}
|
||||||
|
|
||||||
|
// Pack blocks of Lhs and Rhs into contiguous blocks in memory.
|
||||||
|
typedef LhsScalar* LhsBlock;
|
||||||
|
typedef RhsScalar* RhsBlock;
|
||||||
|
|
||||||
|
// Packed Lhs/Rhs block memory allocator.
|
||||||
|
typedef TensorContractionBlockMemAllocator<LhsScalar, RhsScalar>
|
||||||
|
BlockMemAllocator;
|
||||||
|
typedef typename BlockMemAllocator::BlockMemHandle BlockMemHandle;
|
||||||
|
|
||||||
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
|
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
|
||||||
|
|
||||||
typedef internal::gemm_pack_lhs<LhsScalar, StorageIndex,
|
typedef internal::gemm_pack_lhs<
|
||||||
typename LhsMapper::SubMapper,
|
LhsScalar, StorageIndex, typename LhsMapper::SubMapper, Traits::mr,
|
||||||
Traits::mr, Traits::LhsProgress,
|
Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor>
|
||||||
typename Traits::LhsPacket4Packing, ColMajor>
|
|
||||||
LhsPacker;
|
LhsPacker;
|
||||||
|
|
||||||
typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex,
|
typedef internal::gemm_pack_rhs<RhsScalar, StorageIndex,
|
||||||
@ -186,29 +274,61 @@ struct TensorContractionKernel {
|
|||||||
/*ConjugateLhs*/ false, /*ConjugateRhs*/ false>
|
/*ConjugateLhs*/ false, /*ConjugateRhs*/ false>
|
||||||
GebpKernel;
|
GebpKernel;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
|
template <typename Device>
|
||||||
static void packLhs(LhsScalar* lhsBlock,
|
EIGEN_DEVICE_FUNC BlockMemHandle allocate(Device& d, LhsBlock* lhs_block,
|
||||||
const typename LhsMapper::SubMapper& data_mapper,
|
RhsBlock* rhs_block) {
|
||||||
|
return BlockMemAllocator::allocate(d, bm, bk, bn, lhs_block, rhs_block);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
EIGEN_DEVICE_FUNC BlockMemHandle allocateSlices(
|
||||||
|
Device& d, const StorageIndex num_lhs, const StorageIndex num_rhs,
|
||||||
|
const StorageIndex num_slices, std::vector<LhsBlock>* lhs_blocks,
|
||||||
|
std::vector<RhsBlock>* rhs_blocks) {
|
||||||
|
return BlockMemAllocator::allocateSlices(
|
||||||
|
d, bm, bk, bn, num_lhs, num_rhs, num_slices, lhs_blocks, rhs_blocks);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
EIGEN_DEVICE_FUNC static void deallocate(Device& d, BlockMemHandle handle) {
|
||||||
|
BlockMemAllocator::deallocate(d, handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packLhs(
|
||||||
|
LhsBlock* lhsBlock, const typename LhsMapper::SubMapper& data_mapper,
|
||||||
const StorageIndex depth, const StorageIndex rows) {
|
const StorageIndex depth, const StorageIndex rows) {
|
||||||
LhsPacker()(lhsBlock, data_mapper, depth, rows, /*stride*/ 0, /*offset*/ 0);
|
LhsPacker()(*lhsBlock, data_mapper, depth, rows, /*stride*/ 0,
|
||||||
|
/*offset*/ 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
|
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void packRhs(
|
||||||
static void packRhs(RhsScalar* rhsBlock,
|
RhsBlock* rhsBlock, const typename RhsMapper::SubMapper& data_mapper,
|
||||||
const typename RhsMapper::SubMapper& data_mapper,
|
|
||||||
const StorageIndex depth, const StorageIndex cols) {
|
const StorageIndex depth, const StorageIndex cols) {
|
||||||
RhsPacker()(rhsBlock, data_mapper, depth, cols);
|
RhsPacker()(*rhsBlock, data_mapper, depth, cols);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
|
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void invoke(
|
||||||
static void invoke(const OutputMapper& output_mapper,
|
const OutputMapper& output_mapper, const LhsBlock& lhsBlock,
|
||||||
const LhsScalar* lhsBlock, const RhsScalar* rhsBlock,
|
const RhsBlock& rhsBlock, const StorageIndex rows,
|
||||||
const StorageIndex rows, const StorageIndex depth,
|
const StorageIndex depth, const StorageIndex cols,
|
||||||
const StorageIndex cols, const ResScalar alpha) {
|
const ResScalar alpha) {
|
||||||
|
static const int kComputeStrideFromBlockDimensions = -1;
|
||||||
GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha,
|
GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha,
|
||||||
/*strideA*/ -1, /*strideB*/ -1,
|
/*strideA*/ kComputeStrideFromBlockDimensions,
|
||||||
|
/*strideB*/ kComputeStrideFromBlockDimensions,
|
||||||
/*offsetA*/ 0, /*offsetB*/ 0);
|
/*offsetA*/ 0, /*offsetB*/ 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// These are dimensions of the original Tensors, and selected block sizes. The
|
||||||
|
// actual block sizes passed to all function above might be smaller because of
|
||||||
|
// the partial blocks at the end.
|
||||||
|
const StorageIndex m;
|
||||||
|
const StorageIndex k;
|
||||||
|
const StorageIndex n;
|
||||||
|
const StorageIndex bm;
|
||||||
|
const StorageIndex bk;
|
||||||
|
const StorageIndex bn;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
@ -737,11 +857,18 @@ struct TensorContractionEvaluatorBase
|
|||||||
const Index kc = blocking.kc();
|
const Index kc = blocking.kc();
|
||||||
const Index mc = numext::mini(m, blocking.mc());
|
const Index mc = numext::mini(m, blocking.mc());
|
||||||
const Index nc = numext::mini(n, blocking.nc());
|
const Index nc = numext::mini(n, blocking.nc());
|
||||||
const Index sizeA = mc * kc;
|
|
||||||
const Index sizeB = kc * nc;
|
|
||||||
|
|
||||||
LhsScalar* blockA = static_cast<LhsScalar *>(this->m_device.allocate(sizeA * sizeof(LhsScalar)));
|
typedef typename TensorContractionKernel::LhsBlock LhsBlock;
|
||||||
RhsScalar* blockB = static_cast<RhsScalar *>(this->m_device.allocate(sizeB * sizeof(RhsScalar)));
|
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
|
||||||
|
|
||||||
|
LhsBlock blockA;
|
||||||
|
RhsBlock blockB;
|
||||||
|
|
||||||
|
TensorContractionKernel kernel(m, k_slice, n, mc, kc, nc);
|
||||||
|
|
||||||
|
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
|
||||||
|
const BlockMemHandle packed_mem =
|
||||||
|
kernel.allocate(this->m_device, &blockA, &blockB);
|
||||||
|
|
||||||
for(Index i2=0; i2<m; i2+=mc)
|
for(Index i2=0; i2<m; i2+=mc)
|
||||||
{
|
{
|
||||||
@ -749,22 +876,20 @@ struct TensorContractionEvaluatorBase
|
|||||||
for (Index k2 = k_start; k2 < k_end; k2 += kc) {
|
for (Index k2 = k_start; k2 < k_end; k2 += kc) {
|
||||||
// make sure we don't overshoot right edge of left matrix, then pack vertical panel
|
// make sure we don't overshoot right edge of left matrix, then pack vertical panel
|
||||||
const Index actual_kc = numext::mini(k2 + kc, k_end) - k2;
|
const Index actual_kc = numext::mini(k2 + kc, k_end) - k2;
|
||||||
TensorContractionKernel::packLhs(blockA, lhs.getSubMapper(i2, k2),
|
kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
|
||||||
actual_kc, actual_mc);
|
|
||||||
|
|
||||||
// series of horizontal blocks
|
// series of horizontal blocks
|
||||||
for (Index j2 = 0; j2 < n; j2 += nc) {
|
for (Index j2 = 0; j2 < n; j2 += nc) {
|
||||||
// make sure we don't overshoot right edge of right matrix, then pack block
|
// make sure we don't overshoot right edge of right matrix, then pack block
|
||||||
const Index actual_nc = numext::mini(j2 + nc, n) - j2;
|
const Index actual_nc = numext::mini(j2 + nc, n) - j2;
|
||||||
TensorContractionKernel::packRhs(blockB, rhs.getSubMapper(k2, j2),
|
kernel.packRhs(&blockB, rhs.getSubMapper(k2, j2), actual_kc,
|
||||||
actual_kc, actual_nc);
|
actual_nc);
|
||||||
|
|
||||||
// call gebp (matrix kernel)
|
// call gebp (matrix kernel)
|
||||||
// The parameters here are copied from Eigen's GEMM implementation
|
// The parameters here are copied from Eigen's GEMM implementation
|
||||||
const OutputMapper output_mapper = output.getSubMapper(i2, j2);
|
const OutputMapper output_mapper = output.getSubMapper(i2, j2);
|
||||||
TensorContractionKernel::invoke(output_mapper, blockA, blockB,
|
kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
|
||||||
actual_mc, actual_kc, actual_nc,
|
actual_nc, Scalar(1));
|
||||||
Scalar(1));
|
|
||||||
|
|
||||||
// We are done with this [i2, j2] output block.
|
// We are done with this [i2, j2] output block.
|
||||||
if (use_output_kernel && k2 + kc >= k_end) {
|
if (use_output_kernel && k2 + kc >= k_end) {
|
||||||
@ -775,8 +900,7 @@ struct TensorContractionEvaluatorBase
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this->m_device.deallocate(blockA);
|
kernel.deallocate(this->m_device, packed_mem);
|
||||||
this->m_device.deallocate(blockB);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||||
|
@ -24,12 +24,17 @@ enum {
|
|||||||
*/
|
*/
|
||||||
/// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which
|
/// The make pointer class is used by sycl in order to build the mapper class on the device. For other platform the default make pointer is used which
|
||||||
/// is scalar * for CoeffLoader.
|
/// is scalar * for CoeffLoader.
|
||||||
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer> struct CoeffLoader;
|
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_ = MakePointer>
|
||||||
template<typename Scalar, typename Index, int side, typename Tensor, typename nocontract_t, typename contract_t,
|
struct CoeffLoader;
|
||||||
int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
|
|
||||||
template <class> class MakePointer_ = MakePointer> class BaseTensorContractionMapper;
|
|
||||||
|
|
||||||
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_> struct CoeffLoader {
|
template <typename Scalar, typename Index, int side, typename Tensor,
|
||||||
|
typename nocontract_t, typename contract_t, int packet_size,
|
||||||
|
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
|
||||||
|
template <class> class MakePointer_ = MakePointer>
|
||||||
|
class BaseTensorContractionMapper;
|
||||||
|
|
||||||
|
template <typename Tensor, bool HasRawAccess, template <class> class MakePointer_>
|
||||||
|
struct CoeffLoader {
|
||||||
enum {
|
enum {
|
||||||
DirectOffsets = false
|
DirectOffsets = false
|
||||||
};
|
};
|
||||||
@ -40,6 +45,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer
|
|||||||
eigen_assert(false && "unsupported");
|
eigen_assert(false && "unsupported");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
|
||||||
|
data() const {
|
||||||
|
eigen_assert(false && "unsupported");
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return m_tensor.coeff(index); }
|
||||||
|
|
||||||
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
@ -48,12 +59,12 @@ template <typename Tensor, bool HasRawAccess, template <class> class MakePointer
|
|||||||
return m_tensor.template packet<LoadMode>(index);
|
return m_tensor.template packet<LoadMode>(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const Tensor m_tensor;
|
const Tensor m_tensor;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Tensor, template <class> class MakePointer_> struct CoeffLoader<Tensor, true, MakePointer_> {
|
template <typename Tensor, template <class> class MakePointer_>
|
||||||
|
struct CoeffLoader<Tensor, true, MakePointer_> {
|
||||||
enum {
|
enum {
|
||||||
DirectOffsets = true
|
DirectOffsets = true
|
||||||
};
|
};
|
||||||
@ -64,6 +75,11 @@ template <typename Tensor, template <class> class MakePointer_> struct CoeffLoad
|
|||||||
m_data += offset;
|
m_data += offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE const typename MakePointer_<const typename Tensor::Scalar>::Type
|
||||||
|
data() const {
|
||||||
|
return m_data;
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE typename Tensor::Scalar coeff(typename Tensor::Index index) const { return loadConstant(m_data+index); }
|
||||||
|
|
||||||
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
@ -214,6 +230,17 @@ class SimpleTensorContractionMapper {
|
|||||||
return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
|
return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_>& tensor() const {
|
||||||
|
return m_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
const nocontract_t& nocontract_strides() const {
|
||||||
|
return m_nocontract_strides;
|
||||||
|
}
|
||||||
|
const nocontract_t& ij_strides() const { return m_ij_strides; }
|
||||||
|
const contract_t& contract_strides() const { return m_contract_strides; }
|
||||||
|
const contract_t& k_strides() const { return m_k_strides; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
|
CoeffLoader<Tensor, Tensor::RawAccess, MakePointer_> m_tensor;
|
||||||
const nocontract_t m_nocontract_strides;
|
const nocontract_t m_nocontract_strides;
|
||||||
@ -445,6 +472,10 @@ class TensorContractionSubMapper {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ParentMapper& base_mapper() const { return m_base_mapper; }
|
||||||
|
Index vert_offset() const { return m_vert_offset; }
|
||||||
|
Index horiz_offset() const { return m_horiz_offset; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ParentMapper m_base_mapper;
|
ParentMapper m_base_mapper;
|
||||||
const Index m_vert_offset;
|
const Index m_vert_offset;
|
||||||
|
@ -280,6 +280,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
|
Scalar, LhsScalar, RhsScalar, Index, OutputMapper, LhsMapper, RhsMapper>
|
||||||
TensorContractionKernel;
|
TensorContractionKernel;
|
||||||
|
|
||||||
|
typedef typename TensorContractionKernel::LhsBlock LhsBlock;
|
||||||
|
typedef typename TensorContractionKernel::RhsBlock RhsBlock;
|
||||||
|
typedef typename TensorContractionKernel::BlockMemHandle BlockMemHandle;
|
||||||
|
|
||||||
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
|
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
|
||||||
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
|
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
|
||||||
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
|
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
|
||||||
@ -311,7 +315,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
gm_(gm),
|
gm_(gm),
|
||||||
gn_(gn),
|
gn_(gn),
|
||||||
nm0_(nm0),
|
nm0_(nm0),
|
||||||
nn0_(nn0)
|
nn0_(nn0),
|
||||||
|
kernel_(m_, k_, n_, bm_, bk_, bn_)
|
||||||
{
|
{
|
||||||
// These two options are mutually exclusive.
|
// These two options are mutually exclusive.
|
||||||
eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
|
eigen_assert(!(parallel_pack && parallelize_by_sharding_dim_only));
|
||||||
@ -342,26 +347,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Allocate memory for packed rhs/lhs matrices.
|
// Allocate memory for packed rhs/lhs matrices.
|
||||||
size_t align = numext::maxi(EIGEN_MAX_ALIGN_BYTES, 1);
|
packed_mem_ = kernel_.allocateSlices( //
|
||||||
size_t lhs_size =
|
device_, //
|
||||||
divup<size_t>(bm_ * bk_ * sizeof(LhsScalar), align) * align;
|
/*num_lhs=*/nm0_, //
|
||||||
size_t rhs_size =
|
/*num_rhs=*/nn0_, //
|
||||||
divup<size_t>(bn_ * bk_ * sizeof(RhsScalar), align) * align;
|
/*num_slices=*/std::min<Index>(nk_, P - 1), //
|
||||||
packed_mem_ = static_cast<char*>(device_.allocate(
|
packed_lhs_, packed_rhs_);
|
||||||
(nm0_ * lhs_size + nn0_ * rhs_size) * std::min<size_t>(nk_, P - 1)));
|
|
||||||
char* mem = static_cast<char*>(packed_mem_);
|
|
||||||
for (Index x = 0; x < numext::mini<Index>(nk_, P - 1); x++) {
|
|
||||||
packed_lhs_[x].resize(nm0_);
|
|
||||||
for (Index m = 0; m < nm0_; m++) {
|
|
||||||
packed_lhs_[x][m] = reinterpret_cast<LhsScalar*>(mem);
|
|
||||||
mem += lhs_size;
|
|
||||||
}
|
|
||||||
packed_rhs_[x].resize(nn0_);
|
|
||||||
for (Index n = 0; n < nn0_; n++) {
|
|
||||||
packed_rhs_[x][n] = reinterpret_cast<RhsScalar*>(mem);
|
|
||||||
mem += rhs_size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (parallelize_by_sharding_dim_only_) {
|
if (parallelize_by_sharding_dim_only_) {
|
||||||
const int num_worker_threads = device_.numThreadsInPool();
|
const int num_worker_threads = device_.numThreadsInPool();
|
||||||
@ -373,14 +364,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
std::memory_order_relaxed);
|
std::memory_order_relaxed);
|
||||||
|
|
||||||
Index num_blocks = num_worker_threads * gn_;
|
Index num_blocks = num_worker_threads * gn_;
|
||||||
thread_local_packed_mem_ = device_.allocate(num_blocks * rhs_size);
|
thread_local_packed_mem_ = kernel_.allocateSlices( //
|
||||||
mem = static_cast<char*>(thread_local_packed_mem_);
|
device_, //
|
||||||
|
/*num_lhs=*/0, //
|
||||||
|
/*num_rhs=*/num_blocks, //
|
||||||
|
/*num_slices=*/1, //
|
||||||
|
/*lhs_blocks=*/nullptr, &thread_local_packed_rhs_);
|
||||||
|
|
||||||
thread_local_packed_rhs_.resize(num_blocks, nullptr);
|
|
||||||
for (Index i = 0; i < num_blocks; ++i) {
|
|
||||||
thread_local_packed_rhs_[i] = reinterpret_cast<RhsScalar*>(mem);
|
|
||||||
mem += rhs_size;
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
|
can_use_thread_local_packed_ = new std::atomic<bool>[nm_];
|
||||||
for (int i = 0; i < nm_; ++i)
|
for (int i = 0; i < nm_; ++i)
|
||||||
@ -388,14 +378,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
std::memory_order_relaxed);
|
std::memory_order_relaxed);
|
||||||
|
|
||||||
Index num_blocks = num_worker_threads * gm_;
|
Index num_blocks = num_worker_threads * gm_;
|
||||||
thread_local_packed_mem_ = device_.allocate(num_blocks * lhs_size);
|
thread_local_packed_mem_ = kernel_.allocateSlices( //
|
||||||
mem = static_cast<char*>(thread_local_packed_mem_);
|
device_, //
|
||||||
|
/*num_lhs=*/num_blocks, //
|
||||||
thread_local_packed_lhs_.resize(num_blocks, nullptr);
|
/*num_rhs=*/0, //
|
||||||
for (Index i = 0; i < num_blocks; ++i) {
|
/*num_slices=*/1, &thread_local_packed_lhs_, //
|
||||||
thread_local_packed_lhs_[i] = reinterpret_cast<LhsScalar*>(mem);
|
/*rhs_blocks=*/nullptr);
|
||||||
mem += lhs_size;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -405,9 +393,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
|
for (Index m = 0; m < nm_; m++) delete[] state_kernel_[x][m];
|
||||||
delete[] state_kernel_[x];
|
delete[] state_kernel_[x];
|
||||||
}
|
}
|
||||||
device_.deallocate(packed_mem_);
|
kernel_.deallocate(device_, packed_mem_);
|
||||||
if (parallelize_by_sharding_dim_only_) {
|
if (parallelize_by_sharding_dim_only_) {
|
||||||
device_.deallocate(thread_local_packed_mem_);
|
kernel_.deallocate(device_, thread_local_packed_mem_);
|
||||||
delete[] can_use_thread_local_packed_;
|
delete[] can_use_thread_local_packed_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -455,6 +443,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
// coarsening).
|
// coarsening).
|
||||||
const Index nm0_;
|
const Index nm0_;
|
||||||
const Index nn0_;
|
const Index nn0_;
|
||||||
|
// Tensor contraction kernel.
|
||||||
|
TensorContractionKernel kernel_;
|
||||||
|
|
||||||
// Parallelization strategy.
|
// Parallelization strategy.
|
||||||
//
|
//
|
||||||
@ -491,9 +481,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
// actively executing + one to track completion of kernels in the second
|
// actively executing + one to track completion of kernels in the second
|
||||||
// slice.
|
// slice.
|
||||||
static const Index P = 3;
|
static const Index P = 3;
|
||||||
void* packed_mem_;
|
|
||||||
std::vector<LhsScalar*> packed_lhs_[P - 1];
|
// Handle to the allocated temporary storage for Lhs/Rhs blocks.
|
||||||
std::vector<RhsScalar*> packed_rhs_[P - 1];
|
BlockMemHandle packed_mem_;
|
||||||
|
std::vector<LhsBlock> packed_lhs_[P - 1];
|
||||||
|
std::vector<RhsBlock> packed_rhs_[P - 1];
|
||||||
|
|
||||||
// If we choose to parallelize only by the sharding dimension, each thread
|
// If we choose to parallelize only by the sharding dimension, each thread
|
||||||
// will have it's own "thead local" (not a c++ thread local storage) memory
|
// will have it's own "thead local" (not a c++ thread local storage) memory
|
||||||
@ -511,11 +503,11 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
// completion of the K-1 kernel, so we have to allocate "global" packed_lhs_
|
// completion of the K-1 kernel, so we have to allocate "global" packed_lhs_
|
||||||
// and packed_rhs_ to allow kernels to be executed later on a thread
|
// and packed_rhs_ to allow kernels to be executed later on a thread
|
||||||
// different from the thread that was used for packing.
|
// different from the thread that was used for packing.
|
||||||
void* thread_local_packed_mem_;
|
BlockMemHandle thread_local_packed_mem_;
|
||||||
|
|
||||||
// Only one of these will be initialized depending on shard_by_col value.
|
// Only one of these will be initialized depending on shard_by_col value.
|
||||||
std::vector<LhsScalar*> thread_local_packed_lhs_;
|
std::vector<LhsBlock> thread_local_packed_lhs_;
|
||||||
std::vector<RhsScalar*> thread_local_packed_rhs_;
|
std::vector<RhsBlock> thread_local_packed_rhs_;
|
||||||
|
|
||||||
// After a particular shard for Kth slice missed thread local execution
|
// After a particular shard for Kth slice missed thread local execution
|
||||||
// opportunity (K-1 slice didn't complete kernels execution), we can no
|
// opportunity (K-1 slice didn't complete kernels execution), we can no
|
||||||
@ -532,7 +524,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
std::atomic<Index> state_packing_ready_[P];
|
std::atomic<Index> state_packing_ready_[P];
|
||||||
std::atomic<Index> state_switch_[P];
|
std::atomic<Index> state_switch_[P];
|
||||||
|
|
||||||
LhsScalar* packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
|
LhsBlock& packed_lhs(Index m, Index k, Index m1, bool use_thread_local) {
|
||||||
if (use_thread_local) {
|
if (use_thread_local) {
|
||||||
eigen_assert(!shard_by_col_);
|
eigen_assert(!shard_by_col_);
|
||||||
|
|
||||||
@ -546,7 +538,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RhsScalar* packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
|
RhsBlock& packed_rhs(Index n, Index k, Index n1, bool use_thread_local) {
|
||||||
if (use_thread_local) {
|
if (use_thread_local) {
|
||||||
eigen_assert(shard_by_col_);
|
eigen_assert(shard_by_col_);
|
||||||
|
|
||||||
@ -580,7 +572,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
} else {
|
} else {
|
||||||
// If we can't guarantee that all kernels in `k` slice will be
|
// If we can't guarantee that all kernels in `k` slice will be
|
||||||
// executed sequentially in current thread, it's no longer safe to use
|
// executed sequentially in current thread, it's no longer safe to use
|
||||||
// thread local memory in followig slices along the k dimensions.
|
// thread local memory in following slices along the k dimensions.
|
||||||
eigen_assert(k > 0);
|
eigen_assert(k > 0);
|
||||||
can_use_thread_local_packed_[m].store(false,
|
can_use_thread_local_packed_[m].store(false,
|
||||||
std::memory_order_relaxed);
|
std::memory_order_relaxed);
|
||||||
@ -589,9 +581,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
|
|
||||||
const Index mend = m * gm_ + gm(m);
|
const Index mend = m * gm_ + gm(m);
|
||||||
for (Index m1 = m * gm_; m1 < mend; m1++)
|
for (Index m1 = m * gm_; m1 < mend; m1++)
|
||||||
TensorContractionKernel::packLhs(packed_lhs(m, k, m1, use_thread_local),
|
kernel_.packLhs(&packed_lhs(m, k, m1, use_thread_local),
|
||||||
lhs_.getSubMapper(m1 * bm_, k * bk_),
|
lhs_.getSubMapper(m1 * bm_, k * bk_), bk(k), bm(m1));
|
||||||
bk(k), bm(m1));
|
|
||||||
|
|
||||||
if (!parallel_pack_ && shard_by_col_) {
|
if (!parallel_pack_ && shard_by_col_) {
|
||||||
assert(!use_thread_local);
|
assert(!use_thread_local);
|
||||||
@ -634,9 +625,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
// deadlocks.
|
// deadlocks.
|
||||||
memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
|
memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
|
||||||
}
|
}
|
||||||
TensorContractionKernel::packRhs(packed_rhs(n, k, n1, use_thread_local),
|
kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
|
||||||
rhs_.getSubMapper(k * bk_, n1 * bn_),
|
rhs_.getSubMapper(k * bk_, n1 * bn_), bk(k), bn(n1));
|
||||||
bk(k), bn(n1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_pack_ || shard_by_col_) {
|
if (parallel_pack_ || shard_by_col_) {
|
||||||
@ -661,7 +651,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
||||||
for (Index m1 = m * gm_; m1 < mend; m1++) {
|
for (Index m1 = m * gm_; m1 < mend; m1++) {
|
||||||
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
||||||
TensorContractionKernel::invoke(
|
kernel_.invoke(
|
||||||
output_mapper,
|
output_mapper,
|
||||||
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
|
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
|
||||||
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
|
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
|
||||||
@ -678,7 +668,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
for (Index m1 = m * gm_; m1 < mend; m1++)
|
for (Index m1 = m * gm_; m1 < mend; m1++)
|
||||||
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
||||||
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
||||||
TensorContractionKernel::invoke(
|
kernel_.invoke(
|
||||||
output_mapper,
|
output_mapper,
|
||||||
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
|
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
|
||||||
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
|
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user