Added more tests to validate support for tensors laid out in RowMajor order.

This commit is contained in:
Benoit Steiner 2015-02-25 16:14:59 -08:00
parent 1cfd51908c
commit 410070e5ab

View File

@ -114,18 +114,19 @@ static void test_simple_slice()
}
}
// TODO(andydavis) Add RowMajor support when TensorContract supports RowMajor.
template<int DataLayout>
static void test_slice_in_expr() {
MatrixXf m1(7,7);
MatrixXf m2(3,3);
typedef Matrix<float, Dynamic, Dynamic, DataLayout> Mtx;
Mtx m1(7,7);
Mtx m2(3,3);
m1.setRandom();
m2.setRandom();
MatrixXf m3 = m1.block(1, 2, 3, 3) * m2.block(0, 2, 3, 1);
Mtx m3 = m1.block(1, 2, 3, 3) * m2.block(0, 2, 3, 1);
TensorMap<Tensor<float, 2>> tensor1(m1.data(), 7, 7);
TensorMap<Tensor<float, 2>> tensor2(m2.data(), 3, 3);
Tensor<float, 2> tensor3(3,1);
TensorMap<Tensor<float, 2, DataLayout>> tensor1(m1.data(), 7, 7);
TensorMap<Tensor<float, 2, DataLayout>> tensor2(m2.data(), 3, 3);
Tensor<float, 2, DataLayout> tensor3(3,1);
typedef Tensor<float, 1>::DimensionPair DimPair;
array<DimPair, 1> contract_along{{DimPair(1, 0)}};
@ -135,7 +136,7 @@ static void test_slice_in_expr() {
Eigen::DSizes<ptrdiff_t, 2> sizes2(3,1);
tensor3 = tensor1.slice(indices1, sizes1).contract(tensor2.slice(indices2, sizes2), contract_along);
Map<MatrixXf> res(tensor3.data(), 3, 1);
Map<Mtx> res(tensor3.data(), 3, 1);
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 1; ++j) {
VERIFY_IS_APPROX(res(i,j), m3(i,j));
@ -143,8 +144,8 @@ static void test_slice_in_expr() {
}
// Take an arbitrary slice of an arbitrarily sized tensor.
TensorMap<Tensor<const float, 2>> tensor4(m1.data(), 7, 7);
Tensor<float, 1> tensor6 = tensor4.reshape(DSizes<ptrdiff_t, 1>(7*7)).exp().slice(DSizes<ptrdiff_t, 1>(0), DSizes<ptrdiff_t, 1>(35));
TensorMap<Tensor<const float, 2, DataLayout>> tensor4(m1.data(), 7, 7);
Tensor<float, 1, DataLayout> tensor6 = tensor4.reshape(DSizes<ptrdiff_t, 1>(7*7)).exp().slice(DSizes<ptrdiff_t, 1>(0), DSizes<ptrdiff_t, 1>(35));
for (int i = 0; i < 35; ++i) {
VERIFY_IS_APPROX(tensor6(i), expf(tensor4.data()[i]));
}
@ -304,14 +305,14 @@ static void test_slice_raw_data()
VERIFY_IS_EQUAL(slice6.data(), tensor.data());
}
template<int DataLayout>
static void test_composition()
{
Eigen::Tensor<float, 2> matrix(7, 11);
Eigen::Tensor<float, 2, DataLayout> matrix(7, 11);
matrix.setRandom();
const DSizes<ptrdiff_t, 3> newDims{{1, 1, 11}};
Eigen::Tensor<float, 3> tensor =
Eigen::Tensor<float, 3, DataLayout> tensor =
matrix.slice(DSizes<ptrdiff_t, 2>(2, 0), DSizes<ptrdiff_t, 2>(1, 11)).reshape(newDims);
VERIFY_IS_EQUAL(tensor.dimensions().TotalSize(), 11ul);
@ -332,11 +333,13 @@ void test_cxx11_tensor_morphing()
CALL_SUBTEST(test_simple_slice<ColMajor>());
CALL_SUBTEST(test_simple_slice<RowMajor>());
CALL_SUBTEST(test_slice_in_expr());
CALL_SUBTEST(test_slice_in_expr<ColMajor>());
CALL_SUBTEST(test_slice_in_expr<RowMajor>());
CALL_SUBTEST(test_slice_as_lvalue<ColMajor>());
CALL_SUBTEST(test_slice_as_lvalue<RowMajor>());
CALL_SUBTEST(test_slice_raw_data<ColMajor>());
CALL_SUBTEST(test_slice_raw_data<RowMajor>());
CALL_SUBTEST(test_composition());
CALL_SUBTEST(test_composition<ColMajor>());
CALL_SUBTEST(test_composition<RowMajor>());
}