fix bug #356: fix TriangularView::InnerIterator for unit diagonals

This commit is contained in:
Gael Guennebaud 2011-12-04 14:39:24 +01:00
parent 32917515df
commit 9353bbac4a
3 changed files with 110 additions and 16 deletions

View File

@ -113,7 +113,7 @@ template<typename MatrixType, unsigned int UpLo> class SparseSelfAdjointView
SparseSelfAdjointView& rankUpdate(const SparseMatrixBase<DerivedU>& u, Scalar alpha = Scalar(1));
/** \internal triggered by sparse_matrix = SparseSelfadjointView; */
template<typename DestScalar> void evalTo(SparseMatrix<DestScalar,ColMajor,Index>& _dest) const
template<typename DestScalar,int StorageOrder> void evalTo(SparseMatrix<DestScalar,StorageOrder,Index>& _dest) const
{
internal::permute_symm_to_fullsymm<UpLo>(m_matrix, _dest);
}

View File

@ -37,9 +37,10 @@ struct traits<SparseTriangularView<MatrixType,Mode> >
template<typename MatrixType, int Mode> class SparseTriangularView
: public SparseMatrixBase<SparseTriangularView<MatrixType,Mode> >
{
enum { SkipFirst = (Mode==Lower && !(MatrixType::Flags&RowMajorBit))
|| (Mode==Upper && (MatrixType::Flags&RowMajorBit)),
SkipLast = !SkipFirst
enum { SkipFirst = ((Mode&Lower) && !(MatrixType::Flags&RowMajorBit))
|| ((Mode&Upper) && (MatrixType::Flags&RowMajorBit)),
SkipLast = !SkipFirst,
HasUnitDiag = (Mode&UnitDiag) ? 1 : 0
};
public:
@ -81,19 +82,61 @@ class SparseTriangularView<MatrixType,Mode>::InnerIterator : public MatrixType::
public:
EIGEN_STRONG_INLINE InnerIterator(const SparseTriangularView& view, Index outer)
: Base(view.nestedExpression(), outer)
: Base(view.nestedExpression(), outer), m_returnOne(false)
{
if(SkipFirst)
while((*this) && this->index()<outer)
++(*this);
{
while((*this) && (HasUnitDiag ? this->index()<=outer : this->index()<outer))
Base::operator++();
if(HasUnitDiag)
m_returnOne = true;
}
else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
{
if((!SkipFirst) && Base::operator bool())
Base::operator++();
m_returnOne = true;
}
}
EIGEN_STRONG_INLINE InnerIterator& operator++()
{
if(HasUnitDiag && m_returnOne)
m_returnOne = false;
else
{
Base::operator++();
if(HasUnitDiag && (!SkipFirst) && ((!Base::operator bool()) || Base::index()>=Base::outer()))
{
if((!SkipFirst) && Base::operator bool())
Base::operator++();
m_returnOne = true;
}
}
return *this;
}
inline Index row() const { return Base::row(); }
inline Index col() const { return Base::col(); }
inline Index index() const
{
if(HasUnitDiag && m_returnOne) return Base::outer();
else return Base::index();
}
inline Scalar value() const
{
if(HasUnitDiag && m_returnOne) return Scalar(1);
else return Base::value();
}
EIGEN_STRONG_INLINE operator bool() const
{
return SkipFirst ? Base::operator bool() : (Base::operator bool() && this->index() <= this->outer());
if(HasUnitDiag && m_returnOne)
return true;
return (SkipFirst ? Base::operator bool() : (Base::operator bool() && this->index() <= this->outer()));
}
protected:
bool m_returnOne;
};
template<typename MatrixType, int Mode>
@ -105,10 +148,15 @@ class SparseTriangularView<MatrixType,Mode>::ReverseInnerIterator : public Matri
EIGEN_STRONG_INLINE ReverseInnerIterator(const SparseTriangularView& view, Index outer)
: Base(view.nestedExpression(), outer)
{
eigen_assert((!HasUnitDiag) && "ReverseInnerIterator does not support yet triangular views with a unit diagonal");
if(SkipLast)
while((*this) && this->index()>outer)
--(*this);
}
EIGEN_STRONG_INLINE InnerIterator& operator--()
{ Base::operator--(); return *this; }
inline Index row() const { return Base::row(); }
inline Index col() const { return Base::col(); }

View File

@ -196,7 +196,10 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
VERIFY_IS_APPROX(m1+=m2, refM1+=refM2);
VERIFY_IS_APPROX(m1-=m2, refM1-=refM2);
VERIFY_IS_APPROX(m1.col(0).dot(refM2.row(0)), refM1.col(0).dot(refM2.row(0)));
if(SparseMatrixType::IsRowMajor)
VERIFY_IS_APPROX(m1.innerVector(0).dot(refM2.row(0)), refM1.row(0).dot(refM2.row(0)));
else
VERIFY_IS_APPROX(m1.innerVector(0).dot(refM2.row(0)), refM1.col(0).dot(refM2.row(0)));
VERIFY_IS_APPROX(m1.conjugate(), refM1.conjugate());
VERIFY_IS_APPROX(m1.real(), refM1.real());
@ -225,8 +228,15 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
initSparse<Scalar>(density, refMat2, m2);
int j0 = internal::random<int>(0,rows-1);
int j1 = internal::random<int>(0,rows-1);
VERIFY_IS_APPROX(m2.innerVector(j0), refMat2.col(j0));
VERIFY_IS_APPROX(m2.innerVector(j0)+m2.innerVector(j1), refMat2.col(j0)+refMat2.col(j1));
if(SparseMatrixType::IsRowMajor)
VERIFY_IS_APPROX(m2.innerVector(j0), refMat2.row(j0));
else
VERIFY_IS_APPROX(m2.innerVector(j0), refMat2.col(j0));
if(SparseMatrixType::IsRowMajor)
VERIFY_IS_APPROX(m2.innerVector(j0)+m2.innerVector(j1), refMat2.row(j0)+refMat2.row(j1));
else
VERIFY_IS_APPROX(m2.innerVector(j0)+m2.innerVector(j1), refMat2.col(j0)+refMat2.col(j1));
//m2.innerVector(j0) = 2*m2.innerVector(j1);
//refMat2.col(j0) = 2*refMat2.col(j1);
//VERIFY_IS_APPROX(m2, refMat2);
@ -240,9 +250,16 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
int j0 = internal::random<int>(0,rows-2);
int j1 = internal::random<int>(0,rows-2);
int n0 = internal::random<int>(1,rows-(std::max)(j0,j1));
VERIFY_IS_APPROX(m2.innerVectors(j0,n0), refMat2.block(0,j0,rows,n0));
VERIFY_IS_APPROX(m2.innerVectors(j0,n0)+m2.innerVectors(j1,n0),
refMat2.block(0,j0,rows,n0)+refMat2.block(0,j1,rows,n0));
if(SparseMatrixType::IsRowMajor)
VERIFY_IS_APPROX(m2.innerVectors(j0,n0), refMat2.block(j0,0,n0,cols));
else
VERIFY_IS_APPROX(m2.innerVectors(j0,n0), refMat2.block(0,j0,rows,n0));
if(SparseMatrixType::IsRowMajor)
VERIFY_IS_APPROX(m2.innerVectors(j0,n0)+m2.innerVectors(j1,n0),
refMat2.block(j0,0,n0,cols)+refMat2.block(j1,0,n0,cols));
else
VERIFY_IS_APPROX(m2.innerVectors(j0,n0)+m2.innerVectors(j1,n0),
refMat2.block(0,j0,rows,n0)+refMat2.block(0,j1,rows,n0));
//m2.innerVectors(j0,n0) = m2.innerVectors(j0,n0) + m2.innerVectors(j1,n0);
//refMat2.block(0,j0,rows,n0) = refMat2.block(0,j0,rows,n0) + refMat2.block(0,j1,rows,n0);
}
@ -272,7 +289,11 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
else
{
countTrueNonZero++;
m2.insertBackByOuterInner(j,i) = refM2(i,j) = Scalar(1);
m2.insertBackByOuterInner(j,i) = Scalar(1);
if(SparseMatrixType::IsRowMajor)
refM2(j,i) = Scalar(1);
else
refM2(i,j) = Scalar(1);
}
}
}
@ -283,8 +304,31 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
VERIFY(countTrueNonZero==m2.nonZeros());
VERIFY_IS_APPROX(m2, refM2);
}
// test triangularView
{
DenseMatrix refMat2(rows, rows), refMat3(rows, rows);
SparseMatrixType m2(rows, rows), m3(rows, rows);
initSparse<Scalar>(density, refMat2, m2);
refMat3 = refMat2.template triangularView<Lower>();
m3 = m2.template triangularView<Lower>();
VERIFY_IS_APPROX(m3, refMat3);
refMat3 = refMat2.template triangularView<Upper>();
m3 = m2.template triangularView<Upper>();
VERIFY_IS_APPROX(m3, refMat3);
refMat3 = refMat2.template triangularView<UnitUpper>();
m3 = m2.template triangularView<UnitUpper>();
VERIFY_IS_APPROX(m3, refMat3);
refMat3 = refMat2.template triangularView<UnitLower>();
m3 = m2.template triangularView<UnitLower>();
VERIFY_IS_APPROX(m3, refMat3);
}
// test selfadjointView
if(!SparseMatrixType::IsRowMajor)
{
DenseMatrix refMat2(rows, rows), refMat3(rows, rows);
SparseMatrixType m2(rows, rows), m3(rows, rows);
@ -308,8 +352,10 @@ void test_sparse_basic()
for(int i = 0; i < g_repeat; i++) {
int s = Eigen::internal::random<int>(1,50);
CALL_SUBTEST_1(( sparse_basic(SparseMatrix<double>(8, 8)) ));
CALL_SUBTEST_2(( sparse_basic(SparseMatrix<std::complex<double> >(s, s)) ));
CALL_SUBTEST_2(( sparse_basic(SparseMatrix<std::complex<double>, ColMajor>(s, s)) ));
CALL_SUBTEST_2(( sparse_basic(SparseMatrix<std::complex<double>, RowMajor>(s, s)) ));
CALL_SUBTEST_1(( sparse_basic(SparseMatrix<double>(s, s)) ));
CALL_SUBTEST_1(( sparse_basic(SparseMatrix<double,ColMajor,long int>(s, s)) ));
CALL_SUBTEST_1(( sparse_basic(SparseMatrix<double,RowMajor,long int>(s, s)) ));
}
}