diff --git a/Eigen/SparseCore b/Eigen/SparseCore index 0c91c3b59..f74df3038 100644 --- a/Eigen/SparseCore +++ b/Eigen/SparseCore @@ -50,11 +50,11 @@ struct Sparse {}; #include "src/SparseCore/SparseView.h" #include "src/SparseCore/SparseDiagonalProduct.h" #include "src/SparseCore/ConservativeSparseSparseProduct.h" +#include "src/SparseCore/SparseSparseProductWithPruning.h" #include "src/SparseCore/SparseProduct.h" #ifndef EIGEN_TEST_EVALUATORS #include "src/SparseCore/SparsePermutation.h" #include "src/SparseCore/SparseFuzzy.h" -#include "src/SparseCore/SparseSparseProductWithPruning.h" #include "src/SparseCore/SparseDenseProduct.h" #include "src/SparseCore/SparseTriangularView.h" #include "src/SparseCore/SparseSelfAdjointView.h" diff --git a/Eigen/src/SparseCore/SparseMatrixBase.h b/Eigen/src/SparseCore/SparseMatrixBase.h index cebf3990d..3a81916fb 100644 --- a/Eigen/src/SparseCore/SparseMatrixBase.h +++ b/Eigen/src/SparseCore/SparseMatrixBase.h @@ -394,6 +394,9 @@ template class SparseMatrixBase : public EigenBase { return typename internal::eval::type(derived()); } Scalar sum() const; + + inline const SparseView + pruned(const Scalar& reference = Scalar(0), const RealScalar& epsilon = NumTraits::dummy_precision()) const; protected: diff --git a/Eigen/src/SparseCore/SparseProduct.h b/Eigen/src/SparseCore/SparseProduct.h index 52c452f92..8b9578836 100644 --- a/Eigen/src/SparseCore/SparseProduct.h +++ b/Eigen/src/SparseCore/SparseProduct.h @@ -242,7 +242,37 @@ struct product_evaluator, ProductTag, SparseSh protected: PlainObject m_result; }; + +template +struct evaluator > > + : public evaluator::PlainObject>::type +{ + typedef SparseView > XprType; + typedef typename XprType::PlainObject PlainObject; + typedef typename evaluator::type Base; + typedef evaluator type; + typedef evaluator nestedType; + + evaluator(const XprType& xpr) + : m_result(xpr.rows(), xpr.cols()) + { + using std::abs; + ::new (static_cast(this)) Base(m_result); + typedef typename nested_eval::type LhsNested; + typedef typename nested_eval::type RhsNested; + LhsNested lhsNested(xpr.nestedExpression().lhs()); + RhsNested rhsNested(xpr.nestedExpression().rhs()); + + internal::sparse_sparse_product_with_pruning_selector::type, + typename remove_all::type, PlainObject>::run(lhsNested,rhsNested,m_result, + abs(xpr.reference())*xpr.epsilon()); + } + +protected: + PlainObject m_result; +}; + } // end namespace internal #endif // EIGEN_TEST_EVALUATORS diff --git a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h index fcc18f5c9..c33ec6bfd 100644 --- a/Eigen/src/SparseCore/SparseSparseProductWithPruning.h +++ b/Eigen/src/SparseCore/SparseSparseProductWithPruning.h @@ -46,6 +46,11 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r res.resize(cols, rows); else res.resize(rows, cols); + + #ifdef EIGEN_TEST_EVALUATORS + typename evaluator::type lhsEval(lhs); + typename evaluator::type rhsEval(rhs); + #endif res.reserve(estimated_nnz_prod); double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols()); @@ -56,12 +61,20 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r // let's do a more accurate determination of the nnz ratio for the current column j of res tempVector.init(ratioColRes); tempVector.setZero(); +#ifndef EIGEN_TEST_EVALUATORS for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt) +#else + for (typename evaluator::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt) +#endif { // FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index()) tempVector.restart(); Scalar x = rhsIt.value(); +#ifndef EIGEN_TEST_EVALUATORS for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt) +#else + for (typename evaluator::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt) +#endif { tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x; } @@ -140,8 +153,58 @@ struct sparse_sparse_product_with_pruning_selector +struct sparse_sparse_product_with_pruning_selector +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix RowMajorMatrixLhs; + RowMajorMatrixLhs rowLhs(lhs); + sparse_sparse_product_with_pruning_selector(rowLhs,rhs,res,tolerance); + } +}; + +template +struct sparse_sparse_product_with_pruning_selector +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix RowMajorMatrixRhs; + RowMajorMatrixRhs rowRhs(rhs); + sparse_sparse_product_with_pruning_selector(lhs,rowRhs,res,tolerance); + } +}; + +template +struct sparse_sparse_product_with_pruning_selector +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix ColMajorMatrixRhs; + ColMajorMatrixRhs colRhs(rhs); + internal::sparse_sparse_product_with_pruning_impl(lhs, colRhs, res, tolerance); + } +}; + +template +struct sparse_sparse_product_with_pruning_selector +{ + typedef typename ResultType::RealScalar RealScalar; + static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance) + { + typedef SparseMatrix ColMajorMatrixLhs; + ColMajorMatrixLhs colLhs(lhs); + internal::sparse_sparse_product_with_pruning_impl(colLhs, rhs, res, tolerance); + } +}; +#endif } // end namespace internal diff --git a/Eigen/src/SparseCore/SparseView.h b/Eigen/src/SparseCore/SparseView.h index 96d0a849c..7bffbb9cd 100644 --- a/Eigen/src/SparseCore/SparseView.h +++ b/Eigen/src/SparseCore/SparseView.h @@ -233,10 +233,30 @@ struct unary_evaluator, IndexBased> #endif // EIGEN_TEST_EVALUATORS template -const SparseView MatrixBase::sparseView(const Scalar& m_reference, - const typename NumTraits::Real& m_epsilon) const +const SparseView MatrixBase::sparseView(const Scalar& reference, + const typename NumTraits::Real& epsilon) const { - return SparseView(derived(), m_reference, m_epsilon); + return SparseView(derived(), reference, epsilon); +} + +/** \returns an expression of \c *this with values smaller than + * \a reference * \a epsilon are removed. + * + * This method is typically used in conjunction with the product of two sparse matrices + * to automatically prune the smallest values as follows: + * \code + * C = (A*B).pruned(); // suppress numerical zeros (exact) + * C = (A*B).pruned(ref); + * C = (A*B).pruned(ref,epsilon); + * \endcode + * where \c ref is a meaningful non zero reference value. + * */ +template +const SparseView +SparseMatrixBase::pruned(const Scalar& reference, + const RealScalar& epsilon) const +{ + return SparseView(derived(), reference, epsilon); } } // end namespace Eigen