mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-16 11:23:14 +08:00
Leverage the new blocking code in the tensor contraction code.
This commit is contained in:
parent
4beb447e27
commit
3aeeca32af
@ -582,10 +582,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
|
|
||||||
OutputMapper output(buffer, m);
|
OutputMapper output(buffer, m);
|
||||||
|
|
||||||
typedef typename internal::gemm_blocking_space<ColMajor, LhsScalar, RhsScalar, Dynamic, Dynamic, Dynamic> BlockingType;
|
|
||||||
|
|
||||||
// Sizes of the blocks to load in cache. See the Goto paper for details.
|
// Sizes of the blocks to load in cache. See the Goto paper for details.
|
||||||
BlockingType blocking(m, n, k, 1, true);
|
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, 1);
|
||||||
const Index kc = blocking.kc();
|
const Index kc = blocking.kc();
|
||||||
const Index mc = numext::mini(m, blocking.mc());
|
const Index mc = numext::mini(m, blocking.mc());
|
||||||
const Index nc = numext::mini(n, blocking.nc());
|
const Index nc = numext::mini(n, blocking.nc());
|
||||||
|
@ -426,15 +426,16 @@ class TensorContractionSubMapper {
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template<typename Scalar, typename Index, int side,
|
template<typename Scalar_, typename Index, int side,
|
||||||
typename Tensor,
|
typename Tensor,
|
||||||
typename nocontract_t, typename contract_t,
|
typename nocontract_t, typename contract_t,
|
||||||
int packet_size,
|
int packet_size,
|
||||||
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment>
|
||||||
class TensorContractionInputMapper
|
class TensorContractionInputMapper
|
||||||
: public BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
|
: public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
typedef Scalar_ Scalar;
|
||||||
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
|
typedef BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> Base;
|
||||||
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
typedef TensorContractionSubMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> SubMapper;
|
||||||
typedef SubMapper VectorMapper;
|
typedef SubMapper VectorMapper;
|
||||||
|
@ -176,10 +176,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
|
|
||||||
// compute block sizes (which depend on number of threads)
|
// compute block sizes (which depend on number of threads)
|
||||||
const Index num_threads = this->m_device.numThreads();
|
const Index num_threads = this->m_device.numThreads();
|
||||||
Index mc = m;
|
internal::TensorContractionBlocking<LhsMapper, RhsMapper, Index, internal::ShardByCol> blocking(k, m, n, num_threads);
|
||||||
Index nc = n;
|
Index mc = blocking.mc();
|
||||||
Index kc = k;
|
Index nc = blocking.nc();
|
||||||
internal::computeProductBlockingSizes<LhsScalar,RhsScalar,1>(kc, mc, nc, num_threads);
|
Index kc = blocking.kc();
|
||||||
eigen_assert(mc <= m);
|
eigen_assert(mc <= m);
|
||||||
eigen_assert(nc <= n);
|
eigen_assert(nc <= n);
|
||||||
eigen_assert(kc <= k);
|
eigen_assert(kc <= k);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user