mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Reduce the number of template specializations of classes related to tensor contraction to reduce binary size.
This commit is contained in:
parent
2ebcb911b2
commit
e478532625
@ -177,9 +177,9 @@ struct NoOpOutputKernel {
|
|||||||
*/
|
*/
|
||||||
template <typename Index, typename Scalar>
|
template <typename Index, typename Scalar>
|
||||||
EIGEN_ALWAYS_INLINE void operator()(
|
EIGEN_ALWAYS_INLINE void operator()(
|
||||||
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
|
||||||
const TensorContractionParams& params, Index i, Index j, Index num_rows,
|
const TensorContractionParams& /*params*/, Index /*i*/,
|
||||||
Index num_cols) const {}
|
Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel>
|
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel>
|
||||||
@ -467,42 +467,58 @@ struct TensorContractionEvaluatorBase
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
|
#define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
|
||||||
if (this->m_lhs_inner_dim_contiguous) {
|
if (this->m_lhs_inner_dim_contiguous) { \
|
||||||
if (this->m_rhs_inner_dim_contiguous) {
|
if (this->m_rhs_inner_dim_contiguous) { \
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
if (this->m_rhs_inner_dim_reordered) { \
|
||||||
static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer);
|
METHOD<true, true, true, ALIGNMENT>ARGS; \
|
||||||
}
|
} \
|
||||||
else {
|
else { \
|
||||||
static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer);
|
METHOD<true, true, false, ALIGNMENT>ARGS; \
|
||||||
}
|
} \
|
||||||
}
|
} \
|
||||||
else {
|
else { \
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
if (this->m_rhs_inner_dim_reordered) { \
|
||||||
static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer);
|
METHOD<true, false, true, ALIGNMENT>ARGS; \
|
||||||
}
|
} \
|
||||||
else {
|
else { \
|
||||||
static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer);
|
METHOD<true, false, false, ALIGNMENT>ARGS; \
|
||||||
}
|
} \
|
||||||
}
|
} \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
if (this->m_rhs_inner_dim_contiguous) { \
|
||||||
|
if (this->m_rhs_inner_dim_reordered) { \
|
||||||
|
METHOD<false, true, true, ALIGNMENT>ARGS; \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
METHOD<false, true, false, ALIGNMENT>ARGS; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
if (this->m_rhs_inner_dim_reordered) { \
|
||||||
|
METHOD<false, false, true, ALIGNMENT>ARGS; \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
METHOD<false, false, false, ALIGNMENT>ARGS; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
if (this->m_rhs_inner_dim_contiguous) {
|
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer);
|
||||||
static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer);
|
}
|
||||||
}
|
|
||||||
else {
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
|
||||||
static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer);
|
bool rhs_inner_dim_reordered, int Alignment>
|
||||||
}
|
void evalProductSequential(Scalar* buffer) const {
|
||||||
}
|
if (this->m_j_size == 1) {
|
||||||
else {
|
this->template evalGemv<lhs_inner_dim_contiguous,
|
||||||
if (this->m_rhs_inner_dim_reordered) {
|
rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
|
||||||
static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer);
|
Alignment>(buffer);
|
||||||
}
|
} else {
|
||||||
else {
|
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
|
||||||
static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer);
|
rhs_inner_dim_reordered, Alignment>(buffer);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -623,7 +639,7 @@ struct TensorContractionEvaluatorBase
|
|||||||
OutputMapper output(buffer, m);
|
OutputMapper output(buffer, m);
|
||||||
|
|
||||||
// Sizes of the blocks to load in cache. See the Goto paper for details.
|
// Sizes of the blocks to load in cache. See the Goto paper for details.
|
||||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
|
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1);
|
||||||
const Index kc = blocking.kc();
|
const Index kc = blocking.kc();
|
||||||
const Index mc = numext::mini(m, blocking.mc());
|
const Index mc = numext::mini(m, blocking.mc());
|
||||||
const Index nc = numext::mini(n, blocking.nc());
|
const Index nc = numext::mini(n, blocking.nc());
|
||||||
@ -976,14 +992,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
|
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
|
||||||
Base(op, device) { }
|
Base(op, device) { }
|
||||||
|
|
||||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
template <int Alignment>
|
||||||
EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const {
|
void evalProduct(Scalar* buffer) const {
|
||||||
if (this->m_j_size == 1) {
|
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer));
|
||||||
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -21,13 +21,10 @@ enum {
|
|||||||
|
|
||||||
|
|
||||||
// Default Blocking Strategy
|
// Default Blocking Strategy
|
||||||
template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol>
|
template <typename LhsScalar, typename RhsScalar, typename Index, int ShardingType=ShardByCol>
|
||||||
class TensorContractionBlocking {
|
class TensorContractionBlocking {
|
||||||
public:
|
public:
|
||||||
|
|
||||||
typedef typename LhsMapper::Scalar LhsScalar;
|
|
||||||
typedef typename RhsMapper::Scalar RhsScalar;
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h`
|
adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h`
|
||||||
requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h`
|
requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h`
|
||||||
@ -41,7 +38,7 @@ class TensorContractionBlocking {
|
|||||||
../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901:
|
../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901:
|
||||||
dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function
|
dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#if !defined(EIGEN_HIPCC)
|
#if !defined(EIGEN_HIPCC)
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
#endif
|
#endif
|
||||||
|
@ -71,8 +71,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
TensorEvaluator(const XprType& op, const Device& device) :
|
TensorEvaluator(const XprType& op, const Device& device) :
|
||||||
Base(op, device) {}
|
Base(op, device) {}
|
||||||
|
|
||||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
|
template <int Alignment>
|
||||||
bool rhs_inner_dim_reordered, int Alignment>
|
|
||||||
void evalProduct(Scalar* buffer) const {
|
void evalProduct(Scalar* buffer) const {
|
||||||
const Index m = this->m_i_size;
|
const Index m = this->m_i_size;
|
||||||
const Index n = this->m_j_size;
|
const Index n = this->m_j_size;
|
||||||
@ -96,39 +95,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
typedef
|
|
||||||
typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
|
|
||||||
LhsScalar;
|
|
||||||
typedef
|
|
||||||
typename internal::remove_const<typename EvalRightArgType::Scalar>::type
|
|
||||||
RhsScalar;
|
|
||||||
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
|
|
||||||
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
|
|
||||||
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
|
|
||||||
typedef internal::TensorContractionInputMapper<
|
|
||||||
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
|
|
||||||
contract_t, internal::packet_traits<LhsScalar>::size,
|
|
||||||
lhs_inner_dim_contiguous, false, Unaligned>
|
|
||||||
LhsMapper;
|
|
||||||
typedef internal::TensorContractionInputMapper<
|
|
||||||
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
|
|
||||||
contract_t, internal::packet_traits<RhsScalar>::size,
|
|
||||||
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
|
|
||||||
RhsMapper;
|
|
||||||
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
|
||||||
typedef internal::gemm_pack_lhs<LhsScalar, Index,
|
|
||||||
typename LhsMapper::SubMapper, Traits::mr,
|
|
||||||
Traits::LhsProgress, ColMajor>
|
|
||||||
LhsPacker;
|
|
||||||
typedef internal::gemm_pack_rhs<
|
|
||||||
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
|
|
||||||
RhsPacker;
|
|
||||||
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
|
|
||||||
Traits::mr, Traits::nr, false, false>
|
|
||||||
GebpKernel;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// Compute a set of algorithm parameters:
|
// Compute a set of algorithm parameters:
|
||||||
// - kernel block sizes (bm, bn, bk)
|
// - kernel block sizes (bm, bn, bk)
|
||||||
// - task grain sizes (number of kernels executed per task: gm, gn)
|
// - task grain sizes (number of kernels executed per task: gm, gn)
|
||||||
@ -158,14 +124,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
// Again, we don't know number of threads yet, so we use 2.
|
// Again, we don't know number of threads yet, so we use 2.
|
||||||
Index bm, bn, bk;
|
Index bm, bn, bk;
|
||||||
if (shard_by_col) {
|
if (shard_by_col) {
|
||||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||||
internal::ShardByCol>
|
internal::ShardByCol>
|
||||||
blocking(k, m, n, 2);
|
blocking(k, m, n, 2);
|
||||||
bm = blocking.mc();
|
bm = blocking.mc();
|
||||||
bn = blocking.nc();
|
bn = blocking.nc();
|
||||||
bk = blocking.kc();
|
bk = blocking.kc();
|
||||||
} else {
|
} else {
|
||||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||||
internal::ShardByRow>
|
internal::ShardByRow>
|
||||||
blocking(k, m, n, 2);
|
blocking(k, m, n, 2);
|
||||||
bm = blocking.mc();
|
bm = blocking.mc();
|
||||||
@ -187,29 +153,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
if (n == 1) num_threads = 1;
|
if (n == 1) num_threads = 1;
|
||||||
|
|
||||||
if (num_threads == 1) {
|
if (num_threads == 1) {
|
||||||
// The single-threaded algorithm should be faster in this case.
|
TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
|
||||||
if (n == 1)
|
Unaligned, (buffer));
|
||||||
this->template evalGemv<lhs_inner_dim_contiguous,
|
|
||||||
rhs_inner_dim_contiguous,
|
|
||||||
rhs_inner_dim_reordered, Alignment>(buffer);
|
|
||||||
else
|
|
||||||
this->template evalGemm<lhs_inner_dim_contiguous,
|
|
||||||
rhs_inner_dim_contiguous,
|
|
||||||
rhs_inner_dim_reordered, Alignment>(buffer);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now that we know number of threads, recalculate sharding and blocking.
|
// Now that we know number of threads, recalculate sharding and blocking.
|
||||||
shard_by_col = shardByCol(m, n, num_threads);
|
shard_by_col = shardByCol(m, n, num_threads);
|
||||||
if (shard_by_col) {
|
if (shard_by_col) {
|
||||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||||
internal::ShardByCol>
|
internal::ShardByCol>
|
||||||
blocking(k, m, n, num_threads);
|
blocking(k, m, n, num_threads);
|
||||||
bm = blocking.mc();
|
bm = blocking.mc();
|
||||||
bn = blocking.nc();
|
bn = blocking.nc();
|
||||||
bk = blocking.kc();
|
bk = blocking.kc();
|
||||||
} else {
|
} else {
|
||||||
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index,
|
internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
|
||||||
internal::ShardByRow>
|
internal::ShardByRow>
|
||||||
blocking(k, m, n, num_threads);
|
blocking(k, m, n, num_threads);
|
||||||
bm = blocking.mc();
|
bm = blocking.mc();
|
||||||
@ -257,34 +216,55 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
// more important in this case.
|
// more important in this case.
|
||||||
if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
|
if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
|
||||||
|
|
||||||
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides,
|
#define CONTEXT_ARGS \
|
||||||
this->m_i_strides, this->m_left_contracting_strides,
|
(this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
|
||||||
this->m_k_strides);
|
nn0, shard_by_col, parallel_pack) \
|
||||||
|
.run()
|
||||||
|
|
||||||
RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides,
|
TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS);
|
||||||
this->m_j_strides, this->m_right_contracting_strides,
|
|
||||||
this->m_k_strides);
|
#undef CONTEXT_ARGS
|
||||||
|
|
||||||
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
|
|
||||||
OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
|
|
||||||
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
|
|
||||||
shard_by_col, parallel_pack)
|
|
||||||
.run();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Context coordinates a single parallel gemm operation.
|
// Context coordinates a single parallel gemm operation.
|
||||||
template <typename LhsPacker, typename RhsPacker, typename GebpKernel,
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
|
||||||
typename LhsMapper, typename RhsMapper, typename OutputMapper>
|
bool rhs_inner_dim_reordered, int Alignment>
|
||||||
class Context {
|
class Context {
|
||||||
public:
|
public:
|
||||||
Context(const Self* self, int num_threads, LhsMapper& lhs,
|
typedef internal::TensorContractionInputMapper<
|
||||||
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
|
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
|
||||||
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
|
contract_t, internal::packet_traits<LhsScalar>::size,
|
||||||
Index gn, Index nm0, Index nn0, bool shard_by_col,
|
lhs_inner_dim_contiguous, false, Unaligned>
|
||||||
|
LhsMapper;
|
||||||
|
typedef internal::TensorContractionInputMapper<
|
||||||
|
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
|
||||||
|
contract_t, internal::packet_traits<RhsScalar>::size,
|
||||||
|
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
|
||||||
|
RhsMapper;
|
||||||
|
typedef internal::gemm_pack_lhs<LhsScalar, Index,
|
||||||
|
typename LhsMapper::SubMapper, Traits::mr,
|
||||||
|
Traits::LhsProgress, ColMajor>
|
||||||
|
LhsPacker;
|
||||||
|
typedef internal::gemm_pack_rhs<
|
||||||
|
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
|
||||||
|
RhsPacker;
|
||||||
|
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
|
||||||
|
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
|
||||||
|
Traits::mr, Traits::nr, false, false>
|
||||||
|
GebpKernel;
|
||||||
|
|
||||||
|
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
|
||||||
|
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
|
||||||
|
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
|
||||||
bool parallel_pack)
|
bool parallel_pack)
|
||||||
: device_(self->m_device),
|
: device_(self->m_device),
|
||||||
lhs_(lhs),
|
lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
|
||||||
rhs_(rhs),
|
self->m_i_strides, self->m_left_contracting_strides,
|
||||||
|
self->m_k_strides),
|
||||||
|
rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
|
||||||
|
self->m_j_strides, self->m_right_contracting_strides,
|
||||||
|
self->m_k_strides),
|
||||||
buffer_(buffer),
|
buffer_(buffer),
|
||||||
output_(buffer, tm),
|
output_(buffer, tm),
|
||||||
output_kernel_(self->m_output_kernel),
|
output_kernel_(self->m_output_kernel),
|
||||||
@ -376,8 +356,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
private:
|
private:
|
||||||
Notification done_;
|
Notification done_;
|
||||||
const Device& device_;
|
const Device& device_;
|
||||||
LhsMapper& lhs_;
|
LhsMapper lhs_;
|
||||||
RhsMapper& rhs_;
|
RhsMapper rhs_;
|
||||||
Scalar* const buffer_;
|
Scalar* const buffer_;
|
||||||
OutputMapper output_;
|
OutputMapper output_;
|
||||||
OutputKernelType output_kernel_;
|
OutputKernelType output_kernel_;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user