mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Improved tensor references
This commit is contained in:
parent
91dd53e54d
commit
c94174b4fe
@ -64,7 +64,7 @@ class TensorLazyEvaluatorReadOnly : public TensorLazyBaseEvaluator<Dimensions, t
|
|||||||
virtual const Scalar coeff(DenseIndex index) const {
|
virtual const Scalar coeff(DenseIndex index) const {
|
||||||
return m_impl.coeff(index);
|
return m_impl.coeff(index);
|
||||||
}
|
}
|
||||||
virtual Scalar& coeffRef(DenseIndex) {
|
virtual Scalar& coeffRef(DenseIndex /*index*/) {
|
||||||
eigen_assert(false && "can't reference the coefficient of a rvalue");
|
eigen_assert(false && "can't reference the coefficient of a rvalue");
|
||||||
return *reinterpret_cast<Scalar*>(dummy);
|
return *reinterpret_cast<Scalar*>(dummy);
|
||||||
};
|
};
|
||||||
@ -137,6 +137,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
|||||||
enum {
|
enum {
|
||||||
IsAligned = false,
|
IsAligned = false,
|
||||||
PacketAccess = false,
|
PacketAccess = false,
|
||||||
|
Layout = PlainObjectType::Layout,
|
||||||
|
CoordAccess = false, // to be implemented
|
||||||
};
|
};
|
||||||
|
|
||||||
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
|
EIGEN_STRONG_INLINE TensorRef() : m_evaluator(NULL) {
|
||||||
@ -174,6 +176,8 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Index rank() const { return m_evaluator->dimensions().size(); }
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
|
EIGEN_STRONG_INLINE Index dimension(Index n) const { return m_evaluator->dimensions()[n]; }
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -197,6 +201,13 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
|||||||
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
|
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
|
||||||
return coeff(indices);
|
return coeff(indices);
|
||||||
}
|
}
|
||||||
|
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Scalar& coeffRef(Index firstIndex, IndexTypes... otherIndices)
|
||||||
|
{
|
||||||
|
const std::size_t NumIndices = (sizeof...(otherIndices) + 1);
|
||||||
|
const array<Index, NumIndices> indices{{firstIndex, otherIndices...}};
|
||||||
|
return coeffRef(indices);
|
||||||
|
}
|
||||||
#else
|
#else
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -237,6 +248,44 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
|||||||
indices[4] = i4;
|
indices[4] = i4;
|
||||||
return coeff(indices);
|
return coeff(indices);
|
||||||
}
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1)
|
||||||
|
{
|
||||||
|
array<Index, 2> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
return coeffRef(indices);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2)
|
||||||
|
{
|
||||||
|
array<Index, 3> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
indices[2] = i2;
|
||||||
|
return coeffRef(indices);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2, Index i3)
|
||||||
|
{
|
||||||
|
array<Index, 4> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
indices[2] = i2;
|
||||||
|
indices[3] = i3;
|
||||||
|
return coeffRef(indices);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Scalar& coeffRef(Index i0, Index i1, Index i2, Index i3, Index i4)
|
||||||
|
{
|
||||||
|
array<Index, 5> indices;
|
||||||
|
indices[0] = i0;
|
||||||
|
indices[1] = i1;
|
||||||
|
indices[2] = i2;
|
||||||
|
indices[3] = i3;
|
||||||
|
indices[4] = i4;
|
||||||
|
return coeffRef(indices);
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
|
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
|
||||||
@ -244,7 +293,7 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
|||||||
{
|
{
|
||||||
const Dimensions& dims = this->dimensions();
|
const Dimensions& dims = this->dimensions();
|
||||||
Index index = 0;
|
Index index = 0;
|
||||||
if (PlainObjectType::Options&RowMajor) {
|
if (PlainObjectType::Options & RowMajor) {
|
||||||
index += indices[0];
|
index += indices[0];
|
||||||
for (int i = 1; i < NumIndices; ++i) {
|
for (int i = 1; i < NumIndices; ++i) {
|
||||||
index = index * dims[i] + indices[i];
|
index = index * dims[i] + indices[i];
|
||||||
@ -257,6 +306,24 @@ template<typename PlainObjectType> class TensorRef : public TensorBase<TensorRef
|
|||||||
}
|
}
|
||||||
return m_evaluator->coeff(index);
|
return m_evaluator->coeff(index);
|
||||||
}
|
}
|
||||||
|
template <std::size_t NumIndices> EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE Scalar& coeffRef(const array<Index, NumIndices>& indices)
|
||||||
|
{
|
||||||
|
const Dimensions& dims = this->dimensions();
|
||||||
|
Index index = 0;
|
||||||
|
if (PlainObjectType::Options & RowMajor) {
|
||||||
|
index += indices[0];
|
||||||
|
for (int i = 1; i < NumIndices; ++i) {
|
||||||
|
index = index * dims[i] + indices[i];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
index += indices[NumIndices-1];
|
||||||
|
for (int i = NumIndices-2; i >= 0; --i) {
|
||||||
|
index = index * dims[i] + indices[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return m_evaluator->coeffRef(index);
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
EIGEN_STRONG_INLINE const Scalar coeff(Index index) const
|
||||||
@ -298,6 +365,8 @@ struct TensorEvaluator<const TensorRef<Derived>, Device>
|
|||||||
enum {
|
enum {
|
||||||
IsAligned = false,
|
IsAligned = false,
|
||||||
PacketAccess = false,
|
PacketAccess = false,
|
||||||
|
Layout = TensorRef<Derived>::Layout,
|
||||||
|
CoordAccess = false, // to be implemented
|
||||||
};
|
};
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const TensorRef<Derived>& m, const Device&)
|
||||||
|
@ -181,6 +181,21 @@ static void test_ref_in_expr()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_coeff_ref()
|
||||||
|
{
|
||||||
|
Tensor<float, 5> tensor(2,3,5,7,11);
|
||||||
|
tensor.setRandom();
|
||||||
|
Tensor<float, 5> original = tensor;
|
||||||
|
|
||||||
|
TensorRef<Tensor<float, 4>> slice = tensor.chip(7, 4);
|
||||||
|
slice.coeffRef(0, 0, 0, 0) = 1.0f;
|
||||||
|
slice.coeffRef(1, 0, 0, 0) += 2.0f;
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(tensor(0,0,0,0,7), 1.0f);
|
||||||
|
VERIFY_IS_EQUAL(tensor(1,0,0,0,7), original(1,0,0,0,7) + 2.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
void test_cxx11_tensor_ref()
|
void test_cxx11_tensor_ref()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_simple_lvalue_ref());
|
CALL_SUBTEST(test_simple_lvalue_ref());
|
||||||
@ -189,4 +204,5 @@ void test_cxx11_tensor_ref()
|
|||||||
CALL_SUBTEST(test_slice());
|
CALL_SUBTEST(test_slice());
|
||||||
CALL_SUBTEST(test_ref_of_ref());
|
CALL_SUBTEST(test_ref_of_ref());
|
||||||
CALL_SUBTEST(test_ref_in_expr());
|
CALL_SUBTEST(test_ref_in_expr());
|
||||||
|
CALL_SUBTEST(test_coeff_ref());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user