mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 11:19:02 +08:00
Misc improvements for fixed size tensors
This commit is contained in:
parent
71676eaddd
commit
b12dd1ae3c
@ -42,7 +42,9 @@ class TensorFixedSize : public TensorBase<TensorFixedSize<Scalar_, Dimensions_,
|
|||||||
enum {
|
enum {
|
||||||
IsAligned = bool(EIGEN_ALIGN),
|
IsAligned = bool(EIGEN_ALIGN),
|
||||||
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
PacketAccess = (internal::packet_traits<Scalar>::size > 1),
|
||||||
};
|
Layout = Options_ & RowMajor ? RowMajor : ColMajor,
|
||||||
|
CoordAccess = true,
|
||||||
|
};
|
||||||
|
|
||||||
typedef Dimensions_ Dimensions;
|
typedef Dimensions_ Dimensions;
|
||||||
static const std::size_t NumIndices = Dimensions::count;
|
static const std::size_t NumIndices = Dimensions::count;
|
||||||
@ -51,11 +53,12 @@ class TensorFixedSize : public TensorBase<TensorFixedSize<Scalar_, Dimensions_,
|
|||||||
TensorStorage<Scalar, NumIndices, Dimensions::total_size, Options, Dimensions> m_storage;
|
TensorStorage<Scalar, NumIndices, Dimensions::total_size, Options, Dimensions> m_storage;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return m_storage.dimensions()[n]; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index rank() const { return NumIndices; }
|
||||||
EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_storage.dimensions(); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const { return m_storage.dimensions()[n]; }
|
||||||
EIGEN_STRONG_INLINE Index size() const { return m_storage.size(); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_storage.dimensions(); }
|
||||||
EIGEN_STRONG_INLINE Scalar *data() { return m_storage.data(); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const { return m_storage.size(); }
|
||||||
EIGEN_STRONG_INLINE const Scalar *data() const { return m_storage.data(); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar *data() { return m_storage.data(); }
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar *data() const { return m_storage.data(); }
|
||||||
|
|
||||||
// This makes EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
|
// This makes EIGEN_INITIALIZE_COEFFS_IF_THAT_OPTION_IS_ENABLED
|
||||||
// work, because that uses base().coeffRef() - and we don't yet
|
// work, because that uses base().coeffRef() - and we don't yet
|
||||||
@ -187,6 +190,23 @@ class TensorFixedSize : public TensorBase<TensorFixedSize<Scalar_, Dimensions_,
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef EIGEN_HAVE_RVALUE_REFERENCES
|
||||||
|
inline TensorFixedSize(Self&& other)
|
||||||
|
: m_storage(other.m_storage)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE TensorFixedSize& operator=(const TensorFixedSize& other)
|
||||||
|
{
|
||||||
|
// FIXME: check that the dimensions of other match the dimensions of *this.
|
||||||
|
// Unfortunately this isn't possible yet when the rhs is an expression.
|
||||||
|
typedef TensorAssignOp<Self, const TensorFixedSize> Assign;
|
||||||
|
Assign assign(*this, other);
|
||||||
|
internal::TensorExecutor<const Assign, DefaultDevice>::run(assign, DefaultDevice());
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE TensorFixedSize& operator=(const OtherDerived& other)
|
EIGEN_STRONG_INLINE TensorFixedSize& operator=(const OtherDerived& other)
|
||||||
|
@ -32,13 +32,14 @@ static void test_1d()
|
|||||||
vec1(5) = 42.0; vec2(5) = 5.0;
|
vec1(5) = 42.0; vec2(5) = 5.0;
|
||||||
|
|
||||||
float data3[6];
|
float data3[6];
|
||||||
TensorMap<TensorFixedSize<float, Sizes<6> > > vec3(data3, Sizes<6>());
|
TensorMap<TensorFixedSize<float, Sizes<6> > > vec3(data3, 6);
|
||||||
vec3 = vec1.sqrt();
|
vec3 = vec1.sqrt();
|
||||||
float data4[6];
|
float data4[6];
|
||||||
TensorMap<TensorFixedSize<float, Sizes<6>, RowMajor> > vec4(data4, Sizes<6>());
|
TensorMap<TensorFixedSize<float, Sizes<6>, RowMajor> > vec4(data4, 6);
|
||||||
vec4 = vec2.sqrt();
|
vec4 = vec2.sqrt();
|
||||||
|
|
||||||
VERIFY_IS_EQUAL((vec3.size()), 6);
|
VERIFY_IS_EQUAL((vec3.size()), 6);
|
||||||
|
VERIFY_IS_EQUAL(vec3.rank(), 1);
|
||||||
// VERIFY_IS_EQUAL((vec3.dimensions()[0]), 6);
|
// VERIFY_IS_EQUAL((vec3.dimensions()[0]), 6);
|
||||||
// VERIFY_IS_EQUAL((vec3.dimension(0)), 6);
|
// VERIFY_IS_EQUAL((vec3.dimension(0)), 6);
|
||||||
|
|
||||||
@ -68,11 +69,12 @@ static void test_1d()
|
|||||||
static void test_2d()
|
static void test_2d()
|
||||||
{
|
{
|
||||||
float data1[6];
|
float data1[6];
|
||||||
TensorMap<TensorFixedSize<float, Sizes<2, 3> >> mat1(data1, Sizes<2, 3>());
|
TensorMap<TensorFixedSize<float, Sizes<2, 3> >> mat1(data1,2,3);
|
||||||
float data2[6];
|
float data2[6];
|
||||||
TensorMap<TensorFixedSize<float, Sizes<2, 3>, RowMajor>> mat2(data2, Sizes<2, 3>());
|
TensorMap<TensorFixedSize<float, Sizes<2, 3>, RowMajor>> mat2(data2,2,3);
|
||||||
|
|
||||||
VERIFY_IS_EQUAL((mat1.size()), 2*3);
|
VERIFY_IS_EQUAL((mat1.size()), 2*3);
|
||||||
|
VERIFY_IS_EQUAL(mat1.rank(), 2);
|
||||||
// VERIFY_IS_EQUAL((mat1.dimension(0)), 2);
|
// VERIFY_IS_EQUAL((mat1.dimension(0)), 2);
|
||||||
// VERIFY_IS_EQUAL((mat1.dimension(1)), 3);
|
// VERIFY_IS_EQUAL((mat1.dimension(1)), 3);
|
||||||
|
|
||||||
@ -120,6 +122,7 @@ static void test_3d()
|
|||||||
TensorFixedSize<float, Sizes<2, 3, 7>, RowMajor> mat2;
|
TensorFixedSize<float, Sizes<2, 3, 7>, RowMajor> mat2;
|
||||||
|
|
||||||
VERIFY_IS_EQUAL((mat1.size()), 2*3*7);
|
VERIFY_IS_EQUAL((mat1.size()), 2*3*7);
|
||||||
|
VERIFY_IS_EQUAL(mat1.rank(), 3);
|
||||||
// VERIFY_IS_EQUAL((mat1.dimension(0)), 2);
|
// VERIFY_IS_EQUAL((mat1.dimension(0)), 2);
|
||||||
// VERIFY_IS_EQUAL((mat1.dimension(1)), 3);
|
// VERIFY_IS_EQUAL((mat1.dimension(1)), 3);
|
||||||
// VERIFY_IS_EQUAL((mat1.dimension(2)), 7);
|
// VERIFY_IS_EQUAL((mat1.dimension(2)), 7);
|
||||||
@ -166,7 +169,7 @@ static void test_array()
|
|||||||
for (int i = 0; i < 2; ++i) {
|
for (int i = 0; i < 2; ++i) {
|
||||||
for (int j = 0; j < 3; ++j) {
|
for (int j = 0; j < 3; ++j) {
|
||||||
for (int k = 0; k < 7; ++k) {
|
for (int k = 0; k < 7; ++k) {
|
||||||
mat1(array<ptrdiff_t, 3>{{i,j,k}}) = val;
|
mat1(i,j,k) = val;
|
||||||
val += 1.0;
|
val += 1.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user