Added a constructor to simplify the construction of tensormap from tensor

This commit is contained in:
Benoit Steiner 2015-10-22 11:48:02 -07:00
parent e78bc111f1
commit 4cf7da63de
2 changed files with 110 additions and 2 deletions

View File

@ -82,15 +82,19 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
}
#endif
inline TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const array<Index, NumIndices>& dimensions)
: m_data(dataPtr), m_dimensions(dimensions)
{ }
template <typename Dimensions>
EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PointerArgType dataPtr, const Dimensions& dimensions)
: m_data(dataPtr), m_dimensions(dimensions)
{ }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorMap(PlainObjectType& tensor)
: m_data(tensor.data()), m_dimensions(tensor.dimensions())
{ }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index rank() const { return m_dimensions.rank(); }
EIGEN_DEVICE_FUNC

View File

@ -139,9 +139,113 @@ static void test_3d()
}
static void test_from_tensor()
{
Tensor<int, 3> mat1(2,3,7);
Tensor<int, 3, RowMajor> mat2(2,3,7);
int val = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
mat1(i,j,k) = val;
mat2(i,j,k) = val;
val++;
}
}
}
TensorMap<Tensor<int, 3>> mat3(mat1);
TensorMap<Tensor<int, 3, RowMajor>> mat4(mat2);
VERIFY_IS_EQUAL(mat3.rank(), 3);
VERIFY_IS_EQUAL(mat3.size(), 2*3*7);
VERIFY_IS_EQUAL(mat3.dimension(0), 2);
VERIFY_IS_EQUAL(mat3.dimension(1), 3);
VERIFY_IS_EQUAL(mat3.dimension(2), 7);
VERIFY_IS_EQUAL(mat4.rank(), 3);
VERIFY_IS_EQUAL(mat4.size(), 2*3*7);
VERIFY_IS_EQUAL(mat4.dimension(0), 2);
VERIFY_IS_EQUAL(mat4.dimension(1), 3);
VERIFY_IS_EQUAL(mat4.dimension(2), 7);
val = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_EQUAL(mat3(i,j,k), val);
VERIFY_IS_EQUAL(mat4(i,j,k), val);
val++;
}
}
}
TensorFixedSize<int, Sizes<2,3,7>> mat5;
val = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
mat5(i,j,k) = val;
val++;
}
}
}
TensorMap<TensorFixedSize<int, Sizes<2,3,7>>> mat6(mat5);
VERIFY_IS_EQUAL(mat6.rank(), 3);
VERIFY_IS_EQUAL(mat6.size(), 2*3*7);
VERIFY_IS_EQUAL(mat6.dimension(0), 2);
VERIFY_IS_EQUAL(mat6.dimension(1), 3);
VERIFY_IS_EQUAL(mat6.dimension(2), 7);
val = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
VERIFY_IS_EQUAL(mat6(i,j,k), val);
val++;
}
}
}
}
static int f(const TensorMap<Tensor<int, 3> >& tensor) {
Tensor<int, 1> result = tensor.sum();
return result(0);
}
static void test_casting()
{
Tensor<int, 3> tensor(2,3,7);
int val = 0;
for (int i = 0; i < 2; ++i) {
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 7; ++k) {
tensor(i,j,k) = val;
val++;
}
}
}
TensorMap<Tensor<int, 3>> map(tensor);
int sum1 = f(map);
int sum2 = f(tensor);
VERIFY_IS_EQUAL(sum1, sum2);
VERIFY_IS_EQUAL(sum1, 41);
}
void test_cxx11_tensor_map()
{
CALL_SUBTEST(test_1d());
CALL_SUBTEST(test_2d());
CALL_SUBTEST(test_3d());
CALL_SUBTEST(test_from_tensor());
CALL_SUBTEST(test_casting());
}