fix issue 135 (SparseBlock::operator= for SparseMatrix)

This commit is contained in:
Gael Guennebaud 2010-06-14 16:26:33 +02:00
parent 2d65f5d3cd
commit 3cabd0c417
2 changed files with 70 additions and 12 deletions

View File

@ -201,8 +201,8 @@ class SparseInnerVectorSet<DynamicSparseMatrix<_Scalar, _Options>, Size>
* specialisation for SparseMatrix
***************************************************************************/
template<typename _Scalar, int _Options, int Size>
class SparseInnerVectorSet<SparseMatrix<_Scalar, _Options>, Size>
template<typename _Scalar, int _Options, typename _Index, int Size>
class SparseInnerVectorSet<SparseMatrix<_Scalar, _Options, _Index>, Size>
: public SparseMatrixBase<SparseInnerVectorSet<SparseMatrix<_Scalar, _Options>, Size> >
{
typedef SparseMatrix<_Scalar, _Options> MatrixType;
@ -239,21 +239,74 @@ class SparseInnerVectorSet<SparseMatrix<_Scalar, _Options>, Size>
template<typename OtherDerived>
inline SparseInnerVectorSet& operator=(const SparseMatrixBase<OtherDerived>& other)
{
if (IsRowMajor != ((OtherDerived::Flags&RowMajorBit)==RowMajorBit))
typedef typename ei_cleantype<typename MatrixType::Nested>::type _NestedMatrixType;
_NestedMatrixType& matrix = const_cast<_NestedMatrixType&>(m_matrix);;
// This assignement is slow if this vector set not empty
// and/or it is not at the end of the nonzeros of the underlying matrix.
// 1 - eval to a temporary to avoid transposition and/or aliasing issues
SparseMatrix<Scalar, IsRowMajor ? RowMajor : ColMajor, Index> tmp(other);
// 2 - let's check whether there is enough allocated memory
Index nnz = tmp.nonZeros();
Index nnz_previous = nonZeros();
Index free_size = matrix.data().allocatedSize() - nnz_previous;
std::size_t nnz_head = m_outerStart==0 ? 0 : matrix._outerIndexPtr()[m_outerStart];
std::size_t tail = m_matrix._outerIndexPtr()[m_outerStart+m_outerSize.value()];
std::size_t nnz_tail = matrix.nonZeros() - tail;
if(nnz>free_size)
{
// need to transpose => perform a block evaluation followed by a big swap
DynamicSparseMatrix<Scalar,IsRowMajor?RowMajorBit:0> aux(other);
*this = aux.markAsRValue();
// realloc manually to reduce copies
typename MatrixType::Storage newdata(m_matrix.nonZeros() - nnz_previous + nnz);
std::memcpy(&newdata.value(0), &m_matrix.data().value(0), nnz_head*sizeof(Scalar));
std::memcpy(&newdata.index(0), &m_matrix.data().index(0), nnz_head*sizeof(Index));
std::memcpy(&newdata.value(nnz_head), &tmp.data().value(0), nnz*sizeof(Scalar));
std::memcpy(&newdata.index(nnz_head), &tmp.data().index(0), nnz*sizeof(Index));
std::memcpy(&newdata.value(nnz_head+nnz), &matrix.data().value(tail), nnz_tail*sizeof(Scalar));
std::memcpy(&newdata.index(nnz_head+nnz), &matrix.data().index(tail), nnz_tail*sizeof(Index));
matrix.data().swap(newdata);
}
else
{
// evaluate/copy vector per vector
for (Index j=0; j<m_outerSize.value(); ++j)
// no need to realloc, simply copy the tail at its respective position and insert tmp
matrix.data().resize(nnz_head + nnz + nnz_tail);
if(nnz<nnz_previous)
{
SparseVector<Scalar,IsRowMajor ? RowMajorBit : 0> aux(other.innerVector(j));
m_matrix.const_cast_derived()._data()[m_outerStart+j].swap(aux._data());
std::memcpy(&matrix.data().value(nnz_head+nnz), &matrix.data().value(tail), nnz_tail*sizeof(Scalar));
std::memcpy(&matrix.data().index(nnz_head+nnz), &matrix.data().index(tail), nnz_tail*sizeof(Index));
}
else
{
for(Index i=nnz_tail-1; i>=0; --i)
{
matrix.data().value(nnz_head+nnz+i) = matrix.data().value(tail+i);
matrix.data().index(nnz_head+nnz+i) = matrix.data().index(tail+i);
}
}
std::memcpy(&matrix.data().value(nnz_head), &tmp.data().value(0), nnz*sizeof(Scalar));
std::memcpy(&matrix.data().index(nnz_head), &tmp.data().index(0), nnz*sizeof(Index));
}
// update outer index pointers
Index id = nnz_head;
for(Index k=1; k<m_outerSize.value(); ++k)
{
matrix._outerIndexPtr()[m_outerStart+k] = id;
id += tmp.innerVector(k).nonZeros();
}
std::ptrdiff_t offset = nnz - nnz_previous;
for(Index k = m_outerStart + m_outerSize.value(); k<=matrix.outerSize(); ++k)
{
matrix._outerIndexPtr()[k] += offset;
}
return *this;
}
@ -279,8 +332,9 @@ class SparseInnerVectorSet<SparseMatrix<_Scalar, _Options>, Size>
Index nonZeros() const
{
return size_t(m_matrix._outerIndexPtr()[m_outerStart+m_outerSize.value()])
- size_t(m_matrix._outerIndexPtr()[m_outerStart]); }
return std::size_t(m_matrix._outerIndexPtr()[m_outerStart+m_outerSize.value()])
- std::size_t(m_matrix._outerIndexPtr()[m_outerStart]);
}
const Scalar& lastCoeff() const
{

View File

@ -74,6 +74,7 @@ class SparseMatrix
typedef MappedSparseMatrix<Scalar,Flags> Map;
using Base::IsRowMajor;
typedef CompressedStorage<Scalar,Index> Storage;
protected:
@ -102,6 +103,9 @@ class SparseMatrix
inline const Index* _outerIndexPtr() const { return m_outerIndex; }
inline Index* _outerIndexPtr() { return m_outerIndex; }
inline Storage& data() { return m_data; }
inline const Storage& data() const { return m_data; }
inline Scalar coeff(Index row, Index col) const
{
const Index outer = IsRowMajor ? row : col;