mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
Added more tests to validate support for tensors laid out in RowMajor order.
This commit is contained in:
parent
1cfd51908c
commit
410070e5ab
@ -114,18 +114,19 @@ static void test_simple_slice()
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(andydavis) Add RowMajor support when TensorContract supports RowMajor.
|
template<int DataLayout>
|
||||||
static void test_slice_in_expr() {
|
static void test_slice_in_expr() {
|
||||||
MatrixXf m1(7,7);
|
typedef Matrix<float, Dynamic, Dynamic, DataLayout> Mtx;
|
||||||
MatrixXf m2(3,3);
|
Mtx m1(7,7);
|
||||||
|
Mtx m2(3,3);
|
||||||
m1.setRandom();
|
m1.setRandom();
|
||||||
m2.setRandom();
|
m2.setRandom();
|
||||||
|
|
||||||
MatrixXf m3 = m1.block(1, 2, 3, 3) * m2.block(0, 2, 3, 1);
|
Mtx m3 = m1.block(1, 2, 3, 3) * m2.block(0, 2, 3, 1);
|
||||||
|
|
||||||
TensorMap<Tensor<float, 2>> tensor1(m1.data(), 7, 7);
|
TensorMap<Tensor<float, 2, DataLayout>> tensor1(m1.data(), 7, 7);
|
||||||
TensorMap<Tensor<float, 2>> tensor2(m2.data(), 3, 3);
|
TensorMap<Tensor<float, 2, DataLayout>> tensor2(m2.data(), 3, 3);
|
||||||
Tensor<float, 2> tensor3(3,1);
|
Tensor<float, 2, DataLayout> tensor3(3,1);
|
||||||
typedef Tensor<float, 1>::DimensionPair DimPair;
|
typedef Tensor<float, 1>::DimensionPair DimPair;
|
||||||
array<DimPair, 1> contract_along{{DimPair(1, 0)}};
|
array<DimPair, 1> contract_along{{DimPair(1, 0)}};
|
||||||
|
|
||||||
@ -135,7 +136,7 @@ static void test_slice_in_expr() {
|
|||||||
Eigen::DSizes<ptrdiff_t, 2> sizes2(3,1);
|
Eigen::DSizes<ptrdiff_t, 2> sizes2(3,1);
|
||||||
tensor3 = tensor1.slice(indices1, sizes1).contract(tensor2.slice(indices2, sizes2), contract_along);
|
tensor3 = tensor1.slice(indices1, sizes1).contract(tensor2.slice(indices2, sizes2), contract_along);
|
||||||
|
|
||||||
Map<MatrixXf> res(tensor3.data(), 3, 1);
|
Map<Mtx> res(tensor3.data(), 3, 1);
|
||||||
for (int i = 0; i < 3; ++i) {
|
for (int i = 0; i < 3; ++i) {
|
||||||
for (int j = 0; j < 1; ++j) {
|
for (int j = 0; j < 1; ++j) {
|
||||||
VERIFY_IS_APPROX(res(i,j), m3(i,j));
|
VERIFY_IS_APPROX(res(i,j), m3(i,j));
|
||||||
@ -143,8 +144,8 @@ static void test_slice_in_expr() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Take an arbitrary slice of an arbitrarily sized tensor.
|
// Take an arbitrary slice of an arbitrarily sized tensor.
|
||||||
TensorMap<Tensor<const float, 2>> tensor4(m1.data(), 7, 7);
|
TensorMap<Tensor<const float, 2, DataLayout>> tensor4(m1.data(), 7, 7);
|
||||||
Tensor<float, 1> tensor6 = tensor4.reshape(DSizes<ptrdiff_t, 1>(7*7)).exp().slice(DSizes<ptrdiff_t, 1>(0), DSizes<ptrdiff_t, 1>(35));
|
Tensor<float, 1, DataLayout> tensor6 = tensor4.reshape(DSizes<ptrdiff_t, 1>(7*7)).exp().slice(DSizes<ptrdiff_t, 1>(0), DSizes<ptrdiff_t, 1>(35));
|
||||||
for (int i = 0; i < 35; ++i) {
|
for (int i = 0; i < 35; ++i) {
|
||||||
VERIFY_IS_APPROX(tensor6(i), expf(tensor4.data()[i]));
|
VERIFY_IS_APPROX(tensor6(i), expf(tensor4.data()[i]));
|
||||||
}
|
}
|
||||||
@ -304,14 +305,14 @@ static void test_slice_raw_data()
|
|||||||
VERIFY_IS_EQUAL(slice6.data(), tensor.data());
|
VERIFY_IS_EQUAL(slice6.data(), tensor.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<int DataLayout>
|
||||||
static void test_composition()
|
static void test_composition()
|
||||||
{
|
{
|
||||||
Eigen::Tensor<float, 2> matrix(7, 11);
|
Eigen::Tensor<float, 2, DataLayout> matrix(7, 11);
|
||||||
matrix.setRandom();
|
matrix.setRandom();
|
||||||
|
|
||||||
const DSizes<ptrdiff_t, 3> newDims{{1, 1, 11}};
|
const DSizes<ptrdiff_t, 3> newDims{{1, 1, 11}};
|
||||||
Eigen::Tensor<float, 3> tensor =
|
Eigen::Tensor<float, 3, DataLayout> tensor =
|
||||||
matrix.slice(DSizes<ptrdiff_t, 2>(2, 0), DSizes<ptrdiff_t, 2>(1, 11)).reshape(newDims);
|
matrix.slice(DSizes<ptrdiff_t, 2>(2, 0), DSizes<ptrdiff_t, 2>(1, 11)).reshape(newDims);
|
||||||
|
|
||||||
VERIFY_IS_EQUAL(tensor.dimensions().TotalSize(), 11ul);
|
VERIFY_IS_EQUAL(tensor.dimensions().TotalSize(), 11ul);
|
||||||
@ -332,11 +333,13 @@ void test_cxx11_tensor_morphing()
|
|||||||
|
|
||||||
CALL_SUBTEST(test_simple_slice<ColMajor>());
|
CALL_SUBTEST(test_simple_slice<ColMajor>());
|
||||||
CALL_SUBTEST(test_simple_slice<RowMajor>());
|
CALL_SUBTEST(test_simple_slice<RowMajor>());
|
||||||
CALL_SUBTEST(test_slice_in_expr());
|
CALL_SUBTEST(test_slice_in_expr<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_slice_in_expr<RowMajor>());
|
||||||
CALL_SUBTEST(test_slice_as_lvalue<ColMajor>());
|
CALL_SUBTEST(test_slice_as_lvalue<ColMajor>());
|
||||||
CALL_SUBTEST(test_slice_as_lvalue<RowMajor>());
|
CALL_SUBTEST(test_slice_as_lvalue<RowMajor>());
|
||||||
CALL_SUBTEST(test_slice_raw_data<ColMajor>());
|
CALL_SUBTEST(test_slice_raw_data<ColMajor>());
|
||||||
CALL_SUBTEST(test_slice_raw_data<RowMajor>());
|
CALL_SUBTEST(test_slice_raw_data<RowMajor>());
|
||||||
|
|
||||||
CALL_SUBTEST(test_composition());
|
CALL_SUBTEST(test_composition<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_composition<RowMajor>());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user