sparse module: much much faster transposition code

This commit is contained in:
Gael Guennebaud 2008-10-18 11:11:10 +00:00
parent 727dfa1c43
commit cfca7f71fe
2 changed files with 54 additions and 5 deletions

View File

@ -65,6 +65,7 @@ class SparseMatrix
enum { enum {
RowMajor = SparseBase::RowMajor RowMajor = SparseBase::RowMajor
}; };
typedef SparseMatrix<Scalar,(Flags&~RowMajorBit)|(RowMajor?RowMajorBit:0)> TransposedSparseMatrix;
int m_outerSize; int m_outerSize;
int m_innerSize; int m_innerSize;
@ -225,8 +226,7 @@ class SparseMatrix
else else
{ {
resize(other.rows(), other.cols()); resize(other.rows(), other.cols());
for (int j=0; j<=m_outerSize; ++j) memcpy(m_outerIndex, other.m_outerIndex, (m_outerSize+1)*sizeof(int));
m_outerIndex[j] = other.m_outerIndex[j];
m_data = other.m_data; m_data = other.m_data;
} }
return *this; return *this;
@ -235,8 +235,54 @@ class SparseMatrix
template<typename OtherDerived> template<typename OtherDerived>
inline SparseMatrix& operator=(const MatrixBase<OtherDerived>& other) inline SparseMatrix& operator=(const MatrixBase<OtherDerived>& other)
{ {
// std::cout << "SparseMatrix& operator=(const MatrixBase<OtherDerived>& other)\n"; const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
return SparseMatrixBase<SparseMatrix>::operator=(other.derived()); if (needToTranspose)
{
// two passes algorithm:
// 1 - compute the number of coeffs per dest inner vector
// 2 - do the actual copy/eval
// Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed
typedef typename ei_nested<OtherDerived,2>::type OtherCopy;
OtherCopy otherCopy(other.derived());
typedef typename ei_cleantype<OtherCopy>::type _OtherCopy;
resize(other.rows(), other.cols());
Map<VectorXi>(m_outerIndex,outerSize()).setZero();
// pass 1
// FIXME the above copy could be merged with that pass
for (int j=0; j<otherCopy.outerSize(); ++j)
for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it)
m_outerIndex[it.index()]++;
// prefix sum
int count = 0;
VectorXi positions(outerSize());
for (int j=0; j<outerSize(); ++j)
{
int tmp = m_outerIndex[j];
m_outerIndex[j] = count;
positions[j] = count;
count += tmp;
}
m_outerIndex[outerSize()] = count;
// alloc
m_data.resize(count);
// pass 2
for (int j=0; j<otherCopy.outerSize(); ++j)
for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it)
{
int pos = positions[it.index()]++;
m_data.index(pos) = j;
m_data.value(pos) = it.value();
}
return *this;
}
else
{
// there is no special optimization
return SparseMatrixBase<SparseMatrix>::operator=(other.derived());
}
} }
friend std::ostream & operator << (std::ostream & s, const SparseMatrix& m) friend std::ostream & operator << (std::ostream & s, const SparseMatrix& m)

View File

@ -169,7 +169,10 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
} }
} }
for (typename AmbiVector<Scalar>::Iterator it(tempVector); it; ++it) for (typename AmbiVector<Scalar>::Iterator it(tempVector); it; ++it)
res.fill(it.index(), j) = it.value(); if (ResultType::Flags&RowMajorBit)
res.fill(j,it.index()) = it.value();
else
res.fill(it.index(), j) = it.value();
} }
res.endFill(); res.endFill();
} }