Fix bug #596 : Recover plain SparseMatrix from SparseQR matrixQ()

This commit is contained in:
Desire NUENTSA 2013-05-21 17:35:10 +02:00
parent bd7511fc36
commit cf939f154f
5 changed files with 112 additions and 28 deletions

View File

@ -300,7 +300,7 @@ template<typename OtherDerived>
EIGEN_STRONG_INLINE Derived &
SparseMatrixBase<Derived>::operator-=(const SparseMatrixBase<OtherDerived> &other)
{
return *this = derived() - other.derived();
return derived() = derived() - other.derived();
}
template<typename Derived>
@ -308,7 +308,7 @@ template<typename OtherDerived>
EIGEN_STRONG_INLINE Derived &
SparseMatrixBase<Derived>::operator+=(const SparseMatrixBase<OtherDerived>& other)
{
return *this = derived() + other.derived();
return derived() = derived() + other.derived();
}
template<typename Derived>

View File

@ -673,6 +673,14 @@ class SparseMatrix
m_data.swap(other.m_data);
}
/** Sets *this to the identity matrix */
inline void setIdentity()
{
eigen_assert(rows() == cols() && "ONLY FOR SQUARED MATRICES");
this->setZero();
for (int j = 0; j < rows(); j++)
this->insert(j,j) = Scalar(1.0);
}
inline SparseMatrix& operator=(const SparseMatrix& other)
{
if (other.isRValue())

View File

@ -21,6 +21,8 @@ namespace internal {
template <typename SparseQRType> struct traits<SparseQRMatrixQReturnType<SparseQRType> >
{
typedef typename SparseQRType::MatrixType ReturnType;
typedef typename ReturnType::Index Index;
typedef typename ReturnType::StorageKind StorageKind;
};
template <typename SparseQRType> struct traits<SparseQRMatrixQTransposeReturnType<SparseQRType> >
{
@ -72,10 +74,10 @@ class SparseQR
typedef Matrix<Scalar, Dynamic, 1> ScalarVector;
typedef PermutationMatrix<Dynamic, Dynamic, Index> PermutationType;
public:
SparseQR () : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true)
SparseQR () : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true),m_isQSorted(false)
{ }
SparseQR(const MatrixType& mat) : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true)
SparseQR(const MatrixType& mat) : m_isInitialized(false), m_analysisIsok(false), m_lastError(""), m_useDefaultThreshold(true),m_isQSorted(false)
{
compute(mat);
}
@ -110,11 +112,23 @@ class SparseQR
}
/** \returns an expression of the matrix Q as products of sparse Householder reflectors.
* You can do the following to get an actual SparseMatrix representation of Q:
* \code
* SparseMatrix<double> Q = SparseQR<SparseMatrix<double> >(A).matrixQ();
* \endcode
*/
* The common usage of this function is to apply it to a dense matrix or vector
* \code
* VectorXd B1, B2;
* // Initialize B1
* B2 = matrixQ() * B1;
* \endcode
*
* To get a plain SparseMatrix representation of Q:
* \code
* SparseMatrix<double> Q;
* Q = SparseQR<SparseMatrix<double> >(A).matrixQ();
* \endcode
* Internally, this call simply performs a sparse product between the matrix Q
* and a sparse identity matrix. However, due to the fact that the sparse
* reflectors are stored unsorted, two transpositions are needed to sort
* them before performing the product.
*/
SparseQRMatrixQReturnType<SparseQR> matrixQ() const
{ return SparseQRMatrixQReturnType<SparseQR>(*this); }
@ -158,6 +172,7 @@ class SparseQR
return true;
}
/** Sets the threshold that is used to determine linearly dependent columns during the factorization.
*
* In practice, if during the factorization the norm of the column that has to be eliminated is below
@ -180,6 +195,13 @@ class SparseQR
eigen_assert(this->rows() == B.rows() && "SparseQR::solve() : invalid number of rows in the right hand side matrix");
return internal::solve_retval<SparseQR, Rhs>(*this, B.derived());
}
template<typename Rhs>
inline const internal::sparse_solve_retval<SparseQR, Rhs> solve(const SparseMatrixBase<Rhs>& B) const
{
eigen_assert(m_isInitialized && "The factorization should be called first, use compute()");
eigen_assert(this->rows() == B.rows() && "SparseQR::solve() : invalid number of rows in the right hand side matrix");
return internal::sparse_solve_retval<SparseQR, Rhs>(*this, B.derived());
}
/** \brief Reports whether previous computation was successful.
*
@ -194,6 +216,16 @@ class SparseQR
eigen_assert(m_isInitialized && "Decomposition is not initialized.");
return m_info;
}
protected:
inline void sort_matrix_Q()
{
// The matrix Q is sorted during the transposition
SparseMatrix<Scalar, RowMajor, Index> mQrm(this->m_Q);
this->m_Q = mQrm;
this->m_isQSorted = true;
}
protected:
bool m_isInitialized;
@ -213,8 +245,10 @@ class SparseQR
Index m_nonzeropivots; // Number of non zero pivots found
IndexVector m_etree; // Column elimination tree
IndexVector m_firstRowElt; // First element in each row
bool m_isQSorted; // whether Q is sorted or not
template <typename, typename > friend struct SparseQR_QProduct;
template <typename > friend struct SparseQRMatrixQReturnType;
};
@ -462,6 +496,7 @@ void SparseQR<MatrixType,OrderingType>::factorize(const MatrixType& mat)
m_Q.makeCompressed();
m_R.finalize();
m_R.makeCompressed();
m_isQSorted = false;
m_nonzeropivots = nonzeroCol;
@ -494,7 +529,18 @@ struct solve_retval<SparseQR<_MatrixType,OrderingType>, Rhs>
dec()._solve(rhs(),dst);
}
};
template<typename _MatrixType, typename OrderingType, typename Rhs>
struct sparse_solve_retval<SparseQR<_MatrixType, OrderingType>, Rhs>
: sparse_solve_retval_base<SparseQR<_MatrixType, OrderingType>, Rhs>
{
typedef SparseQR<_MatrixType, OrderingType> Dec;
EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec, Rhs)
template<typename Dest> void evalTo(Dest& dst) const
{
this->defaultEvalTo(dst);
}
};
} // end namespace internal
template <typename SparseQRType, typename Derived>
@ -513,34 +559,35 @@ struct SparseQR_QProduct : ReturnByValue<SparseQR_QProduct<SparseQRType, Derived
template<typename DesType>
void evalTo(DesType& res) const
{
Index n = m_qr.cols();
Index n = m_qr.cols();
res = m_other;
if (m_transpose)
{
eigen_assert(m_qr.m_Q.rows() == m_other.rows() && "Non conforming object sizes");
// Compute res = Q' * other :
res = m_other;
for (Index k = 0; k < n; k++)
{
Scalar tau = Scalar(0);
tau = m_qr.m_Q.col(k).dot(res);
tau = tau * m_qr.m_hcoeffs(k);
for (typename MatrixType::InnerIterator itq(m_qr.m_Q, k); itq; ++itq)
//Compute res = Q' * other column by column
for(Index j = 0; j < res.cols(); j++){
for (Index k = 0; k < n; k++)
{
res(itq.row()) -= itq.value() * tau;
Scalar tau = Scalar(0);
tau = m_qr.m_Q.col(k).dot(res.col(j));
tau = tau * m_qr.m_hcoeffs(k);
res.col(j) -= tau * m_qr.m_Q.col(k);
}
}
}
else
{
eigen_assert(m_qr.m_Q.cols() == m_other.rows() && "Non conforming object sizes");
// Compute res = Q * other :
res = m_other;
for (Index k = n-1; k >=0; k--)
// Compute res = Q' * other column by column
for(Index j = 0; j < res.cols(); j++)
{
Scalar tau = Scalar(0);
tau = m_qr.m_Q.col(k).dot(res);
tau = tau * m_qr.m_hcoeffs(k);
res -= tau * m_qr.m_Q.col(k);
for (Index k = n-1; k >=0; k--)
{
Scalar tau = Scalar(0);
tau = m_qr.m_Q.col(k).dot(res.col(j));
tau = tau * m_qr.m_hcoeffs(k);
res.col(j) -= tau * m_qr.m_Q.col(k);
}
}
}
}
@ -551,8 +598,11 @@ struct SparseQR_QProduct : ReturnByValue<SparseQR_QProduct<SparseQRType, Derived
};
template<typename SparseQRType>
struct SparseQRMatrixQReturnType
struct SparseQRMatrixQReturnType : public EigenBase<SparseQRMatrixQReturnType<SparseQRType> >
{
typedef typename SparseQRType::Index Index;
typedef typename SparseQRType::Scalar Scalar;
typedef Matrix<Scalar,Dynamic,Dynamic> DenseMatrix;
SparseQRMatrixQReturnType(const SparseQRType& qr) : m_qr(qr) {}
template<typename Derived>
SparseQR_QProduct<SparseQRType, Derived> operator*(const MatrixBase<Derived>& other)
@ -563,11 +613,29 @@ struct SparseQRMatrixQReturnType
{
return SparseQRMatrixQTransposeReturnType<SparseQRType>(m_qr);
}
inline Index rows() const { return m_qr.rows(); }
inline Index cols() const { return m_qr.cols(); }
// To use for operations with the transpose of Q
SparseQRMatrixQTransposeReturnType<SparseQRType> transpose() const
{
return SparseQRMatrixQTransposeReturnType<SparseQRType>(m_qr);
}
template<typename Dest> void evalTo(MatrixBase<Dest>& dest) const
{
dest.resize(m_qr.rows(), m_qr.cols());
dest.derived() = m_qr.matrixQ() * Dest::Identity(m_qr.rows(), m_qr.rows());
}
template<typename Dest> void evalTo(SparseMatrixBase<Dest>& dest) const
{
Dest idMat(m_qr.rows(), m_qr.rows());
idMat.setIdentity();
dest.derived().resize(m_qr.rows(), m_qr.cols());
// Sort the sparse householder reflectors if needed
if(!m_qr.m_isQSorted)
const_cast<SparseQRType *>(&m_qr)->sort_matrix_Q();
dest.derived() = SparseQR_QProduct<SparseQRType, Dest>(m_qr, idMat, false);
}
const SparseQRType& m_qr;
};

View File

@ -71,6 +71,14 @@ template<typename Scalar> void test_sparseqr_scalar()
VERIFY((dA * refX - b).norm() * 2 > (A * x - b).norm() );
else
VERIFY_IS_APPROX(x, refX);
// Compute explicitly the matrix Q
MatrixType Q, QtQ, idM;
Q = solver.matrixQ();
//Check ||Q' * Q - I ||
QtQ = Q * Q.adjoint();
idM.resize(Q.rows(), Q.rows()); idM.setIdentity();
VERIFY(idM.isApprox(QtQ));
}
void test_sparseqr()
{

View File

@ -302,8 +302,8 @@ LevenbergMarquardt<FunctorType>::minimizeInit(FVectorType &x)
for (Index j = 0; j < n; ++j)
if (m_diag[j] <= 0.)
{
return LevenbergMarquardtSpace::ImproperInputParameters;
m_info = InvalidInput;
return LevenbergMarquardtSpace::ImproperInputParameters;
}
/* evaluate the function at the starting point */