diff --git a/Eigen/src/SparseCore/SparseMatrix.h b/Eigen/src/SparseCore/SparseMatrix.h index 22a6bd803..5e2b14554 100644 --- a/Eigen/src/SparseCore/SparseMatrix.h +++ b/Eigen/src/SparseCore/SparseMatrix.h @@ -437,11 +437,13 @@ class SparseMatrix template void setFromTriplets(const InputIterators& begin, const InputIterators& end); - template - void setFromTriplets(const InputIterators& begin, const InputIterators& end); + template + void setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func); + + void sumupDuplicates() { collapseDuplicates(internal::scalar_sum_op()); } template - void sumupDuplicates(); + void collapseDuplicates(DupFunctor dup_func = DupFunctor()); //--- @@ -894,9 +896,8 @@ private: namespace internal { template -void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, int Options = 0) +void set_from_triplets(const InputIterator& begin, const InputIterator& end, SparseMatrixType& mat, DupFunctor dup_func) { - EIGEN_UNUSED_VARIABLE(Options); enum { IsRowMajor = SparseMatrixType::IsRowMajor }; typedef typename SparseMatrixType::Scalar Scalar; typedef typename SparseMatrixType::StorageIndex StorageIndex; @@ -919,7 +920,7 @@ void set_from_triplets(const InputIterator& begin, const InputIterator& end, Spa trMat.insertBackUncompressed(it->row(),it->col()) = it->value(); // pass 3: - trMat.template sumupDuplicates(); + trMat.collapseDuplicates(dup_func); } // pass 4: transposed copy -> implicit sorting @@ -970,25 +971,29 @@ template template void SparseMatrix::setFromTriplets(const InputIterators& begin, const InputIterators& end) { - internal::set_from_triplets, internal::scalar_sum_op >(begin, end, *this); + internal::set_from_triplets >(begin, end, *this, internal::scalar_sum_op()); } -/** The same as setFromTriplets but when duplicates are met the functor \a DupFunctor is applied: +/** The same as setFromTriplets but when duplicates are met the functor \a dup_func is applied: * \code - * value = DupFunctor()(OldValue, NewValue) + * value = dup_func(OldValue, NewValue) * \endcode - */ + * Here is a C++11 example keeping the latest entry only: + * \code + * mat.setFromTriplets(triplets.begin(), triplets.end(), [] (const Scalar&,const Scalar &b) { return b; }); + * \endcode + */ template -template -void SparseMatrix::setFromTriplets(const InputIterators& begin, const InputIterators& end) +template +void SparseMatrix::setFromTriplets(const InputIterators& begin, const InputIterators& end, DupFunctor dup_func) { - internal::set_from_triplets, DupFunctor>(begin, end, *this); + internal::set_from_triplets, DupFunctor>(begin, end, *this, dup_func); } /** \internal */ template template -void SparseMatrix::sumupDuplicates() +void SparseMatrix::collapseDuplicates(DupFunctor dup_func) { eigen_assert(!isCompressed()); // TODO, in practice we should be able to use m_innerNonZeros for that task @@ -1006,7 +1011,7 @@ void SparseMatrix::sumupDuplicates() if(wi(i)>=start) { // we already meet this entry => accumulate it - m_data.value(wi(i)) = DupFunctor()(m_data.value(wi(i)), m_data.value(k)); + m_data.value(wi(i)) = dup_func(m_data.value(wi(i)), m_data.value(k)); } else { diff --git a/test/sparse_basic.cpp b/test/sparse_basic.cpp index 95bbfab0e..993f7840c 100644 --- a/test/sparse_basic.cpp +++ b/test/sparse_basic.cpp @@ -258,19 +258,33 @@ template void sparse_basic(const SparseMatrixType& re std::vector triplets; Index ntriplets = rows*cols; triplets.reserve(ntriplets); - DenseMatrix refMat(rows,cols); - refMat.setZero(); + DenseMatrix refMat_sum = DenseMatrix::Zero(rows,cols); + DenseMatrix refMat_prod = DenseMatrix::Zero(rows,cols); + DenseMatrix refMat_last = DenseMatrix::Zero(rows,cols); + for(Index i=0;i(0,StorageIndex(rows-1)); StorageIndex c = internal::random(0,StorageIndex(cols-1)); Scalar v = internal::random(); triplets.push_back(TripletType(r,c,v)); - refMat(r,c) += v; + refMat_sum(r,c) += v; + if(std::abs(refMat_prod(r,c))==0) + refMat_prod(r,c) = v; + else + refMat_prod(r,c) *= v; + refMat_last(r,c) = v; } SparseMatrixType m(rows,cols); m.setFromTriplets(triplets.begin(), triplets.end()); - VERIFY_IS_APPROX(m, refMat); + VERIFY_IS_APPROX(m, refMat_sum); + + m.setFromTriplets(triplets.begin(), triplets.end(), std::multiplies()); + VERIFY_IS_APPROX(m, refMat_prod); +#if (defined(__cplusplus) && __cplusplus >= 201103L) + m.setFromTriplets(triplets.begin(), triplets.end(), [] (Scalar,Scalar b) { return b; }); + VERIFY_IS_APPROX(m, refMat_last); +#endif } // test Map