diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index 3b22e43e7..ea17a897d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -590,6 +590,25 @@ struct TensorContractionEvaluatorBase // zero out the result buffer (which must be of size at least m * n * sizeof(Scalar) this->m_device.memset(buffer, 0, m * n * sizeof(Scalar)); + this->template evalGemmPartial(buffer, + 0, k, 1); + } + + template + EIGEN_DEVICE_FUNC void evalGemmPartial(Scalar* buffer, Index k_start, Index k_end, int num_threads) const { + // columns in left side, rows in right side + const Index k = this->m_k_size; + + eigen_assert(k_end >= k_start && k_start >= 0 && k_end <= k); + const Index k_slice = k_end - k_start; + + // rows in left side + const Index m = this->m_i_size; + + // columns in right side + const Index n = this->m_j_size; // define mr, nr, and all of my data mapper types typedef typename internal::remove_const::type LhsScalar; @@ -620,7 +639,7 @@ struct TensorContractionEvaluatorBase typedef internal::blas_data_mapper OutputMapper; // Declare GEBP packing and kernel structs - internal::gemm_pack_lhs pack_lhs; + internal::gemm_pack_lhs pack_lhs; internal::gemm_pack_rhs pack_rhs; internal::gebp_kernel gebp; @@ -635,7 +654,7 @@ struct TensorContractionEvaluatorBase OutputMapper output(buffer, m); // Sizes of the blocks to load in cache. See the Goto paper for details. - internal::TensorContractionBlocking blocking(k, m, n, 1); + internal::TensorContractionBlocking blocking(k_slice, m, n, num_threads); const Index kc = blocking.kc(); const Index mc = numext::mini(m, blocking.mc()); const Index nc = numext::mini(n, blocking.nc()); @@ -648,7 +667,7 @@ struct TensorContractionEvaluatorBase for(Index i2=0; i2::numThreads( static_cast(n) * m, cost, this->m_device.numThreads()); + int num_threads_by_k = numThreadsInnerDim(m, n, k); + if (false && shardByInnerDim(m, n, k, num_threads, num_threads_by_k)) { + // We are in the scenario where it is more effective to shard by the + // inner dimension. + this->template evalShardedByInnerDim(num_threads_by_k, + buffer); + return; + } // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost // model is not tuned. Remove this when the cost model is tuned. @@ -242,9 +250,9 @@ struct TensorEvaluator::size, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Unaligned> RhsMapper; - typedef internal::gemm_pack_lhs + typedef internal::gemm_pack_lhs< + LhsScalar, Index, typename LhsMapper::SubMapper, Traits::mr, + Traits::LhsProgress, typename Traits::LhsPacket4Packing, ColMajor> LhsPacker; typedef internal::gemm_pack_rhs< RhsScalar, Index, typename RhsMapper::SubMapper, Traits::nr, ColMajor> @@ -709,20 +717,9 @@ struct TensorEvaluator::size); const int output_packet_size = internal::unpacket_traits::size; const double kd = static_cast(bk); - // Peak VFMA bandwidth is 0.5. However if we have not enough data for - // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined - // experimentally. - double computeBandwidth = bk == 1 ? 4.0 : - (shard_by_col ? bn : bm) < Traits::nr || - (shard_by_col ? bm : bn) < Traits::mr ? 2.0 : 0.5; -#ifndef EIGEN_VECTORIZE_FMA - // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors. - // However for MULPS/ADDPS we have dependent sequence of 2 such instructions, - // so overall bandwidth is 1.0. - if (computeBandwidth == 0.5) computeBandwidth = 1.0; -#endif + double compute_bandwidth = computeBandwidth(false, bm, bn, bk); // Computations. - TensorOpCost cost = TensorOpCost(0, 0, kd * computeBandwidth, true, packed_size); + TensorOpCost cost = TensorOpCost(0, 0, kd * compute_bandwidth, true, packed_size); // Output stores. cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size); if (prepacked) { @@ -743,6 +740,162 @@ struct TensorEvaluator + EIGEN_STRONG_INLINE void addToBuffer(size_t n, const Scalar* src_buf, + Scalar* tgt_buf) const { + const int output_packet_size = internal::unpacket_traits::size; + size_t i = 0; + const size_t num_packets = n / output_packet_size; + for (; i < output_packet_size * num_packets; i += output_packet_size) { + const PacketReturnType src_val = + internal::pload(src_buf + i); + const PacketReturnType tgt_val = + internal::ploadt(tgt_buf + i); + const PacketReturnType sum = internal::padd(src_val, tgt_val); + internal::pstoret(tgt_buf + i, sum); + } + for (; i < n; ++i) { + tgt_buf[i] += src_buf[i]; + } + } + + // Decide whether we want to shard m x k x n contraction over the inner + // (contraction) dimension (k). + static bool shardByInnerDim(Index m, Index n, Index k, int num_threads, + int num_threads_by_k) { + size_t bufsize = m * n * sizeof(Scalar); + bool shard_by_k = false; + if (n == 1 || // If mat*vec or... + num_threads_by_k < 2 || // running single threaded or... + num_threads_by_k < + num_threads || // sharding by k gives less parallelism or... + bufsize > l3CacheSize() / num_threads_by_k || // need more buffer space + // than L3 cache or... + k / num_threads_by_k < 2 * Traits::nr) { // k per thread is tiny. + shard_by_k = false; + } else if (numext::maxi(m, n) / num_threads < + Traits::nr || // both other dimensions are tiny or... + // k per thread is not small and... + (k / num_threads_by_k > 8 * Traits::nr && + // one of the outer dimensions is tiny or sharding by k offers + // more parallelism. + (numext::mini(m, n) < 2 * Traits::nr || + num_threads_by_k > num_threads))) { + shard_by_k = true; + } + return shard_by_k; + } + + template + void evalShardedByInnerDim(int num_threads, Scalar* result) const { + const Index m = this->m_i_size; + const Index n = this->m_j_size; + const Index k = this->m_k_size; + // The underlying GEMM kernel assumes that k is a multiple of 8 and + // subtle breakage occurs if this is violated. + Index block_size = 8 * divup(k, 8 * num_threads); + int num_blocks = divup(k, block_size); + // we use 'result' for the first block's partial result. + MaxSizeVector block_buffers(num_blocks - 1); + Barrier barrier(num_blocks); + auto process_block = [=, &barrier](Scalar* buf, Index first, Index last) { + ::memset(buf, 0, m * n * sizeof(Scalar)); + TENSOR_CONTRACTION_DISPATCH( + this->template evalGemmPartial, Alignment, + (buf, first, last, this->m_device.numThreads())); + barrier.Notify(); + }; + Index start = 0; + for (int blocks_left = num_blocks; blocks_left > 0; --blocks_left) { + // The underlying GEMM kernel assumes that k is a multiple of 8 and + // subtle breakage occurs if this is violated. + block_size = 8 * divup(k - start, 8 * blocks_left); + Scalar* buf; + if (start == 0) { + buf = result; + } else { + buf = static_cast( + this->m_device.allocate(m * n * sizeof(Scalar))); + block_buffers.push_back(buf); + } + Index end = start + block_size; + if (end > k) { + end = k; + } + this->m_device.enqueueNoNotification( + [=, &process_block]() { process_block(buf, start, end); }); + start = end; + } + barrier.Wait(); + + // Add other partial results into first partial result. + for (const auto& buf : block_buffers) { + addToBuffer(m * n, buf, result); + this->m_device.deallocate(buf); + } + } + + TensorOpCost contractionCostPerInnerDim(Index m, Index n, Index k) const { + // Compute cost. + const int output_packet_size = internal::unpacket_traits::size; + TensorOpCost cost(0, 0, (computeBandwidth(true, m, n, k) * m) * n); + // Output stores. + cost += TensorOpCost(0, sizeof(CoeffReturnType), 0, true, output_packet_size); + TensorOpCost lhsCost = this->m_leftImpl.costPerCoeff(true) * m; + TensorOpCost rhsCost = this->m_rightImpl.costPerCoeff(true) * n; + // Since the inner gemm kernel is always sharded by column, the lhs + // load cost is negligible. + lhsCost.dropMemoryCost(); + return cost + lhsCost + rhsCost; + } + + int numThreadsInnerDim(Index m, Index n, Index k) const { + const int output_packet_size = internal::unpacket_traits::size; + TensorOpCost cost = contractionCostPerInnerDim(m, n, k); + double total_parallel_cost = + TensorCostModel::totalCost(k, cost); + // Cost of reduction step accumulating the m*n per-thread buffers into the + // result. + double reduction_cost = TensorCostModel::totalCost( + m * n, TensorOpCost(2, 1, 1, true, output_packet_size)); + Index num_threads = 1; + double min_cost = total_parallel_cost; + double kPerThreadOverHead = 4000; + double kFixedOverHead = 100000; + for (int nt = 2; nt <= this->m_device.numThreads(); nt++) { + double sequential_cost = + kFixedOverHead + nt * (reduction_cost + kPerThreadOverHead); + double parallel_cost = total_parallel_cost / nt + sequential_cost; + if (parallel_cost < min_cost) { + num_threads = nt; + min_cost = parallel_cost; + } + } + return num_threads; + } + + + double computeBandwidth(bool shard_by_col, Index bm, Index bn, + Index bk) const { + // Peak VFMA bandwidth is 0.5. However if we have not enough data for + // vectorization bandwidth drops. The 4.0 and 2.0 bandwidth is determined + // experimentally. + double computeBandwidth = + bk == 1 ? 4.0 + : (shard_by_col ? bn : bm) < Traits::nr || + (shard_by_col ? bm : bn) < Traits::mr + ? 2.0 + : 0.5; +#ifndef EIGEN_VECTORIZE_FMA + // Bandwidth of all of VFMA/MULPS/ADDPS is 0.5 on latest Intel processors. + // However for MULPS/ADDPS we have dependent sequence of 2 such + // instructions, + // so overall bandwidth is 1.0. + if (computeBandwidth == 0.5) computeBandwidth = 1.0; +#endif + return computeBandwidth; + } + #if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) // TODO(ezhulenev): Add support for output kernels and LIBXSMM. static_assert(std::is_same::value, diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h b/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h index bb63baee2..7f79ac30d 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h @@ -188,7 +188,6 @@ class TensorCostModel { return totalCost(output_size, cost_per_coeff) / kTaskSize; } - private: static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double totalCost( double output_size, const TensorOpCost& cost_per_coeff) { // Cost of memory fetches from L2 cache. 64 is typical cache line size.