Add beta to TensorContractionKernel and make memset optional

This commit is contained in:
Eugene Zhulenev 2019-10-02 11:06:02 -07:00
parent bd0fac456f
commit 6e40454a6e
2 changed files with 39 additions and 22 deletions

View File

@ -180,6 +180,10 @@ template <typename ResScalar, typename LhsScalar, typename RhsScalar,
typename StorageIndex, typename OutputMapper, typename LhsMapper,
typename RhsMapper>
struct TensorContractionKernel {
// True if `invoke()` supports `beta` in `C <- alpha * A * B + beta * C`
// (otherwise beta should be always equal to 1).
enum { HasBeta = false };
EIGEN_DEVICE_FUNC
TensorContractionKernel(StorageIndex m_, StorageIndex k_, StorageIndex n_,
StorageIndex bm_, StorageIndex bk_, StorageIndex bn_)
@ -248,7 +252,9 @@ struct TensorContractionKernel {
const OutputMapper& output_mapper, const LhsBlock& lhsBlock,
const RhsBlock& rhsBlock, const StorageIndex rows,
const StorageIndex depth, const StorageIndex cols,
const ResScalar alpha) {
const ResScalar alpha, const ResScalar beta) {
// Default GEBP kernel does not support beta.
eigen_assert(beta == ResScalar(1));
static const int kComputeStrideFromBlockDimensions = -1;
GebpKernel()(output_mapper, lhsBlock, rhsBlock, rows, depth, cols, alpha,
/*strideA*/ kComputeStrideFromBlockDimensions,
@ -772,15 +778,6 @@ struct TensorContractionEvaluatorBase
void evalGemm(Scalar* buffer) const {
// columns in left side, rows in right side
const Index k = this->m_k_size;
// rows in left side
const Index m = this->m_i_size;
// columns in right side
const Index n = this->m_j_size;
// zero out the result buffer (which must be of size at least m * n * sizeof(Scalar)
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
this->template evalGemmPartial<lhs_inner_dim_contiguous,
rhs_inner_dim_contiguous,
rhs_inner_dim_reordered,
@ -866,6 +863,12 @@ struct TensorContractionEvaluatorBase
const BlockMemHandle packed_mem =
kernel.allocate(this->m_device, &blockA, &blockB);
// If a contraction kernel does not support beta, explicitly initialize
// output buffer with zeroes.
if (!TensorContractionKernel::HasBeta) {
this->m_device.memset(buffer, 0, m * n * sizeof(Scalar));
}
for(Index i2=0; i2<m; i2+=mc)
{
const Index actual_mc = numext::mini(i2+mc,m)-i2;
@ -874,6 +877,13 @@ struct TensorContractionEvaluatorBase
const Index actual_kc = numext::mini(k2 + kc, k_end) - k2;
kernel.packLhs(&blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc);
// If kernel supports beta, there is no need to initialize output
// buffer with zeroes.
const Scalar alpha = Scalar(1);
const Scalar beta = (TensorContractionKernel::HasBeta && k2 == k_start)
? Scalar(0)
: Scalar(1);
// series of horizontal blocks
for (Index j2 = 0; j2 < n; j2 += nc) {
// make sure we don't overshoot right edge of right matrix, then pack block
@ -885,7 +895,7 @@ struct TensorContractionEvaluatorBase
// The parameters here are copied from Eigen's GEMM implementation
const OutputMapper output_mapper = output.getSubMapper(i2, j2);
kernel.invoke(output_mapper, blockA, blockB, actual_mc, actual_kc,
actual_nc, Scalar(1));
actual_nc, alpha, beta);
// We are done with this [i2, j2] output block.
if (use_output_kernel && k2 + kc >= k_end) {

View File

@ -904,14 +904,16 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index nend = n * gn_ + gn(n);
for (Index n1 = n * gn_; n1 < nend; n1++) {
if (k == 0) {
// Zero the output memory in parallel.
// On 10000x2x10000 mm zeroing can easily take half of time.
// Zero (bn x m) row. Safe to do here because all kernels that will
// write to this memory depend on completion of this task.
// Note: don't call device_.memset() here. device_.memset() blocks on
// thread pool worker thread, which can lead to underutilization and
// deadlocks.
if (!TensorContractionKernel::HasBeta && k == 0) {
// Zero the output memory in parallel, only if contraction kernel does
// not support `beta`. Otherwise we will pass beta 0.0 to the first
// call to the `TensorContractionKernel::invoke()`.
//
// On 10000x2x10000 mm zeroing can easily take half of time. Zero (bn
// x m) row. Safe to do here because all kernels that will write to
// this memory depend on completion of this task. Note: don't call
// device_.memset() here. device_.memset() blocks on thread pool
// worker thread, which can lead to underutilization and deadlocks.
memset(buffer_ + n1 * bn_ * m_, 0, bn(n1) * m_ * sizeof(Scalar));
}
kernel_.packRhs(&packed_rhs(n, k, n1, use_thread_local),
@ -936,6 +938,12 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// (rhs fits into L2$ while lhs only into L3$).
const Index nend = n * gn_ + gn(n);
const Index mend = m * gm_ + gm(m);
// NOTE: output = alpha * LHS * RHS + beta * output.
const Scalar alpha = Scalar(1);
const Scalar beta =
(TensorContractionKernel::HasBeta && k == 0) ? Scalar(0) : Scalar(1);
if (shard_by_col_) {
for (Index n1 = n * gn_; n1 < nend; n1++) {
for (Index m1 = m * gm_; m1 < mend; m1++) {
@ -944,7 +952,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
bk(k), bn(n1), Scalar(1));
bk(k), bn(n1), alpha, beta);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
@ -961,7 +969,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
output_mapper,
packed_lhs(m, k, m1, !shard_by_col_ && use_thread_local),
packed_rhs(n, k, n1, shard_by_col_ && use_thread_local), bm(m1),
bk(k), bn(n1), Scalar(1));
bk(k), bn(n1), alpha, beta);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
@ -1266,7 +1274,6 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
template <int Alignment>
void processBlock(Index block_idx, Index begin, Index end) {
Scalar* buf = block_buffers[block_idx];
::memset(buf, 0, buffer_size_bytes);
TENSOR_CONTRACTION_DISPATCH(
evaluator->template evalGemmPartialWithoutOutputKernel, Alignment,