From c53f783705e05c07d9f1c02ab12fdb5d57f1a7a9 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 1 Sep 2016 11:41:27 -0700 Subject: [PATCH] Updated the contraction code to support constant inputs. --- .../CXX11/src/Tensor/TensorContraction.h | 3 ++- unsupported/test/cxx11_tensor_contraction.cpp | 23 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h index a6001074b..20b29e5fd 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h @@ -25,7 +25,8 @@ template struct traits > { // Type promotion to handle the case where the types of the lhs and the rhs are different. - typedef typename gebp_traits::ResScalar Scalar; + typedef typename gebp_traits::type, + typename remove_const::type>::ResScalar Scalar; typedef typename promote_storage_type::StorageKind, typename traits::StorageKind>::ret StorageKind; diff --git a/unsupported/test/cxx11_tensor_contraction.cpp b/unsupported/test/cxx11_tensor_contraction.cpp index 73623b2ed..ace97057f 100644 --- a/unsupported/test/cxx11_tensor_contraction.cpp +++ b/unsupported/test/cxx11_tensor_contraction.cpp @@ -489,6 +489,27 @@ static void test_tensor_product() } +template +static void test_const_inputs() +{ + Tensor in1(2, 3); + Tensor in2(3, 2); + in1.setRandom(); + in2.setRandom(); + + TensorMap > mat1(in1.data(), 2, 3); + TensorMap > mat2(in2.data(), 3, 2); + Tensor mat3(2,2); + + Eigen::array dims = {{DimPair(1, 0)}}; + mat3 = mat1.contract(mat2, dims); + + VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0)); + VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1)); + VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0)); + VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1)); +} + void test_cxx11_tensor_contraction() { CALL_SUBTEST(test_evals()); @@ -519,4 +540,6 @@ void test_cxx11_tensor_contraction() CALL_SUBTEST(test_small_blocking_factors()); CALL_SUBTEST(test_tensor_product()); CALL_SUBTEST(test_tensor_product()); + CALL_SUBTEST(test_const_inputs()); + CALL_SUBTEST(test_const_inputs()); }