Leverage the new blocking code in the tensor contraction code.

This commit is contained in:
Benoit Steiner 2016-01-22 16:36:30 -08:00
parent 4beb447e27
commit 3aeeca32af
3 changed files with 8 additions and 9 deletions

View File

@ -582,10 +582,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
OutputMapper output(buffer, m); OutputMapper output(buffer, m);
typedef typename internal::gemm_blocking_space<ColMajor, LhsScalar, RhsScalar, Dynamic, Dynamic, Dynamic> BlockingType;
// Sizes of the blocks to load in cache. See the Goto paper for details. // Sizes of the blocks to load in cache. See the Goto paper for details.
BlockingType blocking(m, n, k, 1, true); internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
const Index kc = blocking.kc(); const Index kc = blocking.kc();
const Index mc = numext::mini(m, blocking.mc()); const Index mc = numext::mini(m, blocking.mc());
const Index nc = numext::mini(n, blocking.nc()); const Index nc = numext::mini(n, blocking.nc());

View File

@ -426,15 +426,16 @@ class TensorContractionSubMapper {
}; };
template<typename Scalar, typename Index, int side, template<typename Scalar_, typename Index, int side,
typename Tensor, typename Tensor,
typename nocontract_t, typename contract_t, typename nocontract_t, typename contract_t,
int packet_size, int packet_size,
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment> bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
class TensorContractionInputMapper class TensorContractionInputMapper
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> { : public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
public: public:
typedef Scalar_ Scalar;
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base; typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper; typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
typedef SubMapper VectorMapper; typedef SubMapper VectorMapper;

View File

@ -176,10 +176,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// compute block sizes (which depend on number of threads) // compute block sizes (which depend on number of threads)
const Index num_threads = this->m_device.numThreads(); const Index num_threads = this->m_device.numThreads();
Index mc = m; internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, num_threads);
Index nc = n; Index mc = blocking.mc();
Index kc = k; Index nc = blocking.nc();
internal::computeProductBlockingSizes<LhsScalar,RhsScalar,1>(kc, mc, nc, num_threads); Index kc = blocking.kc();
eigen_assert(mc <= m); eigen_assert(mc <= m);
eigen_assert(nc <= n); eigen_assert(nc <= n);
eigen_assert(kc <= k); eigen_assert(kc <= k);