fix bug #341: trisove on MappedSparseMatrix

This commit is contained in:
Gael Guennebaud 2011-12-04 14:57:43 +01:00
parent 9353bbac4a
commit 9bd902ed9c
2 changed files with 39 additions and 6 deletions

View File

@ -111,6 +111,7 @@ class MappedSparseMatrix
}
class InnerIterator;
class ReverseInnerIterator;
/** \returns the number of non zero coefficients */
inline Index nonZeros() const { return m_nnz; }
@ -136,12 +137,6 @@ class MappedSparseMatrix<Scalar,_Flags,_Index>::InnerIterator
m_end(mat.outerIndexPtr()[outer+1])
{}
template<unsigned int Added, unsigned int Removed>
InnerIterator(const Flagged<MappedSparseMatrix,Added,Removed>& mat, Index outer)
: m_matrix(mat._expression()), m_id(m_matrix.outerIndexPtr()[outer]),
m_start(m_id), m_end(m_matrix.outerIndexPtr()[outer+1])
{}
inline InnerIterator& operator++() { m_id++; return *this; }
inline Scalar value() const { return m_matrix.valuePtr()[m_id]; }
@ -161,4 +156,35 @@ class MappedSparseMatrix<Scalar,_Flags,_Index>::InnerIterator
const Index m_end;
};
template<typename Scalar, int _Flags, typename _Index>
class MappedSparseMatrix<Scalar,_Flags,_Index>::ReverseInnerIterator
{
public:
ReverseInnerIterator(const MappedSparseMatrix& mat, Index outer)
: m_matrix(mat),
m_outer(outer),
m_id(mat.outerIndexPtr()[outer+1]),
m_start(mat.outerIndexPtr()[outer]),
m_end(m_id)
{}
inline ReverseInnerIterator& operator--() { m_id--; return *this; }
inline Scalar value() const { return m_matrix.valuePtr()[m_id-1]; }
inline Scalar& valueRef() { return const_cast<Scalar&>(m_matrix.valuePtr()[m_id-1]); }
inline Index index() const { return m_matrix.innerIndexPtr()[m_id-1]; }
inline Index row() const { return IsRowMajor ? m_outer : index(); }
inline Index col() const { return IsRowMajor ? index() : m_outer; }
inline operator bool() const { return (m_id <= m_end) && (m_id>m_start); }
protected:
const MappedSparseMatrix& m_matrix;
const Index m_outer;
Index m_id;
const Index m_start;
const Index m_end;
};
#endif // EIGEN_MAPPED_SPARSEMATRIX_H

View File

@ -74,6 +74,13 @@ template<typename Scalar> void sparse_solvers(int rows, int cols)
m2.template triangularView<Upper>().solve(vec3));
VERIFY_IS_APPROX(refMat2.conjugate().template triangularView<Upper>().solve(vec2),
m2.conjugate().template triangularView<Upper>().solve(vec3));
{
SparseMatrix<Scalar> cm2(m2);
//Index rows, Index cols, Index nnz, Index* outerIndexPtr, Index* innerIndexPtr, Scalar* valuePtr
MappedSparseMatrix<Scalar> mm2(rows, cols, cm2.nonZeros(), cm2.outerIndexPtr(), cm2.innerIndexPtr(), cm2.valuePtr());
VERIFY_IS_APPROX(refMat2.conjugate().template triangularView<Upper>().solve(vec2),
mm2.conjugate().template triangularView<Upper>().solve(vec3));
}
// lower - transpose
initSparse<Scalar>(density, refMat2, m2, ForceNonZeroDiag|MakeLowerTriangular, &zeroCoords, &nonzeroCoords);