extend unit test for SparseMatrix::prune

This commit is contained in:
Gael Guennebaud 2015-10-13 10:53:38 +02:00
parent ac22b66f1c
commit a44d91a0b2
2 changed files with 10 additions and 12 deletions

View File

@ -509,7 +509,6 @@ class SparseMatrix
void prune(const KeepFunc& keep = KeepFunc())
{
// TODO optimize the uncompressed mode to avoid moving and allocating the data twice
// TODO also implement a unit test
makeCompressed();
StorageIndex k = 0;

View File

@ -219,10 +219,10 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
refM2.setZero();
int countFalseNonZero = 0;
int countTrueNonZero = 0;
for (Index j=0; j<m2.outerSize(); ++j)
m2.reserve(VectorXi::Constant(m2.outerSize(), int(m2.innerSize())));
for (Index j=0; j<m2.cols(); ++j)
{
m2.startVec(j);
for (Index i=0; i<m2.innerSize(); ++i)
for (Index i=0; i<m2.rows(); ++i)
{
float x = internal::random<float>(0,1);
if (x<0.1)
@ -232,22 +232,21 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
else if (x<0.5)
{
countFalseNonZero++;
m2.insertBackByOuterInner(j,i) = Scalar(0);
m2.insert(i,j) = Scalar(0);
}
else
{
countTrueNonZero++;
m2.insertBackByOuterInner(j,i) = Scalar(1);
if(SparseMatrixType::IsRowMajor)
refM2(j,i) = Scalar(1);
else
refM2(i,j) = Scalar(1);
m2.insert(i,j) = Scalar(1);
refM2(i,j) = Scalar(1);
}
}
}
m2.finalize();
if(internal::random<bool>())
m2.makeCompressed();
VERIFY(countFalseNonZero+countTrueNonZero == m2.nonZeros());
VERIFY_IS_APPROX(m2, refM2);
if(countTrueNonZero>0)
VERIFY_IS_APPROX(m2, refM2);
m2.prune(Scalar(1));
VERIFY(countTrueNonZero==m2.nonZeros());
VERIFY_IS_APPROX(m2, refM2);