mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 02:29:33 +08:00
simplify and speedup sparse * dense matrix products
This commit is contained in:
parent
85b358097d
commit
553a0ae924
@ -149,6 +149,102 @@ struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
|
||||
typedef Dense StorageKind;
|
||||
typedef MatrixXpr XprKind;
|
||||
};
|
||||
|
||||
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
|
||||
int LhsStorageOrder = SparseLhsType::IsRowMajor?RowMajor:ColMajor,
|
||||
bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
|
||||
struct sparse_time_dense_product_impl;
|
||||
|
||||
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
||||
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
|
||||
{
|
||||
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
||||
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
||||
typedef typename internal::remove_all<DenseResType>::type Res;
|
||||
typedef typename Lhs::Index Index;
|
||||
typedef typename Lhs::InnerIterator LhsInnerIterator;
|
||||
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
|
||||
{
|
||||
for(Index c=0; c<rhs.cols(); ++c)
|
||||
{
|
||||
Index j=0;
|
||||
for(j=0; j<lhs.outerSize(); ++j)
|
||||
{
|
||||
typename Res::Scalar tmp(0);
|
||||
for(LhsInnerIterator it(lhs,j); it ;++it)
|
||||
tmp += it.value() * rhs.coeff(it.index(),c);
|
||||
res.coeffRef(j,c) = alpha * tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
||||
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
|
||||
{
|
||||
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
||||
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
||||
typedef typename internal::remove_all<DenseResType>::type Res;
|
||||
typedef typename Lhs::InnerIterator LhsInnerIterator;
|
||||
typedef typename Lhs::Index Index;
|
||||
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
|
||||
{
|
||||
for(Index c=0; c<rhs.cols(); ++c)
|
||||
{
|
||||
for(Index j=0; j<lhs.outerSize(); ++j)
|
||||
{
|
||||
typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
|
||||
for(LhsInnerIterator it(lhs,j); it ;++it)
|
||||
res.coeffRef(it.index(),c) += it.value() * rhs_j;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
||||
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
|
||||
{
|
||||
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
||||
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
||||
typedef typename internal::remove_all<DenseResType>::type Res;
|
||||
typedef typename Lhs::InnerIterator LhsInnerIterator;
|
||||
typedef typename Lhs::Index Index;
|
||||
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
|
||||
{
|
||||
for(Index j=0; j<lhs.outerSize(); ++j)
|
||||
{
|
||||
typename Res::RowXpr res_j(res.row(j));
|
||||
for(LhsInnerIterator it(lhs,j); it ;++it)
|
||||
res_j += (alpha*it.value()) * rhs.row(it.index());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
|
||||
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
|
||||
{
|
||||
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
|
||||
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
|
||||
typedef typename internal::remove_all<DenseResType>::type Res;
|
||||
typedef typename Lhs::InnerIterator LhsInnerIterator;
|
||||
typedef typename Lhs::Index Index;
|
||||
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
|
||||
{
|
||||
for(Index j=0; j<lhs.outerSize(); ++j)
|
||||
{
|
||||
typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
|
||||
for(LhsInnerIterator it(lhs,j); it ;++it)
|
||||
res.row(it.index()) += (alpha*it.value()) * rhs_j;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
|
||||
inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
|
||||
{
|
||||
sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
|
||||
}
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
template<typename Lhs, typename Rhs>
|
||||
@ -163,27 +259,7 @@ class SparseTimeDenseProduct
|
||||
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
|
||||
{
|
||||
typedef typename _LhsNested::InnerIterator LhsInnerIterator;
|
||||
enum {
|
||||
LhsIsRowMajor = (_LhsNested::Flags&RowMajorBit)==RowMajorBit,
|
||||
RhsIsVector = _RhsNested::ColsAtCompileTime==1
|
||||
};
|
||||
Index j=0;
|
||||
for(j=0; j<m_lhs.outerSize(); ++j)
|
||||
{
|
||||
typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0);
|
||||
typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0));
|
||||
typename Dest::Scalar tmp(0);
|
||||
for(LhsInnerIterator it(m_lhs,j); it ;++it)
|
||||
{
|
||||
if(LhsIsRowMajor && RhsIsVector) tmp += (it.value()) * m_rhs.coeff(it.index());
|
||||
else if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index());
|
||||
else if(RhsIsVector) dest.coeffRef(it.index()) += it.value() * rhs_j;
|
||||
else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
|
||||
}
|
||||
if(LhsIsRowMajor && RhsIsVector)
|
||||
dest.coeffRef(LhsIsRowMajor ? j : 0) = alpha * tmp;
|
||||
}
|
||||
internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -213,11 +289,10 @@ class DenseTimeSparseProduct
|
||||
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
|
||||
{
|
||||
typedef typename _RhsNested::InnerIterator RhsInnerIterator;
|
||||
enum { RhsIsRowMajor = (_RhsNested::Flags&RowMajorBit)==RowMajorBit };
|
||||
for(Index j=0; j<m_rhs.outerSize(); ++j)
|
||||
for(RhsInnerIterator i(m_rhs,j); i; ++i)
|
||||
dest.col(RhsIsRowMajor ? i.index() : j) += (alpha*i.value()) * m_lhs.col(RhsIsRowMajor ? j : i.index());
|
||||
Transpose<const _LhsNested> lhs_t(m_lhs);
|
||||
Transpose<const _RhsNested> rhs_t(m_rhs);
|
||||
Transpose<Dest> dest_t(dest);
|
||||
internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
|
||||
}
|
||||
|
||||
private:
|
||||
|
Loading…
x
Reference in New Issue
Block a user