diff --git a/Eigen/SparseCore b/Eigen/SparseCore index 340ff7f52..b338950ca 100644 --- a/Eigen/SparseCore +++ b/Eigen/SparseCore @@ -48,6 +48,7 @@ struct Sparse {}; #include "src/SparseCore/SparseDot.h" #include "src/SparseCore/SparseRedux.h" #include "src/SparseCore/SparseView.h" +#include "src/SparseCore/SparseDiagonalProduct.h" #ifndef EIGEN_TEST_EVALUATORS #include "src/SparseCore/SparsePermutation.h" #include "src/SparseCore/SparseFuzzy.h" @@ -55,7 +56,6 @@ struct Sparse {}; #include "src/SparseCore/SparseSparseProductWithPruning.h" #include "src/SparseCore/SparseProduct.h" #include "src/SparseCore/SparseDenseProduct.h" -#include "src/SparseCore/SparseDiagonalProduct.h" #include "src/SparseCore/SparseTriangularView.h" #include "src/SparseCore/SparseSelfAdjointView.h" #include "src/SparseCore/TriangularSolver.h" diff --git a/Eigen/src/SparseCore/SparseDiagonalProduct.h b/Eigen/src/SparseCore/SparseDiagonalProduct.h index 1bb590e64..cf0f77342 100644 --- a/Eigen/src/SparseCore/SparseDiagonalProduct.h +++ b/Eigen/src/SparseCore/SparseDiagonalProduct.h @@ -24,8 +24,10 @@ namespace Eigen { // for that particular case // The two other cases are symmetric. +#ifndef EIGEN_TEST_EVALUATORS + namespace internal { - + template struct traits > { @@ -100,9 +102,14 @@ class SparseDiagonalProduct LhsNested m_lhs; RhsNested m_rhs; }; +#endif namespace internal { +#ifndef EIGEN_TEST_EVALUATORS + + + template class sparse_diagonal_product_inner_iterator_selector @@ -179,10 +186,124 @@ class sparse_diagonal_product_inner_iterator_selector inline Index row() const { return m_outer; } }; +#else // EIGEN_TEST_EVALUATORS +enum { + SDP_AsScalarProduct, + SDP_AsCwiseProduct +}; + +template +struct sparse_diagonal_product_evaluator; + +template +struct product_evaluator, ProductTag, DiagonalShape, SparseShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public sparse_diagonal_product_evaluator +{ + typedef Product XprType; + typedef evaluator type; + typedef evaluator nestedType; + enum { CoeffReadCost = Dynamic, Flags = Rhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags + + typedef sparse_diagonal_product_evaluator Base; + product_evaluator(const XprType& xpr) : Base(xpr.rhs(), xpr.lhs().diagonal()) {} +}; + +template +struct product_evaluator, ProductTag, SparseShape, DiagonalShape, typename Lhs::Scalar, typename Rhs::Scalar> + : public sparse_diagonal_product_evaluator, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> +{ + typedef Product XprType; + typedef evaluator type; + typedef evaluator nestedType; + enum { CoeffReadCost = Dynamic, Flags = Lhs::Flags&RowMajorBit }; // FIXME CoeffReadCost & Flags + + typedef sparse_diagonal_product_evaluator, Lhs::Flags&RowMajorBit?SDP_AsCwiseProduct:SDP_AsScalarProduct> Base; + product_evaluator(const XprType& xpr) : Base(xpr.lhs(), xpr.rhs().diagonal()) {} +}; + +template +struct sparse_diagonal_product_evaluator +{ +protected: + typedef typename evaluator::InnerIterator SparseXprInnerIterator; + typedef typename SparseXprType::Scalar Scalar; + typedef typename SparseXprType::Index Index; + +public: + class InnerIterator : public SparseXprInnerIterator + { + public: + InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) + : SparseXprInnerIterator(xprEval.m_sparseXprImpl, outer), + m_coeff(xprEval.m_diagCoeffImpl.coeff(outer)) + {} + + EIGEN_STRONG_INLINE Scalar value() const { return m_coeff * SparseXprInnerIterator::value(); } + protected: + typename DiagonalCoeffType::Scalar m_coeff; + }; + + sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagonalCoeffType &diagCoeff) + : m_sparseXprImpl(sparseXpr), m_diagCoeffImpl(diagCoeff) + {} + +protected: + typename evaluator::nestedType m_sparseXprImpl; + typename evaluator::nestedType m_diagCoeffImpl; +}; + + +template +struct sparse_diagonal_product_evaluator +{ + typedef typename SparseXprType::Scalar Scalar; + typedef typename SparseXprType::Index Index; + + typedef CwiseBinaryOp, + const typename SparseXprType::ConstInnerVectorReturnType, + const DiagCoeffType> CwiseProductType; + + typedef typename evaluator::type CwiseProductEval; + typedef typename evaluator::InnerIterator CwiseProductIterator; + + class InnerIterator : public CwiseProductIterator + { + public: + InnerIterator(const sparse_diagonal_product_evaluator &xprEval, Index outer) + : CwiseProductIterator(CwiseProductEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)),0), + m_cwiseEval(xprEval.m_sparseXprNested.innerVector(outer).cwiseProduct(xprEval.m_diagCoeffNested)), + m_outer(outer) + { + ::new (static_cast(this)) CwiseProductIterator(m_cwiseEval,0); + } + + inline Index outer() const { return m_outer; } + inline Index col() const { return SparseXprType::IsRowMajor ? CwiseProductIterator::index() : m_outer; } + inline Index row() const { return SparseXprType::IsRowMajor ? m_outer : CwiseProductIterator::index(); } + + protected: + Index m_outer; + CwiseProductEval m_cwiseEval; + }; + + sparse_diagonal_product_evaluator(const SparseXprType &sparseXpr, const DiagCoeffType &diagCoeff) + : m_sparseXprNested(sparseXpr), m_diagCoeffNested(diagCoeff) + {} + +protected: + typename nested_eval::type m_sparseXprNested; + typename nested_eval::type m_diagCoeffNested; +}; + +#endif // EIGEN_TEST_EVALUATORS + + } // end namespace internal -// SparseMatrixBase functions +#ifndef EIGEN_TEST_EVALUATORS +// SparseMatrixBase functions template template const SparseDiagonalProduct @@ -190,6 +311,7 @@ SparseMatrixBase::operator*(const DiagonalBase &other) co { return SparseDiagonalProduct(this->derived(), other.derived()); } +#endif // EIGEN_TEST_EVALUATORS } // end namespace Eigen diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index 4c46008ae..c71244d3e 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -269,6 +269,7 @@ template class SparseMatrixBase : public EigenBase const typename SparseSparseProductReturnType::Type operator*(const SparseMatrixBase &other) const; +#ifndef EIGEN_TEST_EVALUATORS // sparse * diagonal template const SparseDiagonalProduct @@ -279,6 +280,19 @@ template class SparseMatrixBase : public EigenBase const SparseDiagonalProduct operator*(const DiagonalBase &lhs, const SparseMatrixBase& rhs) { return SparseDiagonalProduct(lhs.derived(), rhs.derived()); } +#else // EIGEN_TEST_EVALUATORS + // sparse * diagonal + template + const Product + operator*(const DiagonalBase &other) const + { return Product(derived(), other.derived()); } + + // diagonal * sparse + template friend + const Product + operator*(const DiagonalBase &lhs, const SparseMatrixBase& rhs) + { return Product(lhs.derived(), rhs.derived()); } +#endif // EIGEN_TEST_EVALUATORS /** dense * sparse (return a dense object unless it is an outer product) */ template friend