Use Ref<> to ensure that both x and b in Ax=b are compatible with Umfpack/SuperLU expectations

This commit is contained in:
Gael Guennebaud 2015-02-03 23:46:05 +01:00
parent ebdf6a2dbb
commit b1eca55328
2 changed files with 31 additions and 5 deletions

View File

@ -627,8 +627,12 @@ void SuperLU<MatrixType>::_solve_impl(const MatrixBase<Rhs> &b, MatrixBase<Dest>
m_sluFerr.resize(rhsCols); m_sluFerr.resize(rhsCols);
m_sluBerr.resize(rhsCols); m_sluBerr.resize(rhsCols);
m_sluB = SluMatrix::Map(b.const_cast_derived());
m_sluX = SluMatrix::Map(x.derived()); Ref<const Matrix<typename Rhs::Scalar,Dynamic,Dynamic,ColMajor> > b_ref(b);
Ref<const Matrix<typename Dest::Scalar,Dynamic,Dynamic,ColMajor> > x_ref(x);
m_sluB = SluMatrix::Map(b_ref.const_cast_derived());
m_sluX = SluMatrix::Map(x_ref.const_cast_derived());
typename Rhs::PlainObject b_cpy; typename Rhs::PlainObject b_cpy;
if(m_sluEqued!='N') if(m_sluEqued!='N')
@ -651,6 +655,10 @@ void SuperLU<MatrixType>::_solve_impl(const MatrixBase<Rhs> &b, MatrixBase<Dest>
&m_sluFerr[0], &m_sluBerr[0], &m_sluFerr[0], &m_sluBerr[0],
&m_sluStat, &info, Scalar()); &m_sluStat, &info, Scalar());
StatFree(&m_sluStat); StatFree(&m_sluStat);
if(&x.coeffRef(0) != x_ref.data())
x = x_ref;
m_info = info==0 ? Success : NumericalIssue; m_info = info==0 ? Success : NumericalIssue;
} }
@ -938,8 +946,12 @@ void SuperILU<MatrixType>::_solve_impl(const MatrixBase<Rhs> &b, MatrixBase<Dest
m_sluFerr.resize(rhsCols); m_sluFerr.resize(rhsCols);
m_sluBerr.resize(rhsCols); m_sluBerr.resize(rhsCols);
m_sluB = SluMatrix::Map(b.const_cast_derived());
m_sluX = SluMatrix::Map(x.derived()); Ref<const Matrix<typename Rhs::Scalar,Dynamic,Dynamic,ColMajor> > b_ref(b);
Ref<const Matrix<typename Dest::Scalar,Dynamic,Dynamic,ColMajor> > x_ref(x);
m_sluB = SluMatrix::Map(b_ref.const_cast_derived());
m_sluX = SluMatrix::Map(x_ref.const_cast_derived());
typename Rhs::PlainObject b_cpy; typename Rhs::PlainObject b_cpy;
if(m_sluEqued!='N') if(m_sluEqued!='N')
@ -962,6 +974,9 @@ void SuperILU<MatrixType>::_solve_impl(const MatrixBase<Rhs> &b, MatrixBase<Dest
&recip_pivot_growth, &rcond, &recip_pivot_growth, &rcond,
&m_sluStat, &info, Scalar()); &m_sluStat, &info, Scalar());
StatFree(&m_sluStat); StatFree(&m_sluStat);
if(&x.coeffRef(0) != x_ref.data())
x = x_ref;
m_info = info==0 ? Success : NumericalIssue; m_info = info==0 ? Success : NumericalIssue;
} }

View File

@ -403,11 +403,22 @@ bool UmfPackLU<MatrixType>::_solve_impl(const MatrixBase<BDerived> &b, MatrixBas
eigen_assert(b.derived().data() != x.derived().data() && " Umfpack does not support inplace solve"); eigen_assert(b.derived().data() != x.derived().data() && " Umfpack does not support inplace solve");
int errorCode; int errorCode;
Scalar* x_ptr = 0;
Matrix<Scalar,Dynamic,1> x_tmp;
if(x.innerStride()!=1)
{
x_tmp.resize(x.rows());
x_ptr = x_tmp.data();
}
for (int j=0; j<rhsCols; ++j) for (int j=0; j<rhsCols; ++j)
{ {
if(x.innerStride()==1)
x_ptr = &x.col(j).coeffRef(0);
errorCode = umfpack_solve(UMFPACK_A, errorCode = umfpack_solve(UMFPACK_A,
m_outerIndexPtr, m_innerIndexPtr, m_valuePtr, m_outerIndexPtr, m_innerIndexPtr, m_valuePtr,
&x.col(j).coeffRef(0), &b.const_cast_derived().col(j).coeffRef(0), m_numeric, 0, 0); x_ptr, &b.const_cast_derived().col(j).coeffRef(0), m_numeric, 0, 0);
if(x.innerStride()!=1)
x.col(j) = x_tmp;
if (errorCode!=0) if (errorCode!=0)
return false; return false;
} }