mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Add support for non trivial scalar factor in sparse selfadjoint * dense products, and enable +=/-= assignement for such products.
This changeset also improves the performance by working on column of the result at once.
This commit is contained in:
parent
8132a12625
commit
441b7eaab2
@ -250,11 +250,11 @@ template<int Mode, typename SparseLhsType, typename DenseRhsType, typename Dense
|
||||
inline void sparse_selfadjoint_time_dense_product(const SparseLhsType& lhs, const DenseRhsType& rhs, DenseResType& res, const AlphaType& alpha)
|
||||
{
|
||||
EIGEN_ONLY_USED_FOR_DEBUG(alpha);
|
||||
// TODO use alpha
|
||||
eigen_assert(alpha==AlphaType(1) && "alpha != 1 is not implemented yet, sorry");
|
||||
|
||||
typedef evaluator<SparseLhsType> LhsEval;
|
||||
typedef typename evaluator<SparseLhsType>::InnerIterator LhsIterator;
|
||||
typedef typename internal::nested_eval<SparseLhsType,DenseRhsType::MaxColsAtCompileTime>::type SparseLhsTypeNested;
|
||||
typedef typename internal::remove_all<SparseLhsTypeNested>::type SparseLhsTypeNestedCleaned;
|
||||
typedef evaluator<SparseLhsTypeNestedCleaned> LhsEval;
|
||||
typedef typename LhsEval::InnerIterator LhsIterator;
|
||||
typedef typename SparseLhsType::Scalar LhsScalar;
|
||||
|
||||
enum {
|
||||
@ -266,39 +266,53 @@ inline void sparse_selfadjoint_time_dense_product(const SparseLhsType& lhs, cons
|
||||
ProcessSecondHalf = !ProcessFirstHalf
|
||||
};
|
||||
|
||||
LhsEval lhsEval(lhs);
|
||||
|
||||
for (Index j=0; j<lhs.outerSize(); ++j)
|
||||
SparseLhsTypeNested lhs_nested(lhs);
|
||||
LhsEval lhsEval(lhs_nested);
|
||||
|
||||
// work on one column at once
|
||||
for (Index k=0; k<rhs.cols(); ++k)
|
||||
{
|
||||
LhsIterator i(lhsEval,j);
|
||||
if (ProcessSecondHalf)
|
||||
for (Index j=0; j<lhs.outerSize(); ++j)
|
||||
{
|
||||
while (i && i.index()<j) ++i;
|
||||
if(i && i.index()==j)
|
||||
LhsIterator i(lhsEval,j);
|
||||
// handle diagonal coeff
|
||||
if (ProcessSecondHalf)
|
||||
{
|
||||
res.row(j) += i.value() * rhs.row(j);
|
||||
++i;
|
||||
while (i && i.index()<j) ++i;
|
||||
if(i && i.index()==j)
|
||||
{
|
||||
res(j,k) += alpha * i.value() * rhs(j,k);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
|
||||
// premultiplied rhs for scatters
|
||||
typename ScalarBinaryOpTraits<AlphaType, typename DenseRhsType::Scalar>::ReturnType rhs_j(alpha*rhs(j,k));
|
||||
// accumulator for partial scalar product
|
||||
typename DenseResType::Scalar res_j(0);
|
||||
for(; (ProcessFirstHalf ? i && i.index() < j : i) ; ++i)
|
||||
{
|
||||
LhsScalar lhs_ij = i.value();
|
||||
if(!LhsIsRowMajor) lhs_ij = numext::conj(lhs_ij);
|
||||
res_j += lhs_ij * rhs(i.index(),k);
|
||||
res(i.index(),k) += numext::conj(lhs_ij) * rhs_j;
|
||||
}
|
||||
res(j,k) += alpha * res_j;
|
||||
|
||||
// handle diagonal coeff
|
||||
if (ProcessFirstHalf && i && (i.index()==j))
|
||||
res(j,k) += alpha * i.value() * rhs(j,k);
|
||||
}
|
||||
for(; (ProcessFirstHalf ? i && i.index() < j : i) ; ++i)
|
||||
{
|
||||
Index a = LhsIsRowMajor ? j : i.index();
|
||||
Index b = LhsIsRowMajor ? i.index() : j;
|
||||
LhsScalar v = i.value();
|
||||
res.row(a) += (v) * rhs.row(b);
|
||||
res.row(b) += numext::conj(v) * rhs.row(a);
|
||||
}
|
||||
if (ProcessFirstHalf && i && (i.index()==j))
|
||||
res.row(j) += i.value() * rhs.row(j);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename LhsView, typename Rhs, int ProductType>
|
||||
struct generic_product_impl<LhsView, Rhs, SparseSelfAdjointShape, DenseShape, ProductType>
|
||||
: generic_product_impl_base<LhsView, Rhs, generic_product_impl<LhsView, Rhs, SparseSelfAdjointShape, DenseShape, ProductType> >
|
||||
{
|
||||
template<typename Dest>
|
||||
static void evalTo(Dest& dst, const LhsView& lhsView, const Rhs& rhs)
|
||||
static void scaleAndAddTo(Dest& dst, const LhsView& lhsView, const Rhs& rhs, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename LhsView::_MatrixTypeNested Lhs;
|
||||
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
|
||||
@ -306,16 +320,16 @@ struct generic_product_impl<LhsView, Rhs, SparseSelfAdjointShape, DenseShape, Pr
|
||||
LhsNested lhsNested(lhsView.matrix());
|
||||
RhsNested rhsNested(rhs);
|
||||
|
||||
dst.setZero();
|
||||
internal::sparse_selfadjoint_time_dense_product<LhsView::Mode>(lhsNested, rhsNested, dst, typename Dest::Scalar(1));
|
||||
internal::sparse_selfadjoint_time_dense_product<LhsView::Mode>(lhsNested, rhsNested, dst, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Lhs, typename RhsView, int ProductType>
|
||||
struct generic_product_impl<Lhs, RhsView, DenseShape, SparseSelfAdjointShape, ProductType>
|
||||
: generic_product_impl_base<Lhs, RhsView, generic_product_impl<Lhs, RhsView, DenseShape, SparseSelfAdjointShape, ProductType> >
|
||||
{
|
||||
template<typename Dest>
|
||||
static void evalTo(Dest& dst, const Lhs& lhs, const RhsView& rhsView)
|
||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const RhsView& rhsView, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename RhsView::_MatrixTypeNested Rhs;
|
||||
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
|
||||
@ -323,10 +337,9 @@ struct generic_product_impl<Lhs, RhsView, DenseShape, SparseSelfAdjointShape, Pr
|
||||
LhsNested lhsNested(lhs);
|
||||
RhsNested rhsNested(rhsView.matrix());
|
||||
|
||||
dst.setZero();
|
||||
// transpoe everything
|
||||
// transpose everything
|
||||
Transpose<Dest> dstT(dst);
|
||||
internal::sparse_selfadjoint_time_dense_product<RhsView::Mode>(rhsNested.transpose(), lhsNested.transpose(), dstT, typename Dest::Scalar(1));
|
||||
internal::sparse_selfadjoint_time_dense_product<RhsView::Mode>(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -292,6 +292,10 @@ template<typename SparseMatrixType> void sparse_product()
|
||||
VERIFY_IS_APPROX(x=mUp.template selfadjointView<Upper>()*b, refX=refS*b);
|
||||
VERIFY_IS_APPROX(x=mLo.template selfadjointView<Lower>()*b, refX=refS*b);
|
||||
VERIFY_IS_APPROX(x=mS.template selfadjointView<Upper|Lower>()*b, refX=refS*b);
|
||||
|
||||
VERIFY_IS_APPROX(x.noalias()+=mUp.template selfadjointView<Upper>()*b, refX+=refS*b);
|
||||
VERIFY_IS_APPROX(x.noalias()-=mLo.template selfadjointView<Lower>()*b, refX-=refS*b);
|
||||
VERIFY_IS_APPROX(x.noalias()+=mS.template selfadjointView<Upper|Lower>()*b, refX+=refS*b);
|
||||
|
||||
// sparse selfadjointView with sparse matrices
|
||||
SparseMatrixType mSres(rows,rows);
|
||||
|
Loading…
x
Reference in New Issue
Block a user