Fix aliasing issue in sparse matrix assignment.

(m=-m; or m=m.transpose(); with m sparse work again)
This commit is contained in:
Gael Guennebaud 2012-07-25 09:33:50 +02:00
parent 7b34b5f6f9
commit e75b1eb883
3 changed files with 26 additions and 11 deletions

View File

@ -688,7 +688,6 @@ class SparseMatrix
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DONT_INLINE SparseMatrix& operator=(const SparseMatrixBase<OtherDerived>& other) EIGEN_DONT_INLINE SparseMatrix& operator=(const SparseMatrixBase<OtherDerived>& other)
{ {
initAssignment(other.derived());
const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit); const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
if (needToTranspose) if (needToTranspose)
{ {
@ -700,40 +699,45 @@ class SparseMatrix
typedef typename internal::remove_all<OtherCopy>::type _OtherCopy; typedef typename internal::remove_all<OtherCopy>::type _OtherCopy;
OtherCopy otherCopy(other.derived()); OtherCopy otherCopy(other.derived());
Eigen::Map<Matrix<Index, Dynamic, 1> > (m_outerIndex,outerSize()).setZero(); SparseMatrix dest(other.rows(),other.cols());
Eigen::Map<Matrix<Index, Dynamic, 1> > (dest.m_outerIndex,dest.outerSize()).setZero();
// pass 1 // pass 1
// FIXME the above copy could be merged with that pass // FIXME the above copy could be merged with that pass
for (Index j=0; j<otherCopy.outerSize(); ++j) for (Index j=0; j<otherCopy.outerSize(); ++j)
for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it) for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it)
++m_outerIndex[it.index()]; ++dest.m_outerIndex[it.index()];
// prefix sum // prefix sum
Index count = 0; Index count = 0;
VectorXi positions(outerSize()); VectorXi positions(dest.outerSize());
for (Index j=0; j<outerSize(); ++j) for (Index j=0; j<dest.outerSize(); ++j)
{ {
Index tmp = m_outerIndex[j]; Index tmp = dest.m_outerIndex[j];
m_outerIndex[j] = count; dest.m_outerIndex[j] = count;
positions[j] = count; positions[j] = count;
count += tmp; count += tmp;
} }
m_outerIndex[outerSize()] = count; dest.m_outerIndex[dest.outerSize()] = count;
// alloc // alloc
m_data.resize(count); dest.m_data.resize(count);
// pass 2 // pass 2
for (Index j=0; j<otherCopy.outerSize(); ++j) for (Index j=0; j<otherCopy.outerSize(); ++j)
{ {
for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it) for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it)
{ {
Index pos = positions[it.index()]++; Index pos = positions[it.index()]++;
m_data.index(pos) = j; dest.m_data.index(pos) = j;
m_data.value(pos) = it.value(); dest.m_data.value(pos) = it.value();
} }
} }
this->swap(dest);
return *this; return *this;
} }
else else
{ {
if(other.isRValue())
initAssignment(other.derived());
// there is no special optimization // there is no special optimization
return Base::operator=(other.derived()); return Base::operator=(other.derived());
} }

View File

@ -193,6 +193,12 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
// sparse cwise* dense // sparse cwise* dense
VERIFY_IS_APPROX(m3.cwiseProduct(refM4), refM3.cwiseProduct(refM4)); VERIFY_IS_APPROX(m3.cwiseProduct(refM4), refM3.cwiseProduct(refM4));
// VERIFY_IS_APPROX(m3.cwise()/refM4, refM3.cwise()/refM4); // VERIFY_IS_APPROX(m3.cwise()/refM4, refM3.cwise()/refM4);
// test aliasing
VERIFY_IS_APPROX((m1 = -m1), (refM1 = -refM1));
VERIFY_IS_APPROX((m1 = m1.transpose()), (refM1 = refM1.transpose().eval()));
VERIFY_IS_APPROX((m1 = -m1.transpose()), (refM1 = -refM1.transpose().eval()));
VERIFY_IS_APPROX((m1 += -m1), (refM1 += -refM1));
} }
// test transpose // test transpose

View File

@ -78,6 +78,11 @@ template<typename Scalar> void sparse_vector(int rows, int cols)
VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm()); VERIFY_IS_APPROX(v1.squaredNorm(), refV1.squaredNorm());
// test aliasing
VERIFY_IS_APPROX((v1 = -v1), (refV1 = -refV1));
VERIFY_IS_APPROX((v1 = v1.transpose()), (refV1 = refV1.transpose().eval()));
VERIFY_IS_APPROX((v1 += -v1), (refV1 += -refV1));
} }
void test_sparse_vector() void test_sparse_vector()