Reduce the number of template specializations of classes related to tensor contraction to reduce binary size.

This commit is contained in:
Rasmus Munk Larsen 2018-07-27 12:36:34 -07:00
parent 2ebcb911b2
commit e478532625
3 changed files with 109 additions and 121 deletions

View File

@ -177,9 +177,9 @@ struct NoOpOutputKernel {
*/ */
template <typename Index, typename Scalar> template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()( EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper, const OutputKernel::OutputMapper<Index, Scalar>& /*output_mapper*/,
const TensorContractionParams& params, Index i, Index j, Index num_rows, const TensorContractionParams& /*params*/, Index /*i*/,
Index num_cols) const {} Index /*j*/, Index /*num_rows*/, Index /*num_cols*/) const {}
}; };
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel> template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel>
@ -467,42 +467,58 @@ struct TensorContractionEvaluatorBase
} }
} }
EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const { #define TENSOR_CONTRACTION_DISPATCH(METHOD, ALIGNMENT, ARGS) \
if (this->m_lhs_inner_dim_contiguous) { if (this->m_lhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_contiguous) { if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { if (this->m_rhs_inner_dim_reordered) { \
static_cast<const Derived*>(this)->template evalProduct<true, true, true, Unaligned>(buffer); METHOD<true, true, true, ALIGNMENT>ARGS; \
} } \
else { else { \
static_cast<const Derived*>(this)->template evalProduct<true, true, false, Unaligned>(buffer); METHOD<true, true, false, ALIGNMENT>ARGS; \
} } \
} } \
else { else { \
if (this->m_rhs_inner_dim_reordered) { if (this->m_rhs_inner_dim_reordered) { \
static_cast<const Derived*>(this)->template evalProduct<true, false, true, Unaligned>(buffer); METHOD<true, false, true, ALIGNMENT>ARGS; \
} } \
else { else { \
static_cast<const Derived*>(this)->template evalProduct<true, false, false, Unaligned>(buffer); METHOD<true, false, false, ALIGNMENT>ARGS; \
} } \
} } \
} \
else { \
if (this->m_rhs_inner_dim_contiguous) { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<false, true, true, ALIGNMENT>ARGS; \
} \
else { \
METHOD<false, true, false, ALIGNMENT>ARGS; \
} \
} \
else { \
if (this->m_rhs_inner_dim_reordered) { \
METHOD<false, false, true, ALIGNMENT>ARGS; \
} \
else { \
METHOD<false, false, false, ALIGNMENT>ARGS; \
} \
} \
} }
else {
if (this->m_rhs_inner_dim_contiguous) { EIGEN_DEVICE_FUNC void evalTo(Scalar* buffer) const {
if (this->m_rhs_inner_dim_reordered) { static_cast<const Derived*>(this)->template evalProduct<Unaligned>(buffer);
static_cast<const Derived*>(this)->template evalProduct<false, true, true, Unaligned>(buffer); }
}
else { template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
static_cast<const Derived*>(this)->template evalProduct<false, true, false, Unaligned>(buffer); bool rhs_inner_dim_reordered, int Alignment>
} void evalProductSequential(Scalar* buffer) const {
} if (this->m_j_size == 1) {
else { this->template evalGemv<lhs_inner_dim_contiguous,
if (this->m_rhs_inner_dim_reordered) { rhs_inner_dim_contiguous, rhs_inner_dim_reordered,
static_cast<const Derived*>(this)->template evalProduct<false, false, true, Unaligned>(buffer); Alignment>(buffer);
} } else {
else { this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous,
static_cast<const Derived*>(this)->template evalProduct<false, false, false, Unaligned>(buffer); rhs_inner_dim_reordered, Alignment>(buffer);
}
}
} }
} }
@ -623,7 +639,7 @@ struct TensorContractionEvaluatorBase
OutputMapper output(buffer, m); OutputMapper output(buffer, m);
// Sizes of the blocks to load in cache. See the Goto paper for details. // Sizes of the blocks to load in cache. See the Goto paper for details.
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1); internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index, internal::ShardByCol> blocking(k, m, n, 1);
const Index kc = blocking.kc(); const Index kc = blocking.kc();
const Index mc = numext::mini(m, blocking.mc()); const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc()); const Index nc = numext::mini(n, blocking.nc());
@ -976,14 +992,9 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) : EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) { } Base(op, device) { }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> template <int Alignment>
EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const { void evalProduct(Scalar* buffer) const {
if (this->m_j_size == 1) { TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential, Alignment, (buffer));
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
return;
}
this->template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
} }
}; };

View File

@ -21,13 +21,10 @@ enum {
// Default Blocking Strategy // Default Blocking Strategy
template <typename LhsMapper, typename RhsMapper, typename Index, int ShardingType=ShardByCol> template <typename LhsScalar, typename RhsScalar, typename Index, int ShardingType=ShardByCol>
class TensorContractionBlocking { class TensorContractionBlocking {
public: public:
typedef typename LhsMapper::Scalar LhsScalar;
typedef typename RhsMapper::Scalar RhsScalar;
/* /*
adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h` adding EIGEN_DEVICE_FUNC unconditionally to 'TensorContractionBlocking' constructor in `TensorContractionBlocking.h`
requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h` requires adding EIGEN_DEVICE_FUNC to `computeProductBlockingSizes` in `GeneralBlockPanelKernel.h`
@ -41,7 +38,7 @@ class TensorContractionBlocking {
../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901: ../Eigen/src/Core/products/GeneralBlockPanelKernel.h(57): error #2901:
dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function dynamic initialization is not supported for function-scope static variables within a __device__/__global__ function
*/ */
#if !defined(EIGEN_HIPCC) #if !defined(EIGEN_HIPCC)
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
#endif #endif

View File

@ -71,8 +71,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
TensorEvaluator(const XprType& op, const Device& device) : TensorEvaluator(const XprType& op, const Device& device) :
Base(op, device) {} Base(op, device) {}
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, template <int Alignment>
bool rhs_inner_dim_reordered, int Alignment>
void evalProduct(Scalar* buffer) const { void evalProduct(Scalar* buffer) const {
const Index m = this->m_i_size; const Index m = this->m_i_size;
const Index n = this->m_j_size; const Index n = this->m_j_size;
@ -96,39 +95,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
} }
#endif #endif
typedef
typename internal::remove_const<typename EvalLeftArgType::Scalar>::type
LhsScalar;
typedef
typename internal::remove_const<typename EvalRightArgType::Scalar>::type
RhsScalar;
typedef typename internal::gebp_traits<LhsScalar, RhsScalar> Traits;
typedef TensorEvaluator<EvalLeftArgType, Device> LeftEvaluator;
typedef TensorEvaluator<EvalRightArgType, Device> RightEvaluator;
typedef internal::TensorContractionInputMapper<
LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
contract_t, internal::packet_traits<LhsScalar>::size,
lhs_inner_dim_contiguous, false, Unaligned>
LhsMapper;
typedef internal::TensorContractionInputMapper<
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, internal::packet_traits<RhsScalar>::size,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
RhsMapper;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
typedef internal::gemm_pack_lhs<LhsScalar, Index,
typename LhsMapper::SubMapper, Traits::mr,
Traits::LhsProgress, ColMajor>
LhsPacker;
typedef internal::gemm_pack_rhs<
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
RhsPacker;
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
Traits::mr, Traits::nr, false, false>
GebpKernel;
// Compute a set of algorithm parameters: // Compute a set of algorithm parameters:
// - kernel block sizes (bm, bn, bk) // - kernel block sizes (bm, bn, bk)
// - task grain sizes (number of kernels executed per task: gm, gn) // - task grain sizes (number of kernels executed per task: gm, gn)
@ -158,14 +124,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// Again, we don't know number of threads yet, so we use 2. // Again, we don't know number of threads yet, so we use 2.
Index bm, bn, bk; Index bm, bn, bk;
if (shard_by_col) { if (shard_by_col) {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByCol> internal::ShardByCol>
blocking(k, m, n, 2); blocking(k, m, n, 2);
bm = blocking.mc(); bm = blocking.mc();
bn = blocking.nc(); bn = blocking.nc();
bk = blocking.kc(); bk = blocking.kc();
} else { } else {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByRow> internal::ShardByRow>
blocking(k, m, n, 2); blocking(k, m, n, 2);
bm = blocking.mc(); bm = blocking.mc();
@ -187,29 +153,22 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
if (n == 1) num_threads = 1; if (n == 1) num_threads = 1;
if (num_threads == 1) { if (num_threads == 1) {
// The single-threaded algorithm should be faster in this case. TENSOR_CONTRACTION_DISPATCH(this->template evalProductSequential,
if (n == 1) Unaligned, (buffer));
this->template evalGemv<lhs_inner_dim_contiguous,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>(buffer);
else
this->template evalGemm<lhs_inner_dim_contiguous,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered, Alignment>(buffer);
return; return;
} }
// Now that we know number of threads, recalculate sharding and blocking. // Now that we know number of threads, recalculate sharding and blocking.
shard_by_col = shardByCol(m, n, num_threads); shard_by_col = shardByCol(m, n, num_threads);
if (shard_by_col) { if (shard_by_col) {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByCol> internal::ShardByCol>
blocking(k, m, n, num_threads); blocking(k, m, n, num_threads);
bm = blocking.mc(); bm = blocking.mc();
bn = blocking.nc(); bn = blocking.nc();
bk = blocking.kc(); bk = blocking.kc();
} else { } else {
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::TensorContractionBlocking<LhsScalar, RhsScalar, Index,
internal::ShardByRow> internal::ShardByRow>
blocking(k, m, n, num_threads); blocking(k, m, n, num_threads);
bm = blocking.mc(); bm = blocking.mc();
@ -257,34 +216,55 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// more important in this case. // more important in this case.
if ((shard_by_col ? nm : nn) == 1) parallel_pack = false; if ((shard_by_col ? nm : nn) == 1) parallel_pack = false;
LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, #define CONTEXT_ARGS \
this->m_i_strides, this->m_left_contracting_strides, (this, num_threads, buffer, m, n, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, \
this->m_k_strides); nn0, shard_by_col, parallel_pack) \
.run()
RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, TENSOR_CONTRACTION_DISPATCH(Context, Alignment, CONTEXT_ARGS);
this->m_j_strides, this->m_right_contracting_strides,
this->m_k_strides); #undef CONTEXT_ARGS
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
shard_by_col, parallel_pack)
.run();
} }
// Context coordinates a single parallel gemm operation. // Context coordinates a single parallel gemm operation.
template <typename LhsPacker, typename RhsPacker, typename GebpKernel, template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous,
typename LhsMapper, typename RhsMapper, typename OutputMapper> bool rhs_inner_dim_reordered, int Alignment>
class Context { class Context {
public: public:
Context(const Self* self, int num_threads, LhsMapper& lhs, typedef internal::TensorContractionInputMapper<
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, LhsScalar, Index, internal::Lhs, LeftEvaluator, left_nocontract_t,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, contract_t, internal::packet_traits<LhsScalar>::size,
Index gn, Index nm0, Index nn0, bool shard_by_col, lhs_inner_dim_contiguous, false, Unaligned>
LhsMapper;
typedef internal::TensorContractionInputMapper<
RhsScalar, Index, internal::Rhs, RightEvaluator, right_nocontract_t,
contract_t, internal::packet_traits<RhsScalar>::size,
rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned>
RhsMapper;
typedef internal::gemm_pack_lhs<LhsScalar, Index,
typename LhsMapper::SubMapper, Traits::mr,
Traits::LhsProgress, ColMajor>
LhsPacker;
typedef internal::gemm_pack_rhs<
RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor>
RhsPacker;
typedef internal::blas_data_mapper<Scalar, Index, ColMajor> OutputMapper;
typedef internal::gebp_kernel<LhsScalar, RhsScalar, Index, OutputMapper,
Traits::mr, Traits::nr, false, false>
GebpKernel;
Context(const Self* self, int num_threads, Scalar* buffer, Index tm, Index tn,
Index tk, Index bm, Index bn, Index bk, Index nm, Index nn, Index nk,
Index gm, Index gn, Index nm0, Index nn0, bool shard_by_col,
bool parallel_pack) bool parallel_pack)
: device_(self->m_device), : device_(self->m_device),
lhs_(lhs), lhs_(self->m_leftImpl, self->m_left_nocontract_strides,
rhs_(rhs), self->m_i_strides, self->m_left_contracting_strides,
self->m_k_strides),
rhs_(self->m_rightImpl, self->m_right_nocontract_strides,
self->m_j_strides, self->m_right_contracting_strides,
self->m_k_strides),
buffer_(buffer), buffer_(buffer),
output_(buffer, tm), output_(buffer, tm),
output_kernel_(self->m_output_kernel), output_kernel_(self->m_output_kernel),
@ -376,8 +356,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
private: private:
Notification done_; Notification done_;
const Device& device_; const Device& device_;
LhsMapper& lhs_; LhsMapper lhs_;
RhsMapper& rhs_; RhsMapper rhs_;
Scalar* const buffer_; Scalar* const buffer_;
OutputMapper output_; OutputMapper output_;
OutputKernelType output_kernel_; OutputKernelType output_kernel_;