sparse module: much much faster transposition code

This commit is contained in:
Gael Guennebaud 2008-10-18 11:11:10 +00:00
parent 727dfa1c43
commit cfca7f71fe
2 changed files with 54 additions and 5 deletions

View File

@ -65,6 +65,7 @@ class SparseMatrix
enum {
RowMajor = SparseBase::RowMajor
};
typedef SparseMatrix<Scalar,(Flags&~RowMajorBit)|(RowMajor?RowMajorBit:0)> TransposedSparseMatrix;
int m_outerSize;
int m_innerSize;
@ -225,8 +226,7 @@ class SparseMatrix
else
{
resize(other.rows(), other.cols());
for (int j=0; j<=m_outerSize; ++j)
m_outerIndex[j] = other.m_outerIndex[j];
memcpy(m_outerIndex, other.m_outerIndex, (m_outerSize+1)*sizeof(int));
m_data = other.m_data;
}
return *this;
@ -235,9 +235,55 @@ class SparseMatrix
template<typename OtherDerived>
inline SparseMatrix& operator=(const MatrixBase<OtherDerived>& other)
{
// std::cout << "SparseMatrix& operator=(const MatrixBase<OtherDerived>& other)\n";
const bool needToTranspose = (Flags & RowMajorBit) != (OtherDerived::Flags & RowMajorBit);
if (needToTranspose)
{
// two passes algorithm:
// 1 - compute the number of coeffs per dest inner vector
// 2 - do the actual copy/eval
// Since each coeff of the rhs has to be evaluated twice, let's evauluate it if needed
typedef typename ei_nested<OtherDerived,2>::type OtherCopy;
OtherCopy otherCopy(other.derived());
typedef typename ei_cleantype<OtherCopy>::type _OtherCopy;
resize(other.rows(), other.cols());
Map<VectorXi>(m_outerIndex,outerSize()).setZero();
// pass 1
// FIXME the above copy could be merged with that pass
for (int j=0; j<otherCopy.outerSize(); ++j)
for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it)
m_outerIndex[it.index()]++;
// prefix sum
int count = 0;
VectorXi positions(outerSize());
for (int j=0; j<outerSize(); ++j)
{
int tmp = m_outerIndex[j];
m_outerIndex[j] = count;
positions[j] = count;
count += tmp;
}
m_outerIndex[outerSize()] = count;
// alloc
m_data.resize(count);
// pass 2
for (int j=0; j<otherCopy.outerSize(); ++j)
for (typename _OtherCopy::InnerIterator it(otherCopy, j); it; ++it)
{
int pos = positions[it.index()]++;
m_data.index(pos) = j;
m_data.value(pos) = it.value();
}
return *this;
}
else
{
// there is no special optimization
return SparseMatrixBase<SparseMatrix>::operator=(other.derived());
}
}
friend std::ostream & operator << (std::ostream & s, const SparseMatrix& m)
{

View File

@ -169,6 +169,9 @@ struct ei_sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
}
}
for (typename AmbiVector<Scalar>::Iterator it(tempVector); it; ++it)
if (ResultType::Flags&RowMajorBit)
res.fill(j,it.index()) = it.value();
else
res.fill(it.index(), j) = it.value();
}
res.endFill();