mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 09:23:12 +08:00
Implement evaluators for sparse * sparse with auto pruning.
This commit is contained in:
parent
441f97b2df
commit
746d2db6ed
@ -50,11 +50,11 @@ struct Sparse {};
|
|||||||
#include "src/SparseCore/SparseView.h"
|
#include "src/SparseCore/SparseView.h"
|
||||||
#include "src/SparseCore/SparseDiagonalProduct.h"
|
#include "src/SparseCore/SparseDiagonalProduct.h"
|
||||||
#include "src/SparseCore/ConservativeSparseSparseProduct.h"
|
#include "src/SparseCore/ConservativeSparseSparseProduct.h"
|
||||||
|
#include "src/SparseCore/SparseSparseProductWithPruning.h"
|
||||||
#include "src/SparseCore/SparseProduct.h"
|
#include "src/SparseCore/SparseProduct.h"
|
||||||
#ifndef EIGEN_TEST_EVALUATORS
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
#include "src/SparseCore/SparsePermutation.h"
|
#include "src/SparseCore/SparsePermutation.h"
|
||||||
#include "src/SparseCore/SparseFuzzy.h"
|
#include "src/SparseCore/SparseFuzzy.h"
|
||||||
#include "src/SparseCore/SparseSparseProductWithPruning.h"
|
|
||||||
#include "src/SparseCore/SparseDenseProduct.h"
|
#include "src/SparseCore/SparseDenseProduct.h"
|
||||||
#include "src/SparseCore/SparseTriangularView.h"
|
#include "src/SparseCore/SparseTriangularView.h"
|
||||||
#include "src/SparseCore/SparseSelfAdjointView.h"
|
#include "src/SparseCore/SparseSelfAdjointView.h"
|
||||||
|
@ -394,6 +394,9 @@ template<typename Derived> class SparseMatrixBase : public EigenBase<Derived>
|
|||||||
{ return typename internal::eval<Derived>::type(derived()); }
|
{ return typename internal::eval<Derived>::type(derived()); }
|
||||||
|
|
||||||
Scalar sum() const;
|
Scalar sum() const;
|
||||||
|
|
||||||
|
inline const SparseView<Derived>
|
||||||
|
pruned(const Scalar& reference = Scalar(0), const RealScalar& epsilon = NumTraits<Scalar>::dummy_precision()) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
||||||
|
@ -242,7 +242,37 @@ struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, ProductTag, SparseSh
|
|||||||
protected:
|
protected:
|
||||||
PlainObject m_result;
|
PlainObject m_result;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, int Options>
|
||||||
|
struct evaluator<SparseView<Product<Lhs, Rhs, Options> > >
|
||||||
|
: public evaluator<typename Product<Lhs, Rhs, DefaultProduct>::PlainObject>::type
|
||||||
|
{
|
||||||
|
typedef SparseView<Product<Lhs, Rhs, Options> > XprType;
|
||||||
|
typedef typename XprType::PlainObject PlainObject;
|
||||||
|
typedef typename evaluator<PlainObject>::type Base;
|
||||||
|
|
||||||
|
typedef evaluator type;
|
||||||
|
typedef evaluator nestedType;
|
||||||
|
|
||||||
|
evaluator(const XprType& xpr)
|
||||||
|
: m_result(xpr.rows(), xpr.cols())
|
||||||
|
{
|
||||||
|
using std::abs;
|
||||||
|
::new (static_cast<Base*>(this)) Base(m_result);
|
||||||
|
typedef typename nested_eval<Lhs,Dynamic>::type LhsNested;
|
||||||
|
typedef typename nested_eval<Rhs,Dynamic>::type RhsNested;
|
||||||
|
LhsNested lhsNested(xpr.nestedExpression().lhs());
|
||||||
|
RhsNested rhsNested(xpr.nestedExpression().rhs());
|
||||||
|
|
||||||
|
internal::sparse_sparse_product_with_pruning_selector<typename remove_all<LhsNested>::type,
|
||||||
|
typename remove_all<RhsNested>::type, PlainObject>::run(lhsNested,rhsNested,m_result,
|
||||||
|
abs(xpr.reference())*xpr.epsilon());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
PlainObject m_result;
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
#endif // EIGEN_TEST_EVALUATORS
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
@ -46,6 +46,11 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
|
|||||||
res.resize(cols, rows);
|
res.resize(cols, rows);
|
||||||
else
|
else
|
||||||
res.resize(rows, cols);
|
res.resize(rows, cols);
|
||||||
|
|
||||||
|
#ifdef EIGEN_TEST_EVALUATORS
|
||||||
|
typename evaluator<Lhs>::type lhsEval(lhs);
|
||||||
|
typename evaluator<Rhs>::type rhsEval(rhs);
|
||||||
|
#endif
|
||||||
|
|
||||||
res.reserve(estimated_nnz_prod);
|
res.reserve(estimated_nnz_prod);
|
||||||
double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
|
double ratioColRes = double(estimated_nnz_prod)/double(lhs.rows()*rhs.cols());
|
||||||
@ -56,12 +61,20 @@ static void sparse_sparse_product_with_pruning_impl(const Lhs& lhs, const Rhs& r
|
|||||||
// let's do a more accurate determination of the nnz ratio for the current column j of res
|
// let's do a more accurate determination of the nnz ratio for the current column j of res
|
||||||
tempVector.init(ratioColRes);
|
tempVector.init(ratioColRes);
|
||||||
tempVector.setZero();
|
tempVector.setZero();
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
|
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
|
||||||
|
#else
|
||||||
|
for (typename evaluator<Rhs>::InnerIterator rhsIt(rhsEval, j); rhsIt; ++rhsIt)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
|
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
|
||||||
tempVector.restart();
|
tempVector.restart();
|
||||||
Scalar x = rhsIt.value();
|
Scalar x = rhsIt.value();
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
|
for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
|
||||||
|
#else
|
||||||
|
for (typename evaluator<Lhs>::InnerIterator lhsIt(lhsEval, rhsIt.index()); lhsIt; ++lhsIt)
|
||||||
|
#endif
|
||||||
{
|
{
|
||||||
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
|
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
|
||||||
}
|
}
|
||||||
@ -140,8 +153,58 @@ struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,R
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#ifndef EIGEN_TEST_EVALUATORS
|
||||||
// NOTE the 2 others cases (col row *) must never occur since they are caught
|
// NOTE the 2 others cases (col row *) must never occur since they are caught
|
||||||
// by ProductReturnType which transforms it to (col col *) by evaluating rhs.
|
// by ProductReturnType which transforms it to (col col *) by evaluating rhs.
|
||||||
|
#else
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
|
||||||
|
{
|
||||||
|
typedef typename ResultType::RealScalar RealScalar;
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||||
|
{
|
||||||
|
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixLhs;
|
||||||
|
RowMajorMatrixLhs rowLhs(lhs);
|
||||||
|
sparse_sparse_product_with_pruning_selector<RowMajorMatrixLhs,Rhs,ResultType,RowMajor,RowMajor>(rowLhs,rhs,res,tolerance);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
|
||||||
|
{
|
||||||
|
typedef typename ResultType::RealScalar RealScalar;
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||||
|
{
|
||||||
|
typedef SparseMatrix<typename ResultType::Scalar,RowMajor,typename Lhs::Index> RowMajorMatrixRhs;
|
||||||
|
RowMajorMatrixRhs rowRhs(rhs);
|
||||||
|
sparse_sparse_product_with_pruning_selector<Lhs,RowMajorMatrixRhs,ResultType,RowMajor,RowMajor,RowMajor>(lhs,rowRhs,res,tolerance);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
|
||||||
|
{
|
||||||
|
typedef typename ResultType::RealScalar RealScalar;
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||||
|
{
|
||||||
|
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixRhs;
|
||||||
|
ColMajorMatrixRhs colRhs(rhs);
|
||||||
|
internal::sparse_sparse_product_with_pruning_impl<Lhs,ColMajorMatrixRhs,ResultType>(lhs, colRhs, res, tolerance);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Lhs, typename Rhs, typename ResultType>
|
||||||
|
struct sparse_sparse_product_with_pruning_selector<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
|
||||||
|
{
|
||||||
|
typedef typename ResultType::RealScalar RealScalar;
|
||||||
|
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res, const RealScalar& tolerance)
|
||||||
|
{
|
||||||
|
typedef SparseMatrix<typename ResultType::Scalar,ColMajor,typename Lhs::Index> ColMajorMatrixLhs;
|
||||||
|
ColMajorMatrixLhs colLhs(lhs);
|
||||||
|
internal::sparse_sparse_product_with_pruning_impl<ColMajorMatrixLhs,Rhs,ResultType>(colLhs, rhs, res, tolerance);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
|
@ -233,10 +233,30 @@ struct unary_evaluator<SparseView<ArgType>, IndexBased>
|
|||||||
#endif // EIGEN_TEST_EVALUATORS
|
#endif // EIGEN_TEST_EVALUATORS
|
||||||
|
|
||||||
template<typename Derived>
|
template<typename Derived>
|
||||||
const SparseView<Derived> MatrixBase<Derived>::sparseView(const Scalar& m_reference,
|
const SparseView<Derived> MatrixBase<Derived>::sparseView(const Scalar& reference,
|
||||||
const typename NumTraits<Scalar>::Real& m_epsilon) const
|
const typename NumTraits<Scalar>::Real& epsilon) const
|
||||||
{
|
{
|
||||||
return SparseView<Derived>(derived(), m_reference, m_epsilon);
|
return SparseView<Derived>(derived(), reference, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \returns an expression of \c *this with values smaller than
|
||||||
|
* \a reference * \a epsilon are removed.
|
||||||
|
*
|
||||||
|
* This method is typically used in conjunction with the product of two sparse matrices
|
||||||
|
* to automatically prune the smallest values as follows:
|
||||||
|
* \code
|
||||||
|
* C = (A*B).pruned(); // suppress numerical zeros (exact)
|
||||||
|
* C = (A*B).pruned(ref);
|
||||||
|
* C = (A*B).pruned(ref,epsilon);
|
||||||
|
* \endcode
|
||||||
|
* where \c ref is a meaningful non zero reference value.
|
||||||
|
* */
|
||||||
|
template<typename Derived>
|
||||||
|
const SparseView<Derived>
|
||||||
|
SparseMatrixBase<Derived>::pruned(const Scalar& reference,
|
||||||
|
const RealScalar& epsilon) const
|
||||||
|
{
|
||||||
|
return SparseView<Derived>(derived(), reference, epsilon);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
Loading…
x
Reference in New Issue
Block a user