mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-22 01:29:35 +08:00
Added a constructor to simplify the construction of tensormap from tensor
This commit is contained in:
parent
e78bc111f1
commit
4cf7da63de
@ -82,15 +82,19 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
||||
}
|
||||
#endif
|
||||
|
||||
inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
|
||||
: m_data(dataPtr), m_dimensions(dimensions)
|
||||
{ }
|
||||
|
||||
template <typename Dimensions>
|
||||
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
|
||||
: m_data(dataPtr), m_dimensions(dimensions)
|
||||
{ }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor)
|
||||
: m_data(tensor.data()), m_dimensions(tensor.dimensions())
|
||||
{ }
|
||||
|
||||
EIGEN_DEVICE_FUNC
|
||||
EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
|
||||
EIGEN_DEVICE_FUNC
|
||||
|
@ -139,9 +139,113 @@ static void test_3d()
|
||||
}
|
||||
|
||||
|
||||
static void test_from_tensor()
|
||||
{
|
||||
Tensor<int, 3> mat1(2,3,7);
|
||||
Tensor<int, 3, RowMajor> mat2(2,3,7);
|
||||
|
||||
int val = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
mat1(i,j,k) = val;
|
||||
mat2(i,j,k) = val;
|
||||
val++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorMap<Tensor<int, 3>> mat3(mat1);
|
||||
TensorMap<Tensor<int, 3, RowMajor>> mat4(mat2);
|
||||
|
||||
VERIFY_IS_EQUAL(mat3.rank(), 3);
|
||||
VERIFY_IS_EQUAL(mat3.size(), 2*3*7);
|
||||
VERIFY_IS_EQUAL(mat3.dimension(0), 2);
|
||||
VERIFY_IS_EQUAL(mat3.dimension(1), 3);
|
||||
VERIFY_IS_EQUAL(mat3.dimension(2), 7);
|
||||
|
||||
VERIFY_IS_EQUAL(mat4.rank(), 3);
|
||||
VERIFY_IS_EQUAL(mat4.size(), 2*3*7);
|
||||
VERIFY_IS_EQUAL(mat4.dimension(0), 2);
|
||||
VERIFY_IS_EQUAL(mat4.dimension(1), 3);
|
||||
VERIFY_IS_EQUAL(mat4.dimension(2), 7);
|
||||
|
||||
val = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_EQUAL(mat3(i,j,k), val);
|
||||
VERIFY_IS_EQUAL(mat4(i,j,k), val);
|
||||
val++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorFixedSize<int, Sizes<2,3,7>> mat5;
|
||||
|
||||
val = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
mat5(i,j,k) = val;
|
||||
val++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorMap<TensorFixedSize<int, Sizes<2,3,7>>> mat6(mat5);
|
||||
|
||||
VERIFY_IS_EQUAL(mat6.rank(), 3);
|
||||
VERIFY_IS_EQUAL(mat6.size(), 2*3*7);
|
||||
VERIFY_IS_EQUAL(mat6.dimension(0), 2);
|
||||
VERIFY_IS_EQUAL(mat6.dimension(1), 3);
|
||||
VERIFY_IS_EQUAL(mat6.dimension(2), 7);
|
||||
|
||||
val = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
VERIFY_IS_EQUAL(mat6(i,j,k), val);
|
||||
val++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static int f(const TensorMap<Tensor<int, 3> >& tensor) {
|
||||
Tensor<int, 1> result = tensor.sum();
|
||||
return result(0);
|
||||
}
|
||||
|
||||
static void test_casting()
|
||||
{
|
||||
Tensor<int, 3> tensor(2,3,7);
|
||||
|
||||
int val = 0;
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
for (int j = 0; j < 3; ++j) {
|
||||
for (int k = 0; k < 7; ++k) {
|
||||
tensor(i,j,k) = val;
|
||||
val++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TensorMap<Tensor<int, 3>> map(tensor);
|
||||
int sum1 = f(map);
|
||||
int sum2 = f(tensor);
|
||||
|
||||
VERIFY_IS_EQUAL(sum1, sum2);
|
||||
VERIFY_IS_EQUAL(sum1, 41);
|
||||
}
|
||||
|
||||
void test_cxx11_tensor_map()
|
||||
{
|
||||
CALL_SUBTEST(test_1d());
|
||||
CALL_SUBTEST(test_2d());
|
||||
CALL_SUBTEST(test_3d());
|
||||
|
||||
CALL_SUBTEST(test_from_tensor());
|
||||
CALL_SUBTEST(test_casting());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user