mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-13 00:21:49 +08:00
Fuse computations into the Tensor contractions using output kernel
This commit is contained in:
parent
5539587b1f
commit
01fd4096d3
@ -517,9 +517,15 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
typedef Eigen::IndexPair<Index> DimensionPair;
|
typedef Eigen::IndexPair<Index> DimensionPair;
|
||||||
|
|
||||||
template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived>
|
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>
|
||||||
contract(const OtherDerived& other, const Dimensions& dims) const {
|
contract(const OtherDerived& other, const Dimensions& dims) const {
|
||||||
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived>(derived(), other.derived(), dims);
|
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>(derived(), other.derived(), dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherDerived, typename Dimensions, typename OutputKernel> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>
|
||||||
|
contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const {
|
||||||
|
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>(derived(), other.derived(), dims, output_kernel);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convolutions.
|
// Convolutions.
|
||||||
|
@ -85,8 +85,8 @@ template<typename LhsScalar, typename RhsScalar, typename Scalar>
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
|
||||||
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >
|
||||||
{
|
{
|
||||||
// Type promotion to handle the case where the types of the lhs and the rhs are different.
|
// Type promotion to handle the case where the types of the lhs and the rhs are different.
|
||||||
typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
|
typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
|
||||||
@ -112,23 +112,24 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
|
||||||
struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense>
|
struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense>
|
||||||
{
|
{
|
||||||
typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type;
|
typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
|
||||||
struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type>
|
struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >::type>
|
||||||
{
|
{
|
||||||
typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type;
|
typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_>
|
template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_, typename Device_>
|
||||||
struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > {
|
struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_> > {
|
||||||
typedef Indices_ Indices;
|
typedef Indices_ Indices;
|
||||||
typedef LeftArgType_ LeftArgType;
|
typedef LeftArgType_ LeftArgType;
|
||||||
typedef RightArgType_ RightArgType;
|
typedef RightArgType_ RightArgType;
|
||||||
|
typedef OutputKernelType_ OutputKernelType;
|
||||||
typedef Device_ Device;
|
typedef Device_ Device;
|
||||||
|
|
||||||
// From NumDims below.
|
// From NumDims below.
|
||||||
@ -137,8 +138,52 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
|
|||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
template<typename Indices, typename LhsXprType, typename RhsXprType>
|
// Tensor contraction params that should enable to get from output matrix
|
||||||
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors>
|
// 2-dimensional coordinates to the output tensor dimensions.
|
||||||
|
struct TensorContractionParams {
|
||||||
|
// TensorContraction evaluator assumes that both tensors are in ColMajor
|
||||||
|
// layout, if tensors are in RowMajor evaluator swap lhs with rhs.
|
||||||
|
bool swapped_arguments;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Output kernel allows to fuse operations into the tensor contraction.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// 1. Elementwise Relu transformation following Conv2D.
|
||||||
|
// 2. AddBias to the Conv2D output channels dimension.
|
||||||
|
//
|
||||||
|
// See expected implementation in NoOpOutputKernel.
|
||||||
|
struct OutputKernel {
|
||||||
|
template <typename Index, typename Scalar>
|
||||||
|
using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Output kernel that does absolutely nothing.
|
||||||
|
struct NoOpOutputKernel {
|
||||||
|
/**
|
||||||
|
* Tensor contraction evaluator calls this kernel after finishing each block
|
||||||
|
* of output matrix. Output blocks belong to the 2-dimensional output tensor.
|
||||||
|
*
|
||||||
|
* TensorContractionParams contains contraction dimensions information
|
||||||
|
* required to map output 2-d space into the expected output tensor space
|
||||||
|
* (potentially higher dimensional).
|
||||||
|
*
|
||||||
|
* \param[in] output_mapper Access to output tensor memory
|
||||||
|
* \param[in] params Tensor contraction parameters
|
||||||
|
* \param[in] i Index of a first row available through output_mapper
|
||||||
|
* \param[in] j Index of a first column available through output_mapper
|
||||||
|
* \param[in] num_rows Number of available rows
|
||||||
|
* \param[in] num_cols Number of available columns
|
||||||
|
*/
|
||||||
|
template <typename Index, typename Scalar>
|
||||||
|
EIGEN_ALWAYS_INLINE void operator()(
|
||||||
|
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
||||||
|
const TensorContractionParams& params, Index i, Index j, Index num_rows,
|
||||||
|
Index num_cols) const {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
|
||||||
|
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
|
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
|
||||||
@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
|
|||||||
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
|
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
|
||||||
const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims)
|
const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims,
|
||||||
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
|
const OutputKernelType& output_kernel = OutputKernelType())
|
||||||
|
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims),
|
||||||
|
m_output_kernel(output_kernel) {}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
const Indices& indices() const { return m_indices; }
|
const Indices& indices() const { return m_indices; }
|
||||||
@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
|
|||||||
const typename internal::remove_all<typename RhsXprType::Nested>::type&
|
const typename internal::remove_all<typename RhsXprType::Nested>::type&
|
||||||
rhsExpression() const { return m_rhs_xpr; }
|
rhsExpression() const { return m_rhs_xpr; }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
const OutputKernelType& outputKernel() const { return m_output_kernel; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
typename LhsXprType::Nested m_lhs_xpr;
|
typename LhsXprType::Nested m_lhs_xpr;
|
||||||
typename RhsXprType::Nested m_rhs_xpr;
|
typename RhsXprType::Nested m_rhs_xpr;
|
||||||
const Indices m_indices;
|
const Indices m_indices;
|
||||||
|
const OutputKernelType m_output_kernel;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase
|
|||||||
typedef typename internal::traits<Derived>::Indices Indices;
|
typedef typename internal::traits<Derived>::Indices Indices;
|
||||||
typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
|
typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
|
||||||
typedef typename internal::traits<Derived>::RightArgType RightArgType;
|
typedef typename internal::traits<Derived>::RightArgType RightArgType;
|
||||||
|
typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType;
|
||||||
typedef typename internal::traits<Derived>::Device Device;
|
typedef typename internal::traits<Derived>::Device Device;
|
||||||
|
|
||||||
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
|
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
|
||||||
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase
|
|||||||
op.lhsExpression(), op.rhsExpression()), device),
|
op.lhsExpression(), op.rhsExpression()), device),
|
||||||
m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
|
m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
|
||||||
op.rhsExpression(), op.lhsExpression()), device),
|
op.rhsExpression(), op.lhsExpression()), device),
|
||||||
|
m_output_kernel(op.outputKernel()),
|
||||||
m_device(device),
|
m_device(device),
|
||||||
m_result(NULL) {
|
m_result(NULL) {
|
||||||
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
|
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
|
||||||
@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase
|
|||||||
numext::swap(m_dimensions[i], m_dimensions[j]);
|
numext::swap(m_dimensions[i], m_dimensions[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A set of parameters that will allow output kernel to get from output
|
||||||
|
// tensor dimensions (i, j) into the original tensor dimensions.
|
||||||
|
// TODO(ezhulenev): Add parameters required to infer output tensor index for
|
||||||
|
// more complex contractions than 2x2 on internal dimension.
|
||||||
|
m_tensor_contraction_params = {
|
||||||
|
/**swapped_arguments=*/static_cast<int>(Layout) == RowMajor};
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
|
||||||
@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase
|
|||||||
|
|
||||||
// call gebp (matrix kernel)
|
// call gebp (matrix kernel)
|
||||||
// The parameters here are copied from Eigen's GEMM implementation
|
// The parameters here are copied from Eigen's GEMM implementation
|
||||||
gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
|
const auto output_mapper = output.getSubMapper(i2, j2);
|
||||||
|
gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc,
|
||||||
|
Scalar(1), -1, -1, 0, 0);
|
||||||
|
|
||||||
|
// We are done with this [i2, j2] output block.
|
||||||
|
if (k2 + kc >= k) {
|
||||||
|
m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2,
|
||||||
|
actual_mc, actual_nc);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -848,23 +916,26 @@ protected:
|
|||||||
Index m_j_size;
|
Index m_j_size;
|
||||||
Index m_k_size;
|
Index m_k_size;
|
||||||
|
|
||||||
|
TensorContractionParams m_tensor_contraction_params;
|
||||||
|
|
||||||
TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
|
TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
|
||||||
TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
|
TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
|
||||||
const Device& m_device;
|
const Device& m_device;
|
||||||
|
OutputKernelType m_output_kernel;
|
||||||
Scalar* m_result;
|
Scalar* m_result;
|
||||||
bool m_can_use_xsmm;
|
bool m_can_use_xsmm;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
// evaluator for default device
|
// evaluator for default device
|
||||||
template<typename Indices, typename LeftArgType, typename RightArgType, typename Device>
|
template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType, typename Device>
|
||||||
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> :
|
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> :
|
||||||
public TensorContractionEvaluatorBase<
|
public TensorContractionEvaluatorBase<
|
||||||
TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
|
TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > {
|
||||||
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
|
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
|
||||||
typedef TensorContractionEvaluatorBase<Self> Base;
|
typedef TensorContractionEvaluatorBase<Self> Base;
|
||||||
|
|
||||||
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
|
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
|
||||||
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
@ -56,16 +56,16 @@ struct packRhsAndKernelArg {
|
|||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
#endif // EIGEN_USE_SIMPLE_THREAD_POOL
|
#endif // EIGEN_USE_SIMPLE_THREAD_POOL
|
||||||
|
|
||||||
template<typename Indices, typename LeftArgType, typename RightArgType>
|
template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
|
||||||
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> :
|
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
|
||||||
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > {
|
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
|
||||||
|
|
||||||
typedef ThreadPoolDevice Device;
|
typedef ThreadPoolDevice Device;
|
||||||
|
|
||||||
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self;
|
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
|
||||||
typedef TensorContractionEvaluatorBase<Self> Base;
|
typedef TensorContractionEvaluatorBase<Self> Base;
|
||||||
|
|
||||||
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType;
|
typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
|
||||||
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
@ -308,7 +308,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
this->m_k_strides);
|
this->m_k_strides);
|
||||||
|
|
||||||
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
|
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
|
||||||
OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n,
|
OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
|
||||||
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
|
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
|
||||||
shard_by_col, parallel_pack)
|
shard_by_col, parallel_pack)
|
||||||
.run();
|
.run();
|
||||||
@ -319,16 +319,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
typename LhsMapper, typename RhsMapper, typename OutputMapper>
|
typename LhsMapper, typename RhsMapper, typename OutputMapper>
|
||||||
class Context {
|
class Context {
|
||||||
public:
|
public:
|
||||||
Context(const Device& device, int num_threads, LhsMapper& lhs,
|
Context(const Self* self, int num_threads, LhsMapper& lhs,
|
||||||
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
|
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
|
||||||
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
|
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
|
||||||
Index gn, Index nm0, Index nn0, bool shard_by_col,
|
Index gn, Index nm0, Index nn0, bool shard_by_col,
|
||||||
bool parallel_pack)
|
bool parallel_pack)
|
||||||
: device_(device),
|
: device_(self->m_device),
|
||||||
lhs_(lhs),
|
lhs_(lhs),
|
||||||
rhs_(rhs),
|
rhs_(rhs),
|
||||||
buffer_(buffer),
|
buffer_(buffer),
|
||||||
output_(buffer, tm),
|
output_(buffer, tm),
|
||||||
|
output_kernel_(self->m_output_kernel),
|
||||||
|
tensor_contraction_params_(self->m_tensor_contraction_params),
|
||||||
num_threads_(num_threads),
|
num_threads_(num_threads),
|
||||||
shard_by_col_(shard_by_col),
|
shard_by_col_(shard_by_col),
|
||||||
parallel_pack_(parallel_pack),
|
parallel_pack_(parallel_pack),
|
||||||
@ -420,6 +422,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
RhsMapper& rhs_;
|
RhsMapper& rhs_;
|
||||||
Scalar* const buffer_;
|
Scalar* const buffer_;
|
||||||
OutputMapper output_;
|
OutputMapper output_;
|
||||||
|
OutputKernelType output_kernel_;
|
||||||
|
TensorContractionParams tensor_contraction_params_;
|
||||||
const int num_threads_;
|
const int num_threads_;
|
||||||
const bool shard_by_col_;
|
const bool shard_by_col_;
|
||||||
const bool parallel_pack_;
|
const bool parallel_pack_;
|
||||||
@ -536,19 +540,32 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
const Index mend = m * gm_ + gm(m);
|
const Index mend = m * gm_ + gm(m);
|
||||||
if (shard_by_col_) {
|
if (shard_by_col_) {
|
||||||
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
||||||
for (Index m1 = m * gm_; m1 < mend; m1++)
|
for (Index m1 = m * gm_; m1 < mend; m1++) {
|
||||||
GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
|
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
||||||
packed_lhs_[k % (P - 1)][m1],
|
GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
|
||||||
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
|
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
|
||||||
Scalar(1), -1, -1, 0, 0);
|
Scalar(1), -1, -1, 0, 0);
|
||||||
|
|
||||||
|
// We are done with the last task for the [m1, n1] block.
|
||||||
|
if (k + 1 == nk_) {
|
||||||
|
output_kernel_(output_mapper, tensor_contraction_params_,
|
||||||
|
m1 * bm_, n1 * bn_, bm(m1), bn(n1));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (Index m1 = m * gm_; m1 < mend; m1++)
|
for (Index m1 = m * gm_; m1 < mend; m1++)
|
||||||
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
for (Index n1 = n * gn_; n1 < nend; n1++) {
|
||||||
GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_),
|
const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
|
||||||
packed_lhs_[k % (P - 1)][m1],
|
GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
|
||||||
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
|
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
|
||||||
Scalar(1), -1, -1, 0, 0);
|
Scalar(1), -1, -1, 0, 0);
|
||||||
|
|
||||||
|
// We are done with the last task for the [m1, n1] block.
|
||||||
|
if (k + 1 == nk_) {
|
||||||
|
output_kernel_(output_mapper, tensor_contraction_params_,
|
||||||
|
m1 * bm_, n1 * bn_, bm(m1), bn(n1));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
signal_kernel(m, n, k + 1, false);
|
signal_kernel(m, n, k + 1, false);
|
||||||
@ -747,6 +764,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
}
|
}
|
||||||
|
|
||||||
#else // EIGEN_USE_SIMPLE_THREAD_POOL
|
#else // EIGEN_USE_SIMPLE_THREAD_POOL
|
||||||
|
// TODO(ezhulenev): SimpleThreadPool will be removed in the future, and seems
|
||||||
|
// like it's not worth adding output kernel support here.
|
||||||
|
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
|
||||||
|
"SimpleThreadPool does not support contraction output kernels.");
|
||||||
|
|
||||||
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
|
||||||
void evalProduct(Scalar* buffer) const {
|
void evalProduct(Scalar* buffer) const {
|
||||||
@ -1065,6 +1086,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
|
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
|
||||||
|
// TODO(ezhulenev): Add support for output kernels and LIBXSMM.
|
||||||
|
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
|
||||||
|
"XSMM does not support contraction output kernels.");
|
||||||
|
|
||||||
template<int Alignment>
|
template<int Alignment>
|
||||||
class ContextXsmm {
|
class ContextXsmm {
|
||||||
public:
|
public:
|
||||||
|
@ -65,7 +65,7 @@ template<typename Op, typename Dims, typename XprType, template <class> class Ma
|
|||||||
template<typename XprType> class TensorIndexTupleOp;
|
template<typename XprType> class TensorIndexTupleOp;
|
||||||
template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp;
|
template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp;
|
||||||
template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp;
|
template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp;
|
||||||
template<typename Dimensions, typename LeftXprType, typename RightXprType> class TensorContractionOp;
|
template<typename Dimensions, typename LeftXprType, typename RightXprType, typename OutputKernelType> class TensorContractionOp;
|
||||||
template<typename TargetType, typename XprType> class TensorConversionOp;
|
template<typename TargetType, typename XprType> class TensorConversionOp;
|
||||||
template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp;
|
template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp;
|
||||||
template<typename FFT, typename XprType, int FFTDataType, int FFTDirection> class TensorFFTOp;
|
template<typename FFT, typename XprType, int FFTDataType, int FFTDirection> class TensorFFTOp;
|
||||||
@ -97,6 +97,8 @@ template<typename XprType> class TensorForcedEvalOp;
|
|||||||
template<typename ExpressionType, typename DeviceType> class TensorDevice;
|
template<typename ExpressionType, typename DeviceType> class TensorDevice;
|
||||||
template<typename Derived, typename Device> struct TensorEvaluator;
|
template<typename Derived, typename Device> struct TensorEvaluator;
|
||||||
|
|
||||||
|
class NoOpOutputKernel;
|
||||||
|
|
||||||
struct DefaultDevice;
|
struct DefaultDevice;
|
||||||
struct ThreadPoolDevice;
|
struct ThreadPoolDevice;
|
||||||
struct GpuDevice;
|
struct GpuDevice;
|
||||||
|
@ -510,6 +510,55 @@ static void test_const_inputs()
|
|||||||
VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
|
VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply Sqrt to all output elements.
|
||||||
|
struct SqrtOutputKernel {
|
||||||
|
template <typename Index, typename Scalar>
|
||||||
|
EIGEN_ALWAYS_INLINE void operator()(
|
||||||
|
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
||||||
|
const TensorContractionParams&, Index, Index, Index num_rows,
|
||||||
|
Index num_cols) const {
|
||||||
|
for (int i = 0; i < num_rows; ++i) {
|
||||||
|
for (int j = 0; j < num_cols; ++j) {
|
||||||
|
output_mapper(i, j) = std::sqrt(output_mapper(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_large_contraction_with_output_kernel() {
|
||||||
|
Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31);
|
||||||
|
Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10);
|
||||||
|
Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10);
|
||||||
|
|
||||||
|
t_left.setRandom();
|
||||||
|
t_right.setRandom();
|
||||||
|
// Put trash in mat4 to verify contraction clears output memory.
|
||||||
|
t_result.setRandom();
|
||||||
|
|
||||||
|
// Add a little offset so that the results won't be close to zero.
|
||||||
|
t_left += t_left.constant(1.0f);
|
||||||
|
t_right += t_right.constant(1.0f);
|
||||||
|
|
||||||
|
typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
|
||||||
|
MapXf m_left(t_left.data(), 1500, 248);
|
||||||
|
MapXf m_right(t_right.data(), 248, 1400);
|
||||||
|
Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400);
|
||||||
|
|
||||||
|
// this contraction should be equivalent to a single matrix multiplication
|
||||||
|
Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}});
|
||||||
|
|
||||||
|
// compute results by separate methods
|
||||||
|
t_result = t_left.contract(t_right, dims, SqrtOutputKernel());
|
||||||
|
|
||||||
|
m_result = m_left * m_right;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
|
||||||
|
VERIFY(&t_result.data()[i] != &m_result.data()[i]);
|
||||||
|
VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void test_cxx11_tensor_contraction()
|
void test_cxx11_tensor_contraction()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_evals<ColMajor>());
|
CALL_SUBTEST(test_evals<ColMajor>());
|
||||||
@ -542,4 +591,6 @@ void test_cxx11_tensor_contraction()
|
|||||||
CALL_SUBTEST(test_tensor_product<RowMajor>());
|
CALL_SUBTEST(test_tensor_product<RowMajor>());
|
||||||
CALL_SUBTEST(test_const_inputs<ColMajor>());
|
CALL_SUBTEST(test_const_inputs<ColMajor>());
|
||||||
CALL_SUBTEST(test_const_inputs<RowMajor>());
|
CALL_SUBTEST(test_const_inputs<RowMajor>());
|
||||||
|
CALL_SUBTEST(test_large_contraction_with_output_kernel<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_large_contraction_with_output_kernel<RowMajor>());
|
||||||
}
|
}
|
||||||
|
@ -232,6 +232,60 @@ void test_multithread_contraction_agrees_with_singlethread() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Apply Sqrt to all output elements.
|
||||||
|
struct SqrtOutputKernel {
|
||||||
|
template <typename Index, typename Scalar>
|
||||||
|
EIGEN_ALWAYS_INLINE void operator()(
|
||||||
|
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
|
||||||
|
const TensorContractionParams&, Index, Index, Index num_rows,
|
||||||
|
Index num_cols) const {
|
||||||
|
for (int i = 0; i < num_rows; ++i) {
|
||||||
|
for (int j = 0; j < num_cols; ++j) {
|
||||||
|
output_mapper(i, j) = std::sqrt(output_mapper(i, j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_multithread_contraction_with_output_kernel() {
|
||||||
|
typedef Tensor<float, 1>::DimensionPair DimPair;
|
||||||
|
|
||||||
|
const int num_threads = internal::random<int>(2, 11);
|
||||||
|
ThreadPool threads(num_threads);
|
||||||
|
Eigen::ThreadPoolDevice device(&threads, num_threads);
|
||||||
|
|
||||||
|
Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31);
|
||||||
|
Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10);
|
||||||
|
Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10);
|
||||||
|
|
||||||
|
t_left.setRandom();
|
||||||
|
t_right.setRandom();
|
||||||
|
// Put trash in mat4 to verify contraction clears output memory.
|
||||||
|
t_result.setRandom();
|
||||||
|
|
||||||
|
// Add a little offset so that the results won't be close to zero.
|
||||||
|
t_left += t_left.constant(1.0f);
|
||||||
|
t_right += t_right.constant(1.0f);
|
||||||
|
|
||||||
|
typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
|
||||||
|
MapXf m_left(t_left.data(), 1500, 248);
|
||||||
|
MapXf m_right(t_right.data(), 248, 1400);
|
||||||
|
Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400);
|
||||||
|
|
||||||
|
// this contraction should be equivalent to a single matrix multiplication
|
||||||
|
Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}});
|
||||||
|
|
||||||
|
// compute results by separate methods
|
||||||
|
t_result.device(device) = t_left.contract(t_right, dims, SqrtOutputKernel());
|
||||||
|
|
||||||
|
m_result = m_left * m_right;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
|
||||||
|
VERIFY(&t_result.data()[i] != &m_result.data()[i]);
|
||||||
|
VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<int DataLayout>
|
template<int DataLayout>
|
||||||
void test_full_contraction() {
|
void test_full_contraction() {
|
||||||
@ -355,6 +409,8 @@ void test_cxx11_tensor_thread_pool()
|
|||||||
|
|
||||||
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
|
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
|
||||||
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
|
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
|
||||||
|
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>());
|
||||||
|
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>());
|
||||||
|
|
||||||
// Exercise various cases that have been problematic in the past.
|
// Exercise various cases that have been problematic in the past.
|
||||||
CALL_SUBTEST_4(test_contraction_corner_cases<ColMajor>());
|
CALL_SUBTEST_4(test_contraction_corner_cases<ColMajor>());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user