simplify and speedup sparse * dense matrix products

This commit is contained in:
Gael Guennebaud 2012-03-01 10:13:13 +01:00
parent 85b358097d
commit 553a0ae924

View File

@ -149,6 +149,102 @@ struct traits<SparseTimeDenseProduct<Lhs,Rhs> >
typedef Dense StorageKind;
typedef MatrixXpr XprKind;
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,
int LhsStorageOrder = SparseLhsType::IsRowMajor?RowMajor:ColMajor,
bool ColPerCol = ((DenseRhsType::Flags&RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
struct sparse_time_dense_product_impl;
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, true>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::Index Index;
typedef typename Lhs::InnerIterator LhsInnerIterator;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
{
for(Index c=0; c<rhs.cols(); ++c)
{
Index j=0;
for(j=0; j<lhs.outerSize(); ++j)
{
typename Res::Scalar tmp(0);
for(LhsInnerIterator it(lhs,j); it ;++it)
tmp += it.value() * rhs.coeff(it.index(),c);
res.coeffRef(j,c) = alpha * tmp;
}
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, true>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::InnerIterator LhsInnerIterator;
typedef typename Lhs::Index Index;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
{
for(Index c=0; c<rhs.cols(); ++c)
{
for(Index j=0; j<lhs.outerSize(); ++j)
{
typename Res::Scalar rhs_j = alpha * rhs.coeff(j,c);
for(LhsInnerIterator it(lhs,j); it ;++it)
res.coeffRef(it.index(),c) += it.value() * rhs_j;
}
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, RowMajor, false>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::InnerIterator LhsInnerIterator;
typedef typename Lhs::Index Index;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
{
for(Index j=0; j<lhs.outerSize(); ++j)
{
typename Res::RowXpr res_j(res.row(j));
for(LhsInnerIterator it(lhs,j); it ;++it)
res_j += (alpha*it.value()) * rhs.row(it.index());
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType>
struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, ColMajor, false>
{
typedef typename internal::remove_all<SparseLhsType>::type Lhs;
typedef typename internal::remove_all<DenseRhsType>::type Rhs;
typedef typename internal::remove_all<DenseResType>::type Res;
typedef typename Lhs::InnerIterator LhsInnerIterator;
typedef typename Lhs::Index Index;
static void run(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, typename Res::Scalar alpha)
{
for(Index j=0; j<lhs.outerSize(); ++j)
{
typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
for(LhsInnerIterator it(lhs,j); it ;++it)
res.row(it.index()) += (alpha*it.value()) * rhs_j;
}
}
};
template<typename SparseLhsType, typename DenseRhsType, typename DenseResType,typename AlphaType>
inline void sparse_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
{
sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType>::run(lhs, rhs, res, alpha);
}
} // end namespace internal
template<typename Lhs, typename Rhs>
@ -163,27 +259,7 @@ class SparseTimeDenseProduct
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
typedef typename _LhsNested::InnerIterator LhsInnerIterator;
enum {
LhsIsRowMajor = (_LhsNested::Flags&RowMajorBit)==RowMajorBit,
RhsIsVector = _RhsNested::ColsAtCompileTime==1
};
Index j=0;
for(j=0; j<m_lhs.outerSize(); ++j)
{
typename Rhs::Scalar rhs_j = alpha * m_rhs.coeff(LhsIsRowMajor ? 0 : j,0);
typename Dest::RowXpr dest_j(dest.row(LhsIsRowMajor ? j : 0));
typename Dest::Scalar tmp(0);
for(LhsInnerIterator it(m_lhs,j); it ;++it)
{
if(LhsIsRowMajor && RhsIsVector) tmp += (it.value()) * m_rhs.coeff(it.index());
else if(LhsIsRowMajor) dest_j += (alpha*it.value()) * m_rhs.row(it.index());
else if(RhsIsVector) dest.coeffRef(it.index()) += it.value() * rhs_j;
else dest.row(it.index()) += (alpha*it.value()) * m_rhs.row(j);
}
if(LhsIsRowMajor && RhsIsVector)
dest.coeffRef(LhsIsRowMajor ? j : 0) = alpha * tmp;
}
internal::sparse_time_dense_product(m_lhs, m_rhs, dest, alpha);
}
private:
@ -213,11 +289,10 @@ class DenseTimeSparseProduct
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
{
typedef typename _RhsNested::InnerIterator RhsInnerIterator;
enum { RhsIsRowMajor = (_RhsNested::Flags&RowMajorBit)==RowMajorBit };
for(Index j=0; j<m_rhs.outerSize(); ++j)
for(RhsInnerIterator i(m_rhs,j); i; ++i)
dest.col(RhsIsRowMajor ? i.index() : j) += (alpha*i.value()) * m_lhs.col(RhsIsRowMajor ? j : i.index());
Transpose<const _LhsNested> lhs_t(m_lhs);
Transpose<const _RhsNested> rhs_t(m_rhs);
Transpose<Dest> dest_t(dest);
internal::sparse_time_dense_product(rhs_t, lhs_t, dest_t, alpha);
}
private: