mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Implement evaluators for sparse times diagonal products.
This commit is contained in:
parent
ae039dde13
commit
73e686c6a4
@ -48,6 +48,7 @@ struct Sparse {};
|
|||||||
#include "src/SparseCore/SparseDot.h"
|
#include "src/SparseCore/SparseDot.h"
|
||||||
#include "src/SparseCore/SparseRedux.h"
|
#include "src/SparseCore/SparseRedux.h"
|
||||||
#include "src/SparseCore/SparseView.h"
|
#include "src/SparseCore/SparseView.h"
|
||||||
|
#include "src/SparseCore/SparseDiagonalProduct.h"
|
||||||
#ifndef EIGEN_TEST_EVALUATORS
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
#include "src/SparseCore/SparsePermutation.h"
|
#include "src/SparseCore/SparsePermutation.h"
|
||||||
#include "src/SparseCore/SparseFuzzy.h"
|
#include "src/SparseCore/SparseFuzzy.h"
|
||||||
@ -55,7 +56,6 @@ struct Sparse {};
|
|||||||
#include "src/SparseCore/SparseSparseProductWithPruning.h"
|
#include "src/SparseCore/SparseSparseProductWithPruning.h"
|
||||||
#include "src/SparseCore/SparseProduct.h"
|
#include "src/SparseCore/SparseProduct.h"
|
||||||
#include "src/SparseCore/SparseDenseProduct.h"
|
#include "src/SparseCore/SparseDenseProduct.h"
|
||||||
#include "src/SparseCore/SparseDiagonalProduct.h"
|
|
||||||
#include "src/SparseCore/SparseTriangularView.h"
|
#include "src/SparseCore/SparseTriangularView.h"
|
||||||
#include "src/SparseCore/SparseSelfAdjointView.h"
|
#include "src/SparseCore/SparseSelfAdjointView.h"
|
||||||
#include "src/SparseCore/TriangularSolver.h"
|
#include "src/SparseCore/TriangularSolver.h"
|
||||||
|
@ -24,8 +24,10 @@ namespace Eigen {
|
|||||||
// for that particular case
|
// for that particular case
|
||||||
// The two other cases are symmetric.
|
// The two other cases are symmetric.
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs>
|
template<typename Lhs, typename Rhs>
|
||||||
struct traits<SparseDiagonalProduct<Lhs, Rhs> >
|
struct traits<SparseDiagonalProduct<Lhs, Rhs> >
|
||||||
{
|
{
|
||||||
@ -100,9 +102,14 @@ class SparseDiagonalProduct
|
|||||||
LhsNested m_lhs;
|
LhsNested m_lhs;
|
||||||
RhsNested m_rhs;
|
RhsNested m_rhs;
|
||||||
};
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
|
template<typename Lhs, typename Rhs, typename SparseDiagonalProductType>
|
||||||
class sparse_diagonal_product_inner_iterator_selector
|
class sparse_diagonal_product_inner_iterator_selector
|
||||||
<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
|
<Lhs,Rhs,SparseDiagonalProductType,SDP_IsDiagonal,SDP_IsSparseRowMajor>
|
||||||
@ -179,10 +186,124 @@ class sparse_diagonal_product_inner_iterator_selector
|
|||||||
inline Index row() const { return m_outer; }
|
inline Index row() const { return m_outer; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
|
enum {
|
||||||
|
SDP_AsScalarProduct,
|
||||||
|
SDP_AsCwiseProduct
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename SparseXprType, typename DiagonalCoeffType, int SDP_Tag>
|
||||||
|
struct sparse_diagonal_product_evaluator;
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Options, int ProductTag>
|
||||||
|
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||||
|
: public sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct>
|
||||||
|
{
|
||||||
|
typedef Product<Lhs, Rhs, Options> XprType;
|
||||||
|
typedef evaluator<XprType> type;
|
||||||
|
typedef evaluator<XprType> nestedType;
|
||||||
|
enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
|
||||||
|
|
||||||
|
typedef sparse_diagonal_product_evaluator<Rhs, typename Lhs::DiagonalVectorType, Rhs::Flags&RowMajorBit?SDP_AsScalarProduct:SDP_AsCwiseProduct> Base;
|
||||||
|
product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Options, int ProductTag>
|
||||||
|
struct product_evaluator<Product<Lhs, Rhs, Options>, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar>
|
||||||
|
: public sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct>
|
||||||
|
{
|
||||||
|
typedef Product<Lhs, Rhs, Options> XprType;
|
||||||
|
typedef evaluator<XprType> type;
|
||||||
|
typedef evaluator<XprType> nestedType;
|
||||||
|
enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags
|
||||||
|
|
||||||
|
typedef sparse_diagonal_product_evaluator<Lhs, Transpose<const typename Rhs::DiagonalVectorType>, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base;
|
||||||
|
product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename SparseXprType, typename DiagonalCoeffType>
|
||||||
|
struct sparse_diagonal_product_evaluator<SparseXprType, DiagonalCoeffType, SDP_AsScalarProduct>
|
||||||
|
{
|
||||||
|
protected:
|
||||||
|
typedef typename evaluator<SparseXprType>::InnerIterator SparseXprInnerIterator;
|
||||||
|
typedef typename SparseXprType::Scalar Scalar;
|
||||||
|
typedef typename SparseXprType::Index Index;
|
||||||
|
|
||||||
|
public:
|
||||||
|
class InnerIterator : public SparseXprInnerIterator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
|
||||||
|
: SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer),
|
||||||
|
m_coeff(xprEval.m_diagCoeffImpl.coeff(outer))
|
||||||
|
{}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); }
|
||||||
|
protected:
|
||||||
|
typename DiagonalCoeffType::Scalar m_coeff;
|
||||||
|
};
|
||||||
|
|
||||||
|
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff)
|
||||||
|
: m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff)
|
||||||
|
{}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typename evaluator<SparseXprType>::nestedType m_sparseXprImpl;
|
||||||
|
typename evaluator<DiagonalCoeffType>::nestedType m_diagCoeffImpl;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template<typename SparseXprType, typename DiagCoeffType>
|
||||||
|
struct sparse_diagonal_product_evaluator<SparseXprType, DiagCoeffType, SDP_AsCwiseProduct>
|
||||||
|
{
|
||||||
|
typedef typename SparseXprType::Scalar Scalar;
|
||||||
|
typedef typename SparseXprType::Index Index;
|
||||||
|
|
||||||
|
typedef CwiseBinaryOp<scalar_product_op<Scalar>,
|
||||||
|
const typename SparseXprType::ConstInnerVectorReturnType,
|
||||||
|
const DiagCoeffType> CwiseProductType;
|
||||||
|
|
||||||
|
typedef typename evaluator<CwiseProductType>::type CwiseProductEval;
|
||||||
|
typedef typename evaluator<CwiseProductType>::InnerIterator CwiseProductIterator;
|
||||||
|
|
||||||
|
class InnerIterator : public CwiseProductIterator
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer)
|
||||||
|
: CwiseProductIterator(CwiseProductEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),0),
|
||||||
|
m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),
|
||||||
|
m_outer(outer)
|
||||||
|
{
|
||||||
|
::new (static_cast<CwiseProductIterator*>(this)) CwiseProductIterator(m_cwiseEval,0);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Index outer() const { return m_outer; }
|
||||||
|
inline Index col() const { return SparseXprType::IsRowMajor ? CwiseProductIterator::index() : m_outer; }
|
||||||
|
inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : CwiseProductIterator::index(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Index m_outer;
|
||||||
|
CwiseProductEval m_cwiseEval;
|
||||||
|
};
|
||||||
|
|
||||||
|
sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff)
|
||||||
|
: m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff)
|
||||||
|
{}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typename nested_eval<SparseXprType,1>::type m_sparseXprNested;
|
||||||
|
typename nested_eval<DiagCoeffType,SparseXprType::IsRowMajor ? SparseXprType::RowsAtCompileTime
|
||||||
|
: SparseXprType::ColsAtCompileTime>::type m_diagCoeffNested;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
// SparseMatrixBase functions
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
|
// SparseMatrixBase functions
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
const SparseDiagonalProduct<Derived,OtherDerived>
|
const SparseDiagonalProduct<Derived,OtherDerived>
|
||||||
@ -190,6 +311,7 @@ SparseMatrixBase<Derived>::operator*(const DiagonalBase<OtherDerived> &other) co
|
|||||||
{
|
{
|
||||||
return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
|
return SparseDiagonalProduct<Derived,OtherDerived>(this->derived(), other.derived());
|
||||||
}
|
}
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
|
||||||
|
@ -269,6 +269,7 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
|
const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
|
||||||
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
operator*(const SparseMatrixBase<OtherDerived> &other) const;
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
// sparse * diagonal
|
// sparse * diagonal
|
||||||
template<typename OtherDerived>
|
template<typename OtherDerived>
|
||||||
const SparseDiagonalProduct<Derived,OtherDerived>
|
const SparseDiagonalProduct<Derived,OtherDerived>
|
||||||
@ -279,6 +280,19 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
const SparseDiagonalProduct<OtherDerived,Derived>
|
const SparseDiagonalProduct<OtherDerived,Derived>
|
||||||
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
|
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
|
||||||
{ return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
|
{ return SparseDiagonalProduct<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
|
||||||
|
#else // EIGEN_TEST_EVALUATORS
|
||||||
|
// sparse * diagonal
|
||||||
|
template<typename OtherDerived>
|
||||||
|
const Product<Derived,OtherDerived>
|
||||||
|
operator*(const DiagonalBase<OtherDerived> &other) const
|
||||||
|
{ return Product<Derived,OtherDerived>(derived(), other.derived()); }
|
||||||
|
|
||||||
|
// diagonal * sparse
|
||||||
|
template<typename OtherDerived> friend
|
||||||
|
const Product<OtherDerived,Derived>
|
||||||
|
operator*(const DiagonalBase<OtherDerived> &lhs, const SparseMatrixBase& rhs)
|
||||||
|
{ return Product<OtherDerived,Derived>(lhs.derived(), rhs.derived()); }
|
||||||
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
/** dense * sparse (return a dense object unless it is an outer product) */
|
/** dense * sparse (return a dense object unless it is an outer product) */
|
||||||
template<typename OtherDerived> friend
|
template<typename OtherDerived> friend
|
||||||
|
Loading…
x
Reference in New Issue
Block a user