mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-19 19:34:29 +08:00
Update custom setFromTripplets API to allow passing a functor object, and add a collapseDuplicates method to cleanup the API. Also add respective unit test
This commit is contained in:
parent
b9d81c9150
commit
b4c79ee1d3
@ -437,11 +437,13 @@ class SparseMatrix
|
|||||||
template<typename InputIterators>
|
template<typename InputIterators>
|
||||||
void setFromTriplets(const InputIterators& begin, const InputIterators& end);
|
void setFromTriplets(const InputIterators& begin, const InputIterators& end);
|
||||||
|
|
||||||
template<typename DupFunctor, typename InputIterators>
|
template<typename InputIterators,typename DupFunctor>
|
||||||
void setFromTriplets(const InputIterators& begin, const InputIterators& end);
|
void setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func);
|
||||||
|
|
||||||
|
void sumupDuplicates() { collapseDuplicates(internal::scalar_sum_op<Scalar>()); }
|
||||||
|
|
||||||
template<typename DupFunctor>
|
template<typename DupFunctor>
|
||||||
void sumupDuplicates();
|
void collapseDuplicates(DupFunctor dup_func = DupFunctor());
|
||||||
|
|
||||||
//---
|
//---
|
||||||
|
|
||||||
@ -894,9 +896,8 @@ private:
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<typename InputIterator, typename SparseMatrixType, typename DupFunctor>
|
template<typename InputIterator, typename SparseMatrixType, typename DupFunctor>
|
||||||
void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, int Options = 0)
|
void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, DupFunctor dup_func)
|
||||||
{
|
{
|
||||||
EIGEN_UNUSED_VARIABLE(Options);
|
|
||||||
enum { IsRowMajor = SparseMatrixType::IsRowMajor };
|
enum { IsRowMajor = SparseMatrixType::IsRowMajor };
|
||||||
typedef typename SparseMatrixType::Scalar Scalar;
|
typedef typename SparseMatrixType::Scalar Scalar;
|
||||||
typedef typename SparseMatrixType::StorageIndex StorageIndex;
|
typedef typename SparseMatrixType::StorageIndex StorageIndex;
|
||||||
@ -919,7 +920,7 @@ void set_from_triplets(const InputIterator& begin, const InputIterator& end, Spa
|
|||||||
trMat.insertBackUncompressed(it->row(),it->col()) = it->value();
|
trMat.insertBackUncompressed(it->row(),it->col()) = it->value();
|
||||||
|
|
||||||
// pass 3:
|
// pass 3:
|
||||||
trMat.template sumupDuplicates<DupFunctor>();
|
trMat.collapseDuplicates(dup_func);
|
||||||
}
|
}
|
||||||
|
|
||||||
// pass 4: transposed copy -> implicit sorting
|
// pass 4: transposed copy -> implicit sorting
|
||||||
@ -970,25 +971,29 @@ template<typename Scalar, int _Options, typename _Index>
|
|||||||
template<typename InputIterators>
|
template<typename InputIterators>
|
||||||
void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end)
|
void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end)
|
||||||
{
|
{
|
||||||
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, internal::scalar_sum_op<Scalar> >(begin, end, *this);
|
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index> >(begin, end, *this, internal::scalar_sum_op<Scalar>());
|
||||||
}
|
}
|
||||||
|
|
||||||
/** The same as setFromTriplets but when duplicates are met the functor \a DupFunctor is applied:
|
/** The same as setFromTriplets but when duplicates are met the functor \a dup_func is applied:
|
||||||
* \code
|
* \code
|
||||||
* value = DupFunctor()(OldValue, NewValue)
|
* value = dup_func(OldValue, NewValue)
|
||||||
* \endcode
|
* \endcode
|
||||||
*/
|
* Here is a C++11 example keeping the latest entry only:
|
||||||
|
* \code
|
||||||
|
* mat.setFromTriplets(triplets.begin(), triplets.end(), [] (const Scalar&,const Scalar &b) { return b; });
|
||||||
|
* \endcode
|
||||||
|
*/
|
||||||
template<typename Scalar, int _Options, typename _Index>
|
template<typename Scalar, int _Options, typename _Index>
|
||||||
template<typename DupFunctor, typename InputIterators>
|
template<typename InputIterators,typename DupFunctor>
|
||||||
void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end)
|
void SparseMatrix<Scalar,_Options,_Index>::setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func)
|
||||||
{
|
{
|
||||||
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, DupFunctor>(begin, end, *this);
|
internal::set_from_triplets<InputIterators, SparseMatrix<Scalar,_Options,_Index>, DupFunctor>(begin, end, *this, dup_func);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \internal */
|
/** \internal */
|
||||||
template<typename Scalar, int _Options, typename _Index>
|
template<typename Scalar, int _Options, typename _Index>
|
||||||
template<typename DupFunctor>
|
template<typename DupFunctor>
|
||||||
void SparseMatrix<Scalar,_Options,_Index>::sumupDuplicates()
|
void SparseMatrix<Scalar,_Options,_Index>::collapseDuplicates(DupFunctor dup_func)
|
||||||
{
|
{
|
||||||
eigen_assert(!isCompressed());
|
eigen_assert(!isCompressed());
|
||||||
// TODO, in practice we should be able to use m_innerNonZeros for that task
|
// TODO, in practice we should be able to use m_innerNonZeros for that task
|
||||||
@ -1006,7 +1011,7 @@ void SparseMatrix<Scalar,_Options,_Index>::sumupDuplicates()
|
|||||||
if(wi(i)>=start)
|
if(wi(i)>=start)
|
||||||
{
|
{
|
||||||
// we already meet this entry => accumulate it
|
// we already meet this entry => accumulate it
|
||||||
m_data.value(wi(i)) = DupFunctor()(m_data.value(wi(i)), m_data.value(k));
|
m_data.value(wi(i)) = dup_func(m_data.value(wi(i)), m_data.value(k));
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -258,19 +258,33 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
|
|||||||
std::vector<TripletType> triplets;
|
std::vector<TripletType> triplets;
|
||||||
Index ntriplets = rows*cols;
|
Index ntriplets = rows*cols;
|
||||||
triplets.reserve(ntriplets);
|
triplets.reserve(ntriplets);
|
||||||
DenseMatrix refMat(rows,cols);
|
DenseMatrix refMat_sum = DenseMatrix::Zero(rows,cols);
|
||||||
refMat.setZero();
|
DenseMatrix refMat_prod = DenseMatrix::Zero(rows,cols);
|
||||||
|
DenseMatrix refMat_last = DenseMatrix::Zero(rows,cols);
|
||||||
|
|
||||||
for(Index i=0;i<ntriplets;++i)
|
for(Index i=0;i<ntriplets;++i)
|
||||||
{
|
{
|
||||||
StorageIndex r = internal::random<StorageIndex>(0,StorageIndex(rows-1));
|
StorageIndex r = internal::random<StorageIndex>(0,StorageIndex(rows-1));
|
||||||
StorageIndex c = internal::random<StorageIndex>(0,StorageIndex(cols-1));
|
StorageIndex c = internal::random<StorageIndex>(0,StorageIndex(cols-1));
|
||||||
Scalar v = internal::random<Scalar>();
|
Scalar v = internal::random<Scalar>();
|
||||||
triplets.push_back(TripletType(r,c,v));
|
triplets.push_back(TripletType(r,c,v));
|
||||||
refMat(r,c) += v;
|
refMat_sum(r,c) += v;
|
||||||
|
if(std::abs(refMat_prod(r,c))==0)
|
||||||
|
refMat_prod(r,c) = v;
|
||||||
|
else
|
||||||
|
refMat_prod(r,c) *= v;
|
||||||
|
refMat_last(r,c) = v;
|
||||||
}
|
}
|
||||||
SparseMatrixType m(rows,cols);
|
SparseMatrixType m(rows,cols);
|
||||||
m.setFromTriplets(triplets.begin(), triplets.end());
|
m.setFromTriplets(triplets.begin(), triplets.end());
|
||||||
VERIFY_IS_APPROX(m, refMat);
|
VERIFY_IS_APPROX(m, refMat_sum);
|
||||||
|
|
||||||
|
m.setFromTriplets(triplets.begin(), triplets.end(), std::multiplies<Scalar>());
|
||||||
|
VERIFY_IS_APPROX(m, refMat_prod);
|
||||||
|
#if (defined(__cplusplus) && __cplusplus >= 201103L)
|
||||||
|
m.setFromTriplets(triplets.begin(), triplets.end(), [] (Scalar,Scalar b) { return b; });
|
||||||
|
VERIFY_IS_APPROX(m, refMat_last);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// test Map
|
// test Map
|
||||||
|
Loading…
x
Reference in New Issue
Block a user