UmfPack support: fix redundant evaluation/copies when calling compute() and support generic expressions as input

This commit is contained in:
Gael Guennebaud 2014-12-02 17:30:57 +01:00
parent 775f7e5fbb
commit 433bce5c3a

View File

@ -107,6 +107,16 @@ inline int umfpack_get_determinant(std::complex<double> *Mx, double *Ex, void *N
return umfpack_zi_get_determinant(&mx_real,0,Ex,NumericHandle,User_Info);
}
namespace internal {
template<typename T> struct umfpack_helper_is_sparse_plain : false_type {};
template<typename Scalar, int Options, typename StorageIndex>
struct umfpack_helper_is_sparse_plain<SparseMatrix<Scalar,Options,StorageIndex> >
: true_type {};
template<typename Scalar, int Options, typename StorageIndex>
struct umfpack_helper_is_sparse_plain<MappedSparseMatrix<Scalar,Options,StorageIndex> >
: true_type {};
}
/** \ingroup UmfPackSupport_Module
* \brief A sparse LU factorization and solver based on UmfPack
*
@ -199,8 +209,11 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
template<typename InputMatrixType>
void compute(const InputMatrixType& matrix)
{
analyzePattern(matrix);
factorize(matrix);
if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar());
if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar());
grapInput(matrix.derived());
analyzePattern_impl();
factorize_impl();
}
/** Performs a symbolic decomposition on the sparcity of \a matrix.
@ -212,22 +225,12 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
template<typename InputMatrixType>
void analyzePattern(const InputMatrixType& matrix)
{
if(m_symbolic)
umfpack_free_symbolic(&m_symbolic,Scalar());
if(m_numeric)
umfpack_free_numeric(&m_numeric,Scalar());
if(m_symbolic) umfpack_free_symbolic(&m_symbolic,Scalar());
if(m_numeric) umfpack_free_numeric(&m_numeric,Scalar());
grapInput(matrix);
grapInput(matrix.derived());
int errorCode = 0;
errorCode = umfpack_symbolic(matrix.rows(), matrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr,
&m_symbolic, 0, 0);
m_isInitialized = true;
m_info = errorCode ? InvalidInput : Success;
m_analysisIsOk = true;
m_factorizationIsOk = false;
m_extractedDataAreDirty = true;
analyzePattern_impl();
}
/** Performs a numeric decomposition of \a matrix
@ -243,15 +246,9 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
if(m_numeric)
umfpack_free_numeric(&m_numeric,Scalar());
grapInput(matrix);
int errorCode;
errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr,
m_symbolic, &m_numeric, 0, 0);
m_info = errorCode ? NumericalIssue : Success;
m_factorizationIsOk = true;
m_extractedDataAreDirty = true;
grapInput(matrix.derived());
factorize_impl();
}
#ifndef EIGEN_PARSED_BY_DOXYGEN
@ -266,7 +263,6 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
protected:
void init()
{
m_info = InvalidInput;
@ -280,7 +276,7 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
}
template<typename InputMatrixType>
void grapInput(const InputMatrixType& mat)
void grapInput_impl(const InputMatrixType& mat, internal::true_type)
{
m_copyMatrix.resize(mat.rows(), mat.cols());
if( ((MatrixType::Flags&RowMajorBit)==RowMajorBit) || sizeof(typename MatrixType::Index)!=sizeof(int) || !mat.isCompressed() )
@ -298,6 +294,45 @@ class UmfPackLU : public SparseSolverBase<UmfPackLU<_MatrixType> >
m_valuePtr = mat.valuePtr();
}
}
template<typename InputMatrixType>
void grapInput_impl(const InputMatrixType& mat, internal::false_type)
{
m_copyMatrix = mat;
m_outerIndexPtr = m_copyMatrix.outerIndexPtr();
m_innerIndexPtr = m_copyMatrix.innerIndexPtr();
m_valuePtr = m_copyMatrix.valuePtr();
}
template<typename InputMatrixType>
void grapInput(const InputMatrixType& mat)
{
grapInput_impl(mat, internal::umfpack_helper_is_sparse_plain<InputMatrixType>());
}
void analyzePattern_impl()
{
int errorCode = 0;
errorCode = umfpack_symbolic(m_copyMatrix.rows(), m_copyMatrix.cols(), m_outerIndexPtr, m_innerIndexPtr, m_valuePtr,
&m_symbolic, 0, 0);
m_isInitialized = true;
m_info = errorCode ? InvalidInput : Success;
m_analysisIsOk = true;
m_factorizationIsOk = false;
m_extractedDataAreDirty = true;
}
void factorize_impl()
{
int errorCode;
errorCode = umfpack_numeric(m_outerIndexPtr, m_innerIndexPtr, m_valuePtr,
m_symbolic, &m_numeric, 0, 0);
m_info = errorCode ? NumericalIssue : Success;
m_factorizationIsOk = true;
m_extractedDataAreDirty = true;
}
// cached data to reduce reallocation, etc.
mutable LUMatrixType m_l;