Specify default output kernel for TensorContractionOp

This commit is contained in:
Eugene Zhulenev 2018-07-18 14:21:01 -07:00
parent 6e5a3b898f
commit 79d4129cce

View File

@ -182,7 +182,7 @@ struct NoOpOutputKernel {
Index num_cols) const {} Index num_cols) const {}
}; };
template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType> template<typename Indices, typename LhsXprType, typename RhsXprType, typename OutputKernelType = const NoOpOutputKernel>
class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors> class TensorContractionOp : public TensorBase<TensorContractionOp<Indices, LhsXprType, RhsXprType, OutputKernelType>, ReadOnlyAccessors>
{ {
public: public:
@ -507,7 +507,7 @@ struct TensorContractionEvaluatorBase
} }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
#if !defined(EIGEN_HIPCC) #if !defined(EIGEN_HIPCC)
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
#endif #endif
void evalGemv(Scalar* buffer) const { void evalGemv(Scalar* buffer) const {
@ -556,7 +556,7 @@ struct TensorContractionEvaluatorBase
} }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
#if !defined(EIGEN_HIPCC) #if !defined(EIGEN_HIPCC)
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
#endif #endif
void evalGemm(Scalar* buffer) const { void evalGemm(Scalar* buffer) const {