mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-14 18:33:16 +08:00
Updated the contraction code to support constant inputs.
This commit is contained in:
parent
ef54723dbe
commit
c53f783705
@ -25,7 +25,8 @@ template<typename Dimensions, typename LhsXprType, typename RhsXprType>
|
|||||||
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
||||||
{
|
{
|
||||||
// Type promotion to handle the case where the types of the lhs and the rhs are different.
|
// Type promotion to handle the case where the types of the lhs and the rhs are different.
|
||||||
typedef typename gebp_traits<typename LhsXprType::Scalar, typename RhsXprType::Scalar>::ResScalar Scalar;
|
typedef typename gebp_traits<typename remove_const<typename LhsXprType::Scalar>::type,
|
||||||
|
typename remove_const<typename RhsXprType::Scalar>::type>::ResScalar Scalar;
|
||||||
|
|
||||||
typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
|
typedef typename promote_storage_type<typename traits<LhsXprType>::StorageKind,
|
||||||
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
|
typename traits<RhsXprType>::StorageKind>::ret StorageKind;
|
||||||
|
@ -489,6 +489,27 @@ static void test_tensor_product()
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<int DataLayout>
|
||||||
|
static void test_const_inputs()
|
||||||
|
{
|
||||||
|
Tensor<float, 2, DataLayout> in1(2, 3);
|
||||||
|
Tensor<float, 2, DataLayout> in2(3, 2);
|
||||||
|
in1.setRandom();
|
||||||
|
in2.setRandom();
|
||||||
|
|
||||||
|
TensorMap<Tensor<const float, 2, DataLayout> > mat1(in1.data(), 2, 3);
|
||||||
|
TensorMap<Tensor<const float, 2, DataLayout> > mat2(in2.data(), 3, 2);
|
||||||
|
Tensor<float, 2, DataLayout> mat3(2,2);
|
||||||
|
|
||||||
|
Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
|
||||||
|
mat3 = mat1.contract(mat2, dims);
|
||||||
|
|
||||||
|
VERIFY_IS_APPROX(mat3(0,0), mat1(0,0)*mat2(0,0) + mat1(0,1)*mat2(1,0) + mat1(0,2)*mat2(2,0));
|
||||||
|
VERIFY_IS_APPROX(mat3(0,1), mat1(0,0)*mat2(0,1) + mat1(0,1)*mat2(1,1) + mat1(0,2)*mat2(2,1));
|
||||||
|
VERIFY_IS_APPROX(mat3(1,0), mat1(1,0)*mat2(0,0) + mat1(1,1)*mat2(1,0) + mat1(1,2)*mat2(2,0));
|
||||||
|
VERIFY_IS_APPROX(mat3(1,1), mat1(1,0)*mat2(0,1) + mat1(1,1)*mat2(1,1) + mat1(1,2)*mat2(2,1));
|
||||||
|
}
|
||||||
|
|
||||||
void test_cxx11_tensor_contraction()
|
void test_cxx11_tensor_contraction()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_evals<ColMajor>());
|
CALL_SUBTEST(test_evals<ColMajor>());
|
||||||
@ -519,4 +540,6 @@ void test_cxx11_tensor_contraction()
|
|||||||
CALL_SUBTEST(test_small_blocking_factors<RowMajor>());
|
CALL_SUBTEST(test_small_blocking_factors<RowMajor>());
|
||||||
CALL_SUBTEST(test_tensor_product<ColMajor>());
|
CALL_SUBTEST(test_tensor_product<ColMajor>());
|
||||||
CALL_SUBTEST(test_tensor_product<RowMajor>());
|
CALL_SUBTEST(test_tensor_product<RowMajor>());
|
||||||
|
CALL_SUBTEST(test_const_inputs<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_const_inputs<RowMajor>());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user