bug #901: fix triangular-view with unit diagonal of sparse rectangular matrices.

This commit is contained in:
Gael Guennebaud 2016-02-12 15:58:31 +01:00
parent b35d1a122e
commit 0a537cb2d8
2 changed files with 14 additions and 12 deletions

View File

@ -70,20 +70,20 @@ class TriangularViewImpl<MatrixType,Mode,Sparse>::InnerIterator : public MatrixT
public: public:
EIGEN_STRONG_INLINE InnerIterator(const TriangularViewImpl& view, Index outer) EIGEN_STRONG_INLINE InnerIterator(const TriangularViewImpl& view, Index outer)
: Base(view.derived().nestedExpression(), outer), m_returnOne(false) : Base(view.derived().nestedExpression(), outer), m_returnOne(false), m_containsDiag(Base::outer()<view.innerSize())
{ {
if(SkipFirst) if(SkipFirst)
{ {
while((*this) && ((HasUnitDiag||SkipDiag) ? this->index()<=outer : this->index()<outer)) while((*this) && ((HasUnitDiag||SkipDiag) ? this->index()<=outer : this->index()<outer))
Base::operator++(); Base::operator++();
if(HasUnitDiag) if(HasUnitDiag)
m_returnOne = true; m_returnOne = m_containsDiag;
} }
else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer())) else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
{ {
if((!SkipFirst) && Base::operator bool()) if((!SkipFirst) && Base::operator bool())
Base::operator++(); Base::operator++();
m_returnOne = true; m_returnOne = m_containsDiag;
} }
} }
@ -98,7 +98,7 @@ class TriangularViewImpl<MatrixType,Mode,Sparse>::InnerIterator : public MatrixT
{ {
if((!SkipFirst) && Base::operator bool()) if((!SkipFirst) && Base::operator bool())
Base::operator++(); Base::operator++();
m_returnOne = true; m_returnOne = m_containsDiag;
} }
} }
return *this; return *this;
@ -130,6 +130,7 @@ class TriangularViewImpl<MatrixType,Mode,Sparse>::InnerIterator : public MatrixT
} }
protected: protected:
bool m_returnOne; bool m_returnOne;
bool m_containsDiag;
}; };
template<typename MatrixType, unsigned int Mode> template<typename MatrixType, unsigned int Mode>
@ -193,7 +194,7 @@ public:
Flags = XprType::Flags Flags = XprType::Flags
}; };
explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()) {} explicit unary_evaluator(const XprType &xpr) : m_argImpl(xpr.nestedExpression()), m_arg(xpr.nestedExpression()) {}
inline Index nonZerosEstimate() const { inline Index nonZerosEstimate() const {
return m_argImpl.nonZerosEstimate(); return m_argImpl.nonZerosEstimate();
@ -205,20 +206,20 @@ public:
public: public:
EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer) EIGEN_STRONG_INLINE InnerIterator(const unary_evaluator& xprEval, Index outer)
: Base(xprEval.m_argImpl,outer), m_returnOne(false) : Base(xprEval.m_argImpl,outer), m_returnOne(false), m_containsDiag(Base::outer()<xprEval.m_arg.innerSize())
{ {
if(SkipFirst) if(SkipFirst)
{ {
while((*this) && ((HasUnitDiag||SkipDiag) ? this->index()<=outer : this->index()<outer)) while((*this) && ((HasUnitDiag||SkipDiag) ? this->index()<=outer : this->index()<outer))
Base::operator++(); Base::operator++();
if(HasUnitDiag) if(HasUnitDiag)
m_returnOne = true; m_returnOne = m_containsDiag;
} }
else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer())) else if(HasUnitDiag && ((!Base::operator bool()) || Base::index()>=Base::outer()))
{ {
if((!SkipFirst) && Base::operator bool()) if((!SkipFirst) && Base::operator bool())
Base::operator++(); Base::operator++();
m_returnOne = true; // FIXME check innerSize()>outer(); m_returnOne = m_containsDiag;
} }
} }
@ -233,7 +234,7 @@ public:
{ {
if((!SkipFirst) && Base::operator bool()) if((!SkipFirst) && Base::operator bool())
Base::operator++(); Base::operator++();
m_returnOne = true; // FIXME check innerSize()>outer(); m_returnOne = m_containsDiag;
} }
} }
return *this; return *this;
@ -266,12 +267,14 @@ public:
protected: protected:
bool m_returnOne; bool m_returnOne;
bool m_containsDiag;
private: private:
Scalar& valueRef(); Scalar& valueRef();
}; };
protected: protected:
evaluator<ArgType> m_argImpl; evaluator<ArgType> m_argImpl;
const ArgType& m_arg;
}; };
} // end namespace internal } // end namespace internal

View File

@ -21,8 +21,8 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
const Index rows = ref.rows(); const Index rows = ref.rows();
const Index cols = ref.cols(); const Index cols = ref.cols();
const Index inner = ref.innerSize(); //const Index inner = ref.innerSize();
const Index outer = ref.outerSize(); //const Index outer = ref.outerSize();
typedef typename SparseMatrixType::Scalar Scalar; typedef typename SparseMatrixType::Scalar Scalar;
enum { Flags = SparseMatrixType::Flags }; enum { Flags = SparseMatrixType::Flags };
@ -327,7 +327,6 @@ template<typename SparseMatrixType> void sparse_basic(const SparseMatrixType& re
m3 = m2.template triangularView<Upper>(); m3 = m2.template triangularView<Upper>();
VERIFY_IS_APPROX(m3, refMat3); VERIFY_IS_APPROX(m3, refMat3);
if(inner>=outer) // FIXME this should be implemented for outer>inner as well
{ {
refMat3 = refMat2.template triangularView<UnitUpper>(); refMat3 = refMat2.template triangularView<UnitUpper>();
m3 = m2.template triangularView<UnitUpper>(); m3 = m2.template triangularView<UnitUpper>();