Fuse computations into the Tensor contractions using output kernel

This commit is contained in:
Eugene Zhulenev 2018-07-10 13:16:38 -07:00
parent 5539587b1f
commit 01fd4096d3
6 changed files with 248 additions and 37 deletions

View File

@ -517,9 +517,15 @@ class TensorBase<Derived, ReadOnlyAccessors>
typedef Eigen::IndexPair<Index> DimensionPair; typedef Eigen::IndexPair<Index> DimensionPair;
template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template<typename OtherDerived, typename Dimensions> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived> const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>
contract(const OtherDerived& other, const Dimensions& dims) const { contract(const OtherDerived& other, const Dimensions& dims) const {
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived>(derived(), other.derived(), dims); return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const NoOpOutputKernel>(derived(), other.derived(), dims);
}
template<typename OtherDerived, typename Dimensions, typename OutputKernel> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>
contract(const OtherDerived& other, const Dimensions& dims, const OutputKernel& output_kernel) const {
return TensorContractionOp<const Dimensions, const Derived, const OtherDerived, const OutputKernel>(derived(), other.derived(), dims, output_kernel);
} }
// Convolutions. // Convolutions.

View File

@ -85,8 +85,8 @@ template<typename LhsScalar, typename RhsScalar, typename Scalar>
#endif #endif
template<typename Dimensions, typename LhsXprType, typename RhsXprType> template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> > struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >
{ {
// Type promotion to handle the case where the types of the lhs and the rhs are different. // Type promotion to handle the case where the types of the lhs and the rhs are different.
typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type, typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
@ -112,23 +112,24 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
}; };
}; };
template<typename Dimensions, typename LhsXprType, typename RhsXprType> template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, Eigen::Dense> struct eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, Eigen::Dense>
{ {
typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType>& type; typedef const TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>& type;
}; };
template<typename Dimensions, typename LhsXprType, typename RhsXprType> template<typename Dimensions, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >::type> struct nested<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType>, 1, typename eval<TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> >::type>
{ {
typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType> type; typedef TensorContractionOp<Dimensions, LhsXprType, RhsXprType, OutputKernelType> type;
}; };
template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename Device_> template<typename Indices_, typename LeftArgType_, typename RightArgType_, typename OutputKernelType_, typename Device_>
struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_>, Device_> > { struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_, RightArgType_, OutputKernelType_>, Device_> > {
typedef Indices_ Indices; typedef Indices_ Indices;
typedef LeftArgType_ LeftArgType; typedef LeftArgType_ LeftArgType;
typedef RightArgType_ RightArgType; typedef RightArgType_ RightArgType;
typedef OutputKernelType_ OutputKernelType;
typedef Device_ Device; typedef Device_ Device;
// From NumDims below. // From NumDims below.
@ -137,8 +138,52 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
} // end namespace internal } // end namespace internal
template<typename Indices, typename LhsXprType, typename RhsXprType> // Tensor contraction params that should enable to get from output matrix
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType>, ReadOnlyAccessors> // 2-dimensional coordinates to the output tensor dimensions.
struct TensorContractionParams {
// TensorContraction evaluator assumes that both tensors are in ColMajor
// layout, if tensors are in RowMajor evaluator swap lhs with rhs.
bool swapped_arguments;
};
// Output kernel allows to fuse operations into the tensor contraction.
//
// Examples:
// 1. Elementwise Relu transformation following Conv2D.
// 2. AddBias to the Conv2D output channels dimension.
//
// See expected implementation in NoOpOutputKernel.
struct OutputKernel {
template <typename Index, typename Scalar>
using OutputMapper = internal::blas_data_mapper<Scalar, Index, ColMajor>;
};
// Output kernel that does absolutely nothing.
struct NoOpOutputKernel {
/**
* Tensor contraction evaluator calls this kernel after finishing each block
* of output matrix. Output blocks belong to the 2-dimensional output tensor.
*
* TensorContractionParams contains contraction dimensions information
* required to map output 2-d space into the expected output tensor space
* (potentially higher dimensional).
*
* \param[in] output_mapper Access to output tensor memory
* \param[in] params Tensor contraction parameters
* \param[in] i Index of a first row available through output_mapper
* \param[in] j Index of a first column available through output_mapper
* \param[in] num_rows Number of available rows
* \param[in] num_cols Number of available columns
*/
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
const TensorContractionParams& params, Index i, Index j, Index num_rows,
Index num_cols) const {}
};
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType>
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors>
{ {
public: public:
typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar; typedef typename Eigen::internal::traits<TensorContractionOp>::Scalar Scalar;
@ -149,8 +194,10 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index; typedef typename Eigen::internal::traits<TensorContractionOp>::Index Index;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp( EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionOp(
const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims) const LhsXprType& lhs, const RhsXprType& rhs, const Indices& dims,
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {} const OutputKernelType& output_kernel = OutputKernelType())
: m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims),
m_output_kernel(output_kernel) {}
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
const Indices& indices() const { return m_indices; } const Indices& indices() const { return m_indices; }
@ -164,10 +211,14 @@ class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXp
const typename internal::remove_all<typename RhsXprType::Nested>::type& const typename internal::remove_all<typename RhsXprType::Nested>::type&
rhsExpression() const { return m_rhs_xpr; } rhsExpression() const { return m_rhs_xpr; }
EIGEN_DEVICE_FUNC
const OutputKernelType& outputKernel() const { return m_output_kernel; }
protected: protected:
typename LhsXprType::Nested m_lhs_xpr; typename LhsXprType::Nested m_lhs_xpr;
typename RhsXprType::Nested m_rhs_xpr; typename RhsXprType::Nested m_rhs_xpr;
const Indices m_indices; const Indices m_indices;
const OutputKernelType m_output_kernel;
}; };
@ -177,9 +228,10 @@ struct TensorContractionEvaluatorBase
typedef typename internal::traits<Derived>::Indices Indices; typedef typename internal::traits<Derived>::Indices Indices;
typedef typename internal::traits<Derived>::LeftArgType LeftArgType; typedef typename internal::traits<Derived>::LeftArgType LeftArgType;
typedef typename internal::traits<Derived>::RightArgType RightArgType; typedef typename internal::traits<Derived>::RightArgType RightArgType;
typedef typename internal::traits<Derived>::OutputKernelType OutputKernelType;
typedef typename internal::traits<Derived>::Device Device; typedef typename internal::traits<Derived>::Device Device;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;
@ -221,6 +273,7 @@ struct TensorContractionEvaluatorBase
op.lhsExpression(), op.rhsExpression()), device), op.lhsExpression(), op.rhsExpression()), device),
m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(), m_rightImpl(choose(Cond<static_cast<int>(Layout) == static_cast<int>(ColMajor)>(),
op.rhsExpression(), op.lhsExpression()), device), op.rhsExpression(), op.lhsExpression()), device),
m_output_kernel(op.outputKernel()),
m_device(device), m_device(device),
m_result(NULL) { m_result(NULL) {
EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) == EIGEN_STATIC_ASSERT((static_cast<int>(TensorEvaluator<LeftArgType, Device>::Layout) ==
@ -391,6 +444,13 @@ struct TensorContractionEvaluatorBase
numext::swap(m_dimensions[i], m_dimensions[j]); numext::swap(m_dimensions[i], m_dimensions[j]);
} }
} }
// A set of parameters that will allow output kernel to get from output
// tensor dimensions (i, j) into the original tensor dimensions.
// TODO(ezhulenev): Add parameters required to infer output tensor index for
// more complex contractions than 2x2 on internal dimension.
m_tensor_contraction_params = {
/**swapped_arguments=*/static_cast<int>(Layout) == RowMajor};
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; }
@ -585,7 +645,15 @@ struct TensorContractionEvaluatorBase
// call gebp (matrix kernel) // call gebp (matrix kernel)
// The parameters here are copied from Eigen's GEMM implementation // The parameters here are copied from Eigen's GEMM implementation
gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0); const auto output_mapper = output.getSubMapper(i2, j2);
gebp(output_mapper, blockA, blockB, actual_mc, actual_kc, actual_nc,
Scalar(1), -1, -1, 0, 0);
// We are done with this [i2, j2] output block.
if (k2 + kc >= k) {
m_output_kernel(output_mapper, m_tensor_contraction_params, i2, j2,
actual_mc, actual_nc);
}
} }
} }
} }
@ -848,23 +916,26 @@ protected:
Index m_j_size; Index m_j_size;
Index m_k_size; Index m_k_size;
TensorContractionParams m_tensor_contraction_params;
TensorEvaluator<EvalLeftArgType, Device> m_leftImpl; TensorEvaluator<EvalLeftArgType, Device> m_leftImpl;
TensorEvaluator<EvalRightArgType, Device> m_rightImpl; TensorEvaluator<EvalRightArgType, Device> m_rightImpl;
const Device& m_device; const Device& m_device;
OutputKernelType m_output_kernel;
Scalar* m_result; Scalar* m_result;
bool m_can_use_xsmm; bool m_can_use_xsmm;
}; };
// evaluator for default device // evaluator for default device
template<typename Indices, typename LeftArgType, typename RightArgType, typename Device> template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType, typename Device>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> : struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> :
public TensorContractionEvaluatorBase< public TensorContractionEvaluatorBase<
TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > { TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> > {
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base; typedef TensorContractionEvaluatorBase<Self> Base;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;

View File

@ -56,16 +56,16 @@ struct packRhsAndKernelArg {
} // end namespace internal } // end namespace internal
#endif // EIGEN_USE_SIMPLE_THREAD_POOL #endif // EIGEN_USE_SIMPLE_THREAD_POOL
template<typename Indices, typename LeftArgType, typename RightArgType> template<typename Indices, typename LeftArgType, typename RightArgType, typename OutputKernelType>
struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> : struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> :
public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, ThreadPoolDevice> > { public TensorContractionEvaluatorBase<TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, ThreadPoolDevice> > {
typedef ThreadPoolDevice Device; typedef ThreadPoolDevice Device;
typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> Self; typedef TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType>, Device> Self;
typedef TensorContractionEvaluatorBase<Self> Base; typedef TensorContractionEvaluatorBase<Self> Base;
typedef TensorContractionOp<Indices, LeftArgType, RightArgType> XprType; typedef TensorContractionOp<Indices, LeftArgType, RightArgType, OutputKernelType> XprType;
typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar; typedef typename internal::remove_const<typename XprType::Scalar>::type Scalar;
typedef typename XprType::Index Index; typedef typename XprType::Index Index;
typedef typename XprType::CoeffReturnType CoeffReturnType; typedef typename XprType::CoeffReturnType CoeffReturnType;
@ -308,7 +308,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
this->m_k_strides); this->m_k_strides);
Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper, Context<LhsPacker, RhsPacker, GebpKernel, LhsMapper, RhsMapper,
OutputMapper>(this->m_device, num_threads, lhs, rhs, buffer, m, n, OutputMapper>(this, num_threads, lhs, rhs, buffer, m, n,
k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0, k, bm, bn, bk, nm, nn, nk, gm, gn, nm0, nn0,
shard_by_col, parallel_pack) shard_by_col, parallel_pack)
.run(); .run();
@ -319,16 +319,18 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
typename LhsMapper, typename RhsMapper, typename OutputMapper> typename LhsMapper, typename RhsMapper, typename OutputMapper>
class Context { class Context {
public: public:
Context(const Device& device, int num_threads, LhsMapper& lhs, Context(const Self* self, int num_threads, LhsMapper& lhs,
RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm, RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
Index gn, Index nm0, Index nn0, bool shard_by_col, Index gn, Index nm0, Index nn0, bool shard_by_col,
bool parallel_pack) bool parallel_pack)
: device_(device), : device_(self->m_device),
lhs_(lhs), lhs_(lhs),
rhs_(rhs), rhs_(rhs),
buffer_(buffer), buffer_(buffer),
output_(buffer, tm), output_(buffer, tm),
output_kernel_(self->m_output_kernel),
tensor_contraction_params_(self->m_tensor_contraction_params),
num_threads_(num_threads), num_threads_(num_threads),
shard_by_col_(shard_by_col), shard_by_col_(shard_by_col),
parallel_pack_(parallel_pack), parallel_pack_(parallel_pack),
@ -420,6 +422,8 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
RhsMapper& rhs_; RhsMapper& rhs_;
Scalar* const buffer_; Scalar* const buffer_;
OutputMapper output_; OutputMapper output_;
OutputKernelType output_kernel_;
TensorContractionParams tensor_contraction_params_;
const int num_threads_; const int num_threads_;
const bool shard_by_col_; const bool shard_by_col_;
const bool parallel_pack_; const bool parallel_pack_;
@ -536,19 +540,32 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
const Index mend = m * gm_ + gm(m); const Index mend = m * gm_ + gm(m);
if (shard_by_col_) { if (shard_by_col_) {
for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index n1 = n * gn_; n1 < nend; n1++) {
for (Index m1 = m * gm_; m1 < mend; m1++) for (Index m1 = m * gm_; m1 < mend; m1++) {
GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_), const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
packed_lhs_[k % (P - 1)][m1], GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
Scalar(1), -1, -1, 0, 0); Scalar(1), -1, -1, 0, 0);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
output_kernel_(output_mapper, tensor_contraction_params_,
m1 * bm_, n1 * bn_, bm(m1), bn(n1));
}
}
} }
} else { } else {
for (Index m1 = m * gm_; m1 < mend; m1++) for (Index m1 = m * gm_; m1 < mend; m1++)
for (Index n1 = n * gn_; n1 < nend; n1++) { for (Index n1 = n * gn_; n1 < nend; n1++) {
GebpKernel()(output_.getSubMapper(m1 * bm_, n1 * bn_), const auto output_mapper = output_.getSubMapper(m1 * bm_, n1 * bn_);
packed_lhs_[k % (P - 1)][m1], GebpKernel()(output_mapper, packed_lhs_[k % (P - 1)][m1],
packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1), packed_rhs_[k % (P - 1)][n1], bm(m1), bk(k), bn(n1),
Scalar(1), -1, -1, 0, 0); Scalar(1), -1, -1, 0, 0);
// We are done with the last task for the [m1, n1] block.
if (k + 1 == nk_) {
output_kernel_(output_mapper, tensor_contraction_params_,
m1 * bm_, n1 * bn_, bm(m1), bn(n1));
}
} }
} }
signal_kernel(m, n, k + 1, false); signal_kernel(m, n, k + 1, false);
@ -747,6 +764,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
} }
#else // EIGEN_USE_SIMPLE_THREAD_POOL #else // EIGEN_USE_SIMPLE_THREAD_POOL
// TODO(ezhulenev): SimpleThreadPool will be removed in the future, and seems
// like it's not worth adding output kernel support here.
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
"SimpleThreadPool does not support contraction output kernels.");
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalProduct(Scalar* buffer) const { void evalProduct(Scalar* buffer) const {
@ -1065,6 +1086,10 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
} }
#if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM) #if defined(EIGEN_VECTORIZE_AVX) && defined(EIGEN_USE_LIBXSMM)
// TODO(ezhulenev): Add support for output kernels and LIBXSMM.
static_assert(std::is_same<OutputKernelType, const NoOpOutputKernel>::value,
"XSMM does not support contraction output kernels.");
template<int Alignment> template<int Alignment>
class ContextXsmm { class ContextXsmm {
public: public:

View File

@ -65,7 +65,7 @@ template<typename Op, typename Dims, typename XprType, template <class> class Ma
template<typename XprType> class TensorIndexTupleOp; template<typename XprType> class TensorIndexTupleOp;
template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp; template<typename ReduceOp, typename Dims, typename XprType> class TensorTupleReducerOp;
template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp; template<typename Axis, typename LeftXprType, typename RightXprType> class TensorConcatenationOp;
template<typename Dimensions, typename LeftXprType, typename RightXprType> class TensorContractionOp; template<typename Dimensions, typename LeftXprType, typename RightXprType, typename OutputKernelType> class TensorContractionOp;
template<typename TargetType, typename XprType> class TensorConversionOp; template<typename TargetType, typename XprType> class TensorConversionOp;
template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp; template<typename Dimensions, typename InputXprType, typename KernelXprType> class TensorConvolutionOp;
template<typename FFT, typename XprType, int FFTDataType, int FFTDirection> class TensorFFTOp; template<typename FFT, typename XprType, int FFTDataType, int FFTDirection> class TensorFFTOp;
@ -97,6 +97,8 @@ template<typename XprType> class TensorForcedEvalOp;
template<typename ExpressionType, typename DeviceType> class TensorDevice; template<typename ExpressionType, typename DeviceType> class TensorDevice;
template<typename Derived, typename Device> struct TensorEvaluator; template<typename Derived, typename Device> struct TensorEvaluator;
class NoOpOutputKernel;
struct DefaultDevice; struct DefaultDevice;
struct ThreadPoolDevice; struct ThreadPoolDevice;
struct GpuDevice; struct GpuDevice;

View File

@ -510,6 +510,55 @@ static void test_const_inputs()
VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1)); VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
} }
// Apply Sqrt to all output elements.
struct SqrtOutputKernel {
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
const TensorContractionParams&, Index, Index, Index num_rows,
Index num_cols) const {
for (int i = 0; i < num_rows; ++i) {
for (int j = 0; j < num_cols; ++j) {
output_mapper(i, j) = std::sqrt(output_mapper(i, j));
}
}
}
};
template <int DataLayout>
static void test_large_contraction_with_output_kernel() {
Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31);
Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10);
Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10);
t_left.setRandom();
t_right.setRandom();
// Put trash in mat4 to verify contraction clears output memory.
t_result.setRandom();
// Add a little offset so that the results won't be close to zero.
t_left += t_left.constant(1.0f);
t_right += t_right.constant(1.0f);
typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
MapXf m_left(t_left.data(), 1500, 248);
MapXf m_right(t_right.data(), 248, 1400);
Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400);
// this contraction should be equivalent to a single matrix multiplication
Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}});
// compute results by separate methods
t_result = t_left.contract(t_right, dims, SqrtOutputKernel());
m_result = m_left * m_right;
for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
VERIFY(&t_result.data()[i] != &m_result.data()[i]);
VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i]));
}
}
void test_cxx11_tensor_contraction() void test_cxx11_tensor_contraction()
{ {
CALL_SUBTEST(test_evals<ColMajor>()); CALL_SUBTEST(test_evals<ColMajor>());
@ -542,4 +591,6 @@ void test_cxx11_tensor_contraction()
CALL_SUBTEST(test_tensor_product<RowMajor>()); CALL_SUBTEST(test_tensor_product<RowMajor>());
CALL_SUBTEST(test_const_inputs<ColMajor>()); CALL_SUBTEST(test_const_inputs<ColMajor>());
CALL_SUBTEST(test_const_inputs<RowMajor>()); CALL_SUBTEST(test_const_inputs<RowMajor>());
CALL_SUBTEST(test_large_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST(test_large_contraction_with_output_kernel<RowMajor>());
} }

View File

@ -232,6 +232,60 @@ void test_multithread_contraction_agrees_with_singlethread() {
} }
} }
// Apply Sqrt to all output elements.
struct SqrtOutputKernel {
template <typename Index, typename Scalar>
EIGEN_ALWAYS_INLINE void operator()(
const OutputKernel::OutputMapper<Index, Scalar>& output_mapper,
const TensorContractionParams&, Index, Index, Index num_rows,
Index num_cols) const {
for (int i = 0; i < num_rows; ++i) {
for (int j = 0; j < num_cols; ++j) {
output_mapper(i, j) = std::sqrt(output_mapper(i, j));
}
}
}
};
template <int DataLayout>
static void test_multithread_contraction_with_output_kernel() {
typedef Tensor<float, 1>::DimensionPair DimPair;
const int num_threads = internal::random<int>(2, 11);
ThreadPool threads(num_threads);
Eigen::ThreadPoolDevice device(&threads, num_threads);
Tensor<float, 4, DataLayout> t_left(30, 50, 8, 31);
Tensor<float, 5, DataLayout> t_right(8, 31, 7, 20, 10);
Tensor<float, 5, DataLayout> t_result(30, 50, 7, 20, 10);
t_left.setRandom();
t_right.setRandom();
// Put trash in mat4 to verify contraction clears output memory.
t_result.setRandom();
// Add a little offset so that the results won't be close to zero.
t_left += t_left.constant(1.0f);
t_right += t_right.constant(1.0f);
typedef Map<Eigen::Matrix<float, Dynamic, Dynamic, DataLayout>> MapXf;
MapXf m_left(t_left.data(), 1500, 248);
MapXf m_right(t_right.data(), 248, 1400);
Eigen::Matrix<float, Dynamic, Dynamic, DataLayout> m_result(1500, 1400);
// this contraction should be equivalent to a single matrix multiplication
Eigen::array<DimPair, 2> dims({{DimPair(2, 0), DimPair(3, 1)}});
// compute results by separate methods
t_result.device(device) = t_left.contract(t_right, dims, SqrtOutputKernel());
m_result = m_left * m_right;
for (size_t i = 0; i < t_result.dimensions().TotalSize(); i++) {
VERIFY(&t_result.data()[i] != &m_result.data()[i]);
VERIFY_IS_APPROX(t_result.data()[i], std::sqrt(m_result.data()[i]));
}
}
template<int DataLayout> template<int DataLayout>
void test_full_contraction() { void test_full_contraction() {
@ -355,6 +409,8 @@ void test_cxx11_tensor_thread_pool()
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>()); CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>()); CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>());
// Exercise various cases that have been problematic in the past. // Exercise various cases that have been problematic in the past.
CALL_SUBTEST_4(test_contraction_corner_cases<ColMajor>()); CALL_SUBTEST_4(test_contraction_corner_cases<ColMajor>());