diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h index d43fb286e..0a87e67eb 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorRef.h @@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator(dummy); }; @@ -137,6 +137,8 @@ template class TensorRef : public TensorBase class TensorRef : public TensorBasedimensions().size(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; } EIGEN_DEVICE_FUNC @@ -197,6 +201,13 @@ template class TensorRef : public TensorBase indices{{firstIndex, otherIndices...}}; return coeff(indices); } + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices) + { + const std::size_t NumIndices = (sizeof...(otherIndices) + 1); + const array indices{{firstIndex, otherIndices...}}; + return coeffRef(indices); + } #else EIGEN_DEVICE_FUNC @@ -237,6 +248,44 @@ template class TensorRef : public TensorBase indices; + indices[0] = i0; + indices[1] = i1; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2) + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3) + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + return coeffRef(indices); + } + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4) + { + array indices; + indices[0] = i0; + indices[1] = i1; + indices[2] = i2; + indices[3] = i3; + indices[4] = i4; + return coeffRef(indices); + } #endif template EIGEN_DEVICE_FUNC @@ -244,7 +293,7 @@ template class TensorRef : public TensorBasedimensions(); Index index = 0; - if (PlainObjectType::Options&RowMajor) { + if (PlainObjectType::Options & RowMajor) { index += indices[0]; for (int i = 1; i < NumIndices; ++i) { index = index * dims[i] + indices[i]; @@ -257,6 +306,24 @@ template class TensorRef : public TensorBasecoeff(index); } + template EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar& coeffRef(const array& indices) + { + const Dimensions& dims = this->dimensions(); + Index index = 0; + if (PlainObjectType::Options & RowMajor) { + index += indices[0]; + for (int i = 1; i < NumIndices; ++i) { + index = index * dims[i] + indices[i]; + } + } else { + index += indices[NumIndices-1]; + for (int i = NumIndices-2; i >= 0; --i) { + index = index * dims[i] + indices[i]; + } + } + return m_evaluator->coeffRef(index); + } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar coeff(Index index) const @@ -298,6 +365,8 @@ struct TensorEvaluator, Device> enum { IsAligned = false, PacketAccess = false, + Layout = TensorRef::Layout, + CoordAccess = false, // to be implemented }; EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef& m, const Device&) diff --git a/unsupported/test/cxx11_tensor_ref.cpp b/unsupported/test/cxx11_tensor_ref.cpp index 4ff94a059..aa369f278 100644 --- a/unsupported/test/cxx11_tensor_ref.cpp +++ b/unsupported/test/cxx11_tensor_ref.cpp @@ -181,6 +181,21 @@ static void test_ref_in_expr() } +static void test_coeff_ref() +{ + Tensor tensor(2,3,5,7,11); + tensor.setRandom(); + Tensor original = tensor; + + TensorRef> slice = tensor.chip(7, 4); + slice.coeffRef(0, 0, 0, 0) = 1.0f; + slice.coeffRef(1, 0, 0, 0) += 2.0f; + + VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f); + VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f); +} + + void test_cxx11_tensor_ref() { CALL_SUBTEST(test_simple_lvalue_ref()); @@ -189,4 +204,5 @@ void test_cxx11_tensor_ref() CALL_SUBTEST(test_slice()); CALL_SUBTEST(test_ref_of_ref()); CALL_SUBTEST(test_ref_in_expr()); + CALL_SUBTEST(test_coeff_ref()); }