mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Improved tensor references
This commit is contained in:
parent
91dd53e54d
commit
c94174b4fe
@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, t
|
||||
virtual const Scalar coeff(DenseIndex index) const {
|
||||
return m_impl.coeff(index);
|
||||
}
|
||||
virtual Scalar& coeffRef(DenseIndex) {
|
||||
virtual Scalar& coeffRef(DenseIndex /*index*/) {
|
||||
eigen_assert(false && "can't reference the coefficient of a rvalue");
|
||||
return *reinterpret_cast<Scalar*>(dummy);
|
||||
};
|
||||
@ -137,6 +137,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = false,
|
||||
Layout = PlainObjectType::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
|
||||
@ -174,6 +176,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
||||
return *this;
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
|
||||
EIGEN_DEVICE_FUNC
|
||||
@ -197,6 +201,13 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
||||
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
|
||||
return coeff(indices);
|
||||
}
|
||||
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
|
||||
{
|
||||
const std::size_t NumIndices = (sizeof...(otherIndices) + 1);
|
||||
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
|
||||
return coeffRef(indices);
|
||||
}
|
||||
#else
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
@ -237,6 +248,44 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
||||
indices[4] = i4;
|
||||
return coeff(indices);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
|
||||
{
|
||||
array<Index, 2> indices;
|
||||
indices[0] = i0;
|
||||
indices[1] = i1;
|
||||
return coeffRef(indices);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
|
||||
{
|
||||
array<Index, 3> indices;
|
||||
indices[0] = i0;
|
||||
indices[1] = i1;
|
||||
indices[2] = i2;
|
||||
return coeffRef(indices);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
|
||||
{
|
||||
array<Index, 4> indices;
|
||||
indices[0] = i0;
|
||||
indices[1] = i1;
|
||||
indices[2] = i2;
|
||||
indices[3] = i3;
|
||||
return coeffRef(indices);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
|
||||
{
|
||||
array<Index, 5> indices;
|
||||
indices[0] = i0;
|
||||
indices[1] = i1;
|
||||
indices[2] = i2;
|
||||
indices[3] = i3;
|
||||
indices[4] = i4;
|
||||
return coeffRef(indices);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
|
||||
@ -244,7 +293,7 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
||||
{
|
||||
const Dimensions& dims = this->dimensions();
|
||||
Index index = 0;
|
||||
if (PlainObjectType::Options&RowMajor) {
|
||||
if (PlainObjectType::Options & RowMajor) {
|
||||
index += indices[0];
|
||||
for (int i = 1; i < NumIndices; ++i) {
|
||||
index = index * dims[i] + indices[i];
|
||||
@ -257,6 +306,24 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
||||
}
|
||||
return m_evaluator->coeff(index);
|
||||
}
|
||||
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
|
||||
{
|
||||
const Dimensions& dims = this->dimensions();
|
||||
Index index = 0;
|
||||
if (PlainObjectType::Options & RowMajor) {
|
||||
index += indices[0];
|
||||
for (int i = 1; i < NumIndices; ++i) {
|
||||
index = index * dims[i] + indices[i];
|
||||
}
|
||||
} else {
|
||||
index += indices[NumIndices-1];
|
||||
for (int i = NumIndices-2; i >= 0; --i) {
|
||||
index = index * dims[i] + indices[i];
|
||||
}
|
||||
}
|
||||
return m_evaluator->coeffRef(index);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
||||
@ -298,6 +365,8 @@ struct TensorEvaluator<const TensorRef<Derived>, Device>
|
||||
enum {
|
||||
IsAligned = false,
|
||||
PacketAccess = false,
|
||||
Layout = TensorRef<Derived>::Layout,
|
||||
CoordAccess = false, // to be implemented
|
||||
};
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
|
||||
|
@ -181,6 +181,21 @@ static void test_ref_in_expr()
|
||||
}
|
||||
|
||||
|
||||
static void test_coeff_ref()
|
||||
{
|
||||
Tensor<float, 5> tensor(2,3,5,7,11);
|
||||
tensor.setRandom();
|
||||
Tensor<float, 5> original = tensor;
|
||||
|
||||
TensorRef<Tensor<float, 4>> slice = tensor.chip(7, 4);
|
||||
slice.coeffRef(0, 0, 0, 0) = 1.0f;
|
||||
slice.coeffRef(1, 0, 0, 0) += 2.0f;
|
||||
|
||||
VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f);
|
||||
VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f);
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_ref()
|
||||
{
|
||||
CALL_SUBTEST(test_simple_lvalue_ref());
|
||||
@ -189,4 +204,5 @@ void test_cxx11_tensor_ref()
|
||||
CALL_SUBTEST(test_slice());
|
||||
CALL_SUBTEST(test_ref_of_ref());
|
||||
CALL_SUBTEST(test_ref_in_expr());
|
||||
CALL_SUBTEST(test_coeff_ref());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user