Marked several methods EIGEN_DEVICE_FUNC

This commit is contained in:
Benoit Steiner 2016-01-28 23:37:48 -08:00
parent c5d25bf1d0
commit 963f2d2a8f
2 changed files with 6 additions and 6 deletions

View File

@ -378,7 +378,7 @@ struct TensorContractionEvaluatorBase
} }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalGemv(Scalar* buffer) const { EIGEN_DEVICE_FUNC void evalGemv(Scalar* buffer) const {
const Index rows = m_i_size; const Index rows = m_i_size;
const Index cols = m_k_size; const Index cols = m_k_size;
@ -516,7 +516,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
Base(op, device) { } Base(op, device) { }
template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment> template <bool lhs_inner_dim_contiguous, bool rhs_inner_dim_contiguous, bool rhs_inner_dim_reordered, int Alignment>
void evalProduct(Scalar* buffer) const { EIGEN_DEVICE_FUNC void evalProduct(Scalar* buffer) const {
if (this->m_j_size == 1) { if (this->m_j_size == 1) {
this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer); this->template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
return; return;

View File

@ -28,7 +28,7 @@ class TensorContractionBlocking {
typedef typename LhsMapper::Scalar LhsScalar; typedef typename LhsMapper::Scalar LhsScalar;
typedef typename RhsMapper::Scalar RhsScalar; typedef typename RhsMapper::Scalar RhsScalar;
TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) : EIGEN_DEVICE_FUNC TensorContractionBlocking(Index k, Index m, Index n, Index num_threads = 1) :
kc_(k), mc_(m), nc_(n) kc_(k), mc_(m), nc_(n)
{ {
if (ShardingType == ShardByCol) { if (ShardingType == ShardByCol) {
@ -41,9 +41,9 @@ class TensorContractionBlocking {
} }
} }
EIGEN_ALWAYS_INLINE Index kc() const { return kc_; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index kc() const { return kc_; }
EIGEN_ALWAYS_INLINE Index mc() const { return mc_; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index mc() const { return mc_; }
EIGEN_ALWAYS_INLINE Index nc() const { return nc_; } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index nc() const { return nc_; }
private: private:
Index kc_; Index kc_;