mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-22 17:49:36 +08:00
Added a constructor to simplify the construction of tensormap from tensor
This commit is contained in:
parent
e78bc111f1
commit
4cf7da63de
@ -82,15 +82,19 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
|
||||||
: m_data(dataPtr), m_dimensions(dimensions)
|
: m_data(dataPtr), m_dimensions(dimensions)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
template <typename Dimensions>
|
template <typename Dimensions>
|
||||||
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
|
||||||
: m_data(dataPtr), m_dimensions(dimensions)
|
: m_data(dataPtr), m_dimensions(dimensions)
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor)
|
||||||
|
: m_data(tensor.data()), m_dimensions(tensor.dimensions())
|
||||||
|
{ }
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
|
EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
|
@ -139,9 +139,113 @@ static void test_3d()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static void test_from_tensor()
|
||||||
|
{
|
||||||
|
Tensor<int, 3> mat1(2,3,7);
|
||||||
|
Tensor<int, 3, RowMajor> mat2(2,3,7);
|
||||||
|
|
||||||
|
int val = 0;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
mat1(i,j,k) = val;
|
||||||
|
mat2(i,j,k) = val;
|
||||||
|
val++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorMap<Tensor<int, 3>> mat3(mat1);
|
||||||
|
TensorMap<Tensor<int, 3, RowMajor>> mat4(mat2);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(mat3.rank(), 3);
|
||||||
|
VERIFY_IS_EQUAL(mat3.size(), 2*3*7);
|
||||||
|
VERIFY_IS_EQUAL(mat3.dimension(0), 2);
|
||||||
|
VERIFY_IS_EQUAL(mat3.dimension(1), 3);
|
||||||
|
VERIFY_IS_EQUAL(mat3.dimension(2), 7);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(mat4.rank(), 3);
|
||||||
|
VERIFY_IS_EQUAL(mat4.size(), 2*3*7);
|
||||||
|
VERIFY_IS_EQUAL(mat4.dimension(0), 2);
|
||||||
|
VERIFY_IS_EQUAL(mat4.dimension(1), 3);
|
||||||
|
VERIFY_IS_EQUAL(mat4.dimension(2), 7);
|
||||||
|
|
||||||
|
val = 0;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(mat3(i,j,k), val);
|
||||||
|
VERIFY_IS_EQUAL(mat4(i,j,k), val);
|
||||||
|
val++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorFixedSize<int, Sizes<2,3,7>> mat5;
|
||||||
|
|
||||||
|
val = 0;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
mat5(i,j,k) = val;
|
||||||
|
val++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorMap<TensorFixedSize<int, Sizes<2,3,7>>> mat6(mat5);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(mat6.rank(), 3);
|
||||||
|
VERIFY_IS_EQUAL(mat6.size(), 2*3*7);
|
||||||
|
VERIFY_IS_EQUAL(mat6.dimension(0), 2);
|
||||||
|
VERIFY_IS_EQUAL(mat6.dimension(1), 3);
|
||||||
|
VERIFY_IS_EQUAL(mat6.dimension(2), 7);
|
||||||
|
|
||||||
|
val = 0;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
VERIFY_IS_EQUAL(mat6(i,j,k), val);
|
||||||
|
val++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
static int f(const TensorMap<Tensor<int, 3> >& tensor) {
|
||||||
|
Tensor<int, 1> result = tensor.sum();
|
||||||
|
return result(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_casting()
|
||||||
|
{
|
||||||
|
Tensor<int, 3> tensor(2,3,7);
|
||||||
|
|
||||||
|
int val = 0;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
tensor(i,j,k) = val;
|
||||||
|
val++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorMap<Tensor<int, 3>> map(tensor);
|
||||||
|
int sum1 = f(map);
|
||||||
|
int sum2 = f(tensor);
|
||||||
|
|
||||||
|
VERIFY_IS_EQUAL(sum1, sum2);
|
||||||
|
VERIFY_IS_EQUAL(sum1, 41);
|
||||||
|
}
|
||||||
|
|
||||||
void test_cxx11_tensor_map()
|
void test_cxx11_tensor_map()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_1d());
|
CALL_SUBTEST(test_1d());
|
||||||
CALL_SUBTEST(test_2d());
|
CALL_SUBTEST(test_2d());
|
||||||
CALL_SUBTEST(test_3d());
|
CALL_SUBTEST(test_3d());
|
||||||
|
|
||||||
|
CALL_SUBTEST(test_from_tensor());
|
||||||
|
CALL_SUBTEST(test_casting());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user