Improved tensor references

This commit is contained in:
Benoit Steiner 2015-01-14 10:13:08 -08:00
parent 91dd53e54d
commit c94174b4fe
2 changed files with 87 additions and 2 deletions

View File

@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, t
virtual const Scalar coeff(DenseIndex index) const {
return m_impl.coeff(index);
}
virtual Scalar& coeffRef(DenseIndex) {
virtual Scalar& coeffRef(DenseIndex /*index*/) {
eigen_assert(false && "can't reference the coefficient of a rvalue");
return *reinterpret_cast<Scalar*>(dummy);
};
@ -137,6 +137,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
enum {
IsAligned = false,
PacketAccess = false,
Layout = PlainObjectType::Layout,
CoordAccess = false, // to be implemented
};
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
@ -174,6 +176,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
return *this;
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
EIGEN_DEVICE_FUNC
@ -197,6 +201,13 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
return coeff(indices);
}
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
{
const std::size_t NumIndices = (sizeof...(otherIndices) + 1);
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
return coeffRef(indices);
}
#else
EIGEN_DEVICE_FUNC
@ -237,6 +248,44 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
indices[4] = i4;
return coeff(indices);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
{
array<Index, 2> indices;
indices[0] = i0;
indices[1] = i1;
return coeffRef(indices);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
{
array<Index, 3> indices;
indices[0] = i0;
indices[1] = i1;
indices[2] = i2;
return coeffRef(indices);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
{
array<Index, 4> indices;
indices[0] = i0;
indices[1] = i1;
indices[2] = i2;
indices[3] = i3;
return coeffRef(indices);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
{
array<Index, 5> indices;
indices[0] = i0;
indices[1] = i1;
indices[2] = i2;
indices[3] = i3;
indices[4] = i4;
return coeffRef(indices);
}
#endif
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
@ -244,7 +293,7 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
{
const Dimensions& dims = this->dimensions();
Index index = 0;
if (PlainObjectType::Options&RowMajor) {
if (PlainObjectType::Options & RowMajor) {
index += indices[0];
for (int i = 1; i < NumIndices; ++i) {
index = index * dims[i] + indices[i];
@ -257,6 +306,24 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
}
return m_evaluator->coeff(index);
}
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
{
const Dimensions& dims = this->dimensions();
Index index = 0;
if (PlainObjectType::Options & RowMajor) {
index += indices[0];
for (int i = 1; i < NumIndices; ++i) {
index = index * dims[i] + indices[i];
}
} else {
index += indices[NumIndices-1];
for (int i = NumIndices-2; i >= 0; --i) {
index = index * dims[i] + indices[i];
}
}
return m_evaluator->coeffRef(index);
}
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
@ -298,6 +365,8 @@ struct TensorEvaluator<const TensorRef<Derived>, Device>
enum {
IsAligned = false,
PacketAccess = false,
Layout = TensorRef<Derived>::Layout,
CoordAccess = false, // to be implemented
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)

View File

@ -181,6 +181,21 @@ static void test_ref_in_expr()
}
static void test_coeff_ref()
{
Tensor<float, 5> tensor(2,3,5,7,11);
tensor.setRandom();
Tensor<float, 5> original = tensor;
TensorRef<Tensor<float, 4>> slice = tensor.chip(7, 4);
slice.coeffRef(0, 0, 0, 0) = 1.0f;
slice.coeffRef(1, 0, 0, 0) += 2.0f;
VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f);
VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f);
}
void test_cxx11_tensor_ref()
{
CALL_SUBTEST(test_simple_lvalue_ref());
@ -189,4 +204,5 @@ void test_cxx11_tensor_ref()
CALL_SUBTEST(test_slice());
CALL_SUBTEST(test_ref_of_ref());
CALL_SUBTEST(test_ref_in_expr());
CALL_SUBTEST(test_coeff_ref());
}