mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 12:46:00 +08:00
Add support to directly evaluate the product of two sparse matrices within a dense matrix.
This commit is contained in:
parent
a5324a131f
commit
e6f8c5c325
@ -1,7 +1,7 @@
|
|||||||
// This file is part of Eigen, a lightweight C++ template library
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
// for linear algebra.
|
// for linear algebra.
|
||||||
//
|
//
|
||||||
// Copyright (C) 2008-2014 Gael Guennebaud <gael.guennebaud@inria.fr>
|
// Copyright (C) 2008-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
|
||||||
//
|
//
|
||||||
// This Source Code Form is subject to the terms of the Mozilla
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
@ -255,6 +255,89 @@ struct conservative_sparse_sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,R
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
static void sparse_sparse_to_dense_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||||
|
{
|
||||||
|
typedef typename remove_all<Lhs>::type::Scalar Scalar;
|
||||||
|
Index cols = rhs.outerSize();
|
||||||
|
eigen_assert(lhs.outerSize() == rhs.innerSize());
|
||||||
|
|
||||||
|
evaluator<Lhs> lhsEval(lhs);
|
||||||
|
evaluator<Rhs> rhsEval(rhs);
|
||||||
|
|
||||||
|
for (Index j=0; j<cols; ++j)
|
||||||
|
{
|
||||||
|
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
|
||||||
|
{
|
||||||
|
Scalar y = rhsIt.value();
|
||||||
|
Index k = rhsIt.index();
|
||||||
|
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, k); lhsIt; ++lhsIt)
|
||||||
|
{
|
||||||
|
Index i = lhsIt.index();
|
||||||
|
Scalar x = lhsIt.value();
|
||||||
|
res.coeffRef(i,j) += x * y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType,
|
||||||
|
int LhsStorageOrder = (traits<Lhs>::Flags&RowMajorBit) ? RowMajor : ColMajor,
|
||||||
|
int RhsStorageOrder = (traits<Rhs>::Flags&RowMajorBit) ? RowMajor : ColMajor>
|
||||||
|
struct sparse_sparse_to_dense_product_selector;
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor>
|
||||||
|
{
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||||
|
{
|
||||||
|
internal::sparse_sparse_to_dense_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, res);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor>
|
||||||
|
{
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||||
|
{
|
||||||
|
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||||
|
ColMajorMatrix lhsCol(lhs);
|
||||||
|
internal::sparse_sparse_to_dense_product_impl<ColMajorMatrix,Rhs,ResultType>(lhsCol, rhs, res);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor>
|
||||||
|
{
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||||
|
{
|
||||||
|
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename ResultType::StorageIndex> ColMajorMatrix;
|
||||||
|
ColMajorMatrix rhsCol(rhs);
|
||||||
|
internal::sparse_sparse_to_dense_product_impl<Lhs,ColMajorMatrix,ResultType>(lhs, rhsCol, res);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_to_dense_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor>
|
||||||
|
{
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
|
||||||
|
{
|
||||||
|
Transpose<ResultType> trRes(res);
|
||||||
|
internal::sparse_sparse_to_dense_product_impl<Rhs,Lhs,Transpose<ResultType> >(rhs, lhs, trRes);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
@ -133,8 +133,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Sparse, Scalar>
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Sparse to Dense assignment
|
// Sparse to Dense assignment
|
||||||
template< typename DstXprType, typename SrcXprType, typename Functor, typename Scalar>
|
template< typename DstXprType, typename SrcXprType, typename Functor>
|
||||||
struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
|
struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense>
|
||||||
{
|
{
|
||||||
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
|
static void run(DstXprType &dst, const SrcXprType &src, const Functor &func)
|
||||||
{
|
{
|
||||||
@ -149,8 +149,8 @@ struct Assignment<DstXprType, SrcXprType, Functor, Sparse2Dense, Scalar>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template< typename DstXprType, typename SrcXprType, typename Scalar>
|
template< typename DstXprType, typename SrcXprType>
|
||||||
struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense, Scalar>
|
struct Assignment<DstXprType, SrcXprType, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense>
|
||||||
{
|
{
|
||||||
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
|
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
|
||||||
{
|
{
|
||||||
|
@ -281,7 +281,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
|
|
||||||
// sparse * sparse
|
// sparse * sparse
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
const Product<Derived,OtherDerived>
|
const Product<Derived,OtherDerived,AliasFreeProduct>
|
||||||
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
||||||
|
|
||||||
// sparse * dense
|
// sparse * dense
|
||||||
|
@ -25,10 +25,10 @@ namespace Eigen {
|
|||||||
* */
|
* */
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
inline const Product<Derived,OtherDerived>
|
inline const Product<Derived,OtherDerived,AliasFreeProduct>
|
||||||
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
|
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
|
||||||
{
|
{
|
||||||
return Product<Derived,OtherDerived>(derived(), other.derived());
|
return Product<Derived,OtherDerived,AliasFreeProduct>(derived(), other.derived());
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
@ -61,6 +61,36 @@ struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, SparseShape, Produc
|
|||||||
: public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
|
: public generic_product_impl<Lhs, Rhs, SparseShape, SparseShape, ProductType>
|
||||||
{};
|
{};
|
||||||
|
|
||||||
|
// Dense = sparse * sparse
|
||||||
|
template< typename DstXprType, typename Lhs, typename Rhs, int Options/*, typename Scalar*/>
|
||||||
|
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
|
||||||
|
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
|
||||||
|
{
|
||||||
|
typedef Product<Lhs,Rhs,Options> SrcXprType;
|
||||||
|
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar> &)
|
||||||
|
{
|
||||||
|
dst.setZero();
|
||||||
|
dst += src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Dense += sparse * sparse
|
||||||
|
template< typename DstXprType, typename Lhs, typename Rhs, int Options>
|
||||||
|
struct Assignment<DstXprType, Product<Lhs,Rhs,Options>, internal::add_assign_op<typename DstXprType::Scalar>, Sparse2Dense/*,
|
||||||
|
typename enable_if<(Options==DefaultProduct || Options==AliasFreeProduct),Scalar>::type*/>
|
||||||
|
{
|
||||||
|
typedef Product<Lhs,Rhs,Options> SrcXprType;
|
||||||
|
static void run(DstXprType &dst, const SrcXprType &src, const internal::add_assign_op<typename DstXprType::Scalar> &)
|
||||||
|
{
|
||||||
|
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
|
||||||
|
typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
|
||||||
|
LhsNested lhsNested(src.lhs());
|
||||||
|
RhsNested rhsNested(src.rhs());
|
||||||
|
internal::sparse_sparse_to_dense_product_selector<typename remove_all<LhsNested>::type,
|
||||||
|
typename remove_all<RhsNested>::type, DstXprType>::run(lhsNested,rhsNested,dst);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, int Options>
|
template<typename Lhs, typename Rhs, int Options>
|
||||||
struct evaluator<SparseView<Product<Lhs, Rhs, Options> > >
|
struct evaluator<SparseView<Product<Lhs, Rhs, Options> > >
|
||||||
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>
|
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>
|
||||||
|
@ -76,6 +76,17 @@ template<typename SparseMatrixType> void sparse_product()
|
|||||||
VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose());
|
VERIFY_IS_APPROX(m4=(m2t.transpose()*m3t.transpose()).pruned(0), refMat4=refMat2t.transpose()*refMat3t.transpose());
|
||||||
VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose());
|
VERIFY_IS_APPROX(m4=(m2*m3t.transpose()).pruned(0), refMat4=refMat2*refMat3t.transpose());
|
||||||
|
|
||||||
|
// dense ?= sparse * sparse
|
||||||
|
VERIFY_IS_APPROX(dm4 =m2*m3, refMat4 =refMat2*refMat3);
|
||||||
|
VERIFY_IS_APPROX(dm4+=m2*m3, refMat4+=refMat2*refMat3);
|
||||||
|
VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3, refMat4 =refMat2t.transpose()*refMat3);
|
||||||
|
VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3, refMat4+=refMat2t.transpose()*refMat3);
|
||||||
|
VERIFY_IS_APPROX(dm4 =m2t.transpose()*m3t.transpose(), refMat4 =refMat2t.transpose()*refMat3t.transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4+=m2t.transpose()*m3t.transpose(), refMat4+=refMat2t.transpose()*refMat3t.transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4 =m2*m3t.transpose(), refMat4 =refMat2*refMat3t.transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4+=m2*m3t.transpose(), refMat4+=refMat2*refMat3t.transpose());
|
||||||
|
VERIFY_IS_APPROX(dm4 = m2*m3*s1, refMat4 = refMat2*refMat3*s1);
|
||||||
|
|
||||||
// test aliasing
|
// test aliasing
|
||||||
m4 = m2; refMat4 = refMat2;
|
m4 = m2; refMat4 = refMat2;
|
||||||
VERIFY_IS_APPROX(m4=m4*m3, refMat4=refMat4*refMat3);
|
VERIFY_IS_APPROX(m4=m4*m3, refMat4=refMat4*refMat3);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user