Add internal method _solve_impl_transposed() to LU decomposition classes that solves A^T x = b or A^* x = b.

This commit is contained in:
Rasmus Munk Larsen 2015-11-30 13:39:24 -08:00
parent 274b2272b7
commit 1663d15da7
3 changed files with 148 additions and 23 deletions

View File

@ -10,7 +10,7 @@
#ifndef EIGEN_LU_H #ifndef EIGEN_LU_H
#define EIGEN_LU_H #define EIGEN_LU_H
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
template<typename _MatrixType> struct traits<FullPivLU<_MatrixType> > template<typename _MatrixType> struct traits<FullPivLU<_MatrixType> >
@ -384,22 +384,26 @@ template<typename _MatrixType> class FullPivLU
inline Index rows() const { return m_lu.rows(); } inline Index rows() const { return m_lu.rows(); }
inline Index cols() const { return m_lu.cols(); } inline Index cols() const { return m_lu.cols(); }
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType> template<typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
void _solve_impl(const RhsType &rhs, DstType &dst) const; void _solve_impl(const RhsType &rhs, DstType &dst) const;
template<bool Conjugate, typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC
void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const;
#endif #endif
protected: protected:
static void check_template_parameters() static void check_template_parameters()
{ {
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
} }
void computeInPlace(); void computeInPlace();
MatrixType m_lu; MatrixType m_lu;
PermutationPType m_p; PermutationPType m_p;
PermutationQType m_q; PermutationQType m_q;
@ -447,15 +451,15 @@ template<typename InputType>
FullPivLU<MatrixType>& FullPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix) FullPivLU<MatrixType>& FullPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix)
{ {
check_template_parameters(); check_template_parameters();
// the permutations are stored as int indices, so just to be sure: // the permutations are stored as int indices, so just to be sure:
eigen_assert(matrix.rows()<=NumTraits<int>::highest() && matrix.cols()<=NumTraits<int>::highest()); eigen_assert(matrix.rows()<=NumTraits<int>::highest() && matrix.cols()<=NumTraits<int>::highest());
m_isInitialized = true; m_isInitialized = true;
m_lu = matrix.derived(); m_lu = matrix.derived();
computeInPlace(); computeInPlace();
return *this; return *this;
} }
@ -709,7 +713,7 @@ struct image_retval<FullPivLU<_MatrixType> >
template<typename _MatrixType> template<typename _MatrixType>
template<typename RhsType, typename DstType> template<typename RhsType, typename DstType>
void FullPivLU<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) const void FullPivLU<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) const
{ {
/* The decomposition PAQ = LU can be rewritten as A = P^{-1} L U Q^{-1}. /* The decomposition PAQ = LU can be rewritten as A = P^{-1} L U Q^{-1}.
* So we proceed as follows: * So we proceed as follows:
* Step 1: compute c = P * rhs. * Step 1: compute c = P * rhs.
@ -753,6 +757,70 @@ void FullPivLU<_MatrixType>::_solve_impl(const RhsType &rhs, DstType &dst) const
for(Index i = nonzero_pivots; i < m_lu.cols(); ++i) for(Index i = nonzero_pivots; i < m_lu.cols(); ++i)
dst.row(permutationQ().indices().coeff(i)).setZero(); dst.row(permutationQ().indices().coeff(i)).setZero();
} }
template<typename _MatrixType>
template<bool Conjugate, typename RhsType, typename DstType>
void FullPivLU<_MatrixType>::_solve_impl_transposed(const RhsType &rhs, DstType &dst) const
{
/* The decomposition PAQ = LU can be rewritten as A = P^{-1} L U Q^{-1},
* and since permutations are real and unitary, we can write this
* as A^T = Q U^T L^T P,
* So we proceed as follows:
* Step 1: compute c = Q^T rhs.
* Step 2: replace c by the solution x to U^T x = c. May or may not exist.
* Step 3: replace c by the solution x to L^T x = c.
* Step 4: result = P^T c.
* If Conjugate is true, replace "^T" by "^*" above.
*/
const Index rows = this->rows(), cols = this->cols(),
nonzero_pivots = this->rank();
eigen_assert(rhs.rows() == cols);
const Index smalldim = (std::min)(rows, cols);
if(nonzero_pivots == 0)
{
dst.setZero();
return;
}
typename RhsType::PlainObject c(rhs.rows(), rhs.cols());
// Step 1
c = permutationQ().inverse() * rhs;
if (Conjugate) {
// Step 2
m_lu.topLeftCorner(nonzero_pivots, nonzero_pivots)
.template triangularView<Upper>()
.adjoint()
.solveInPlace(c.topRows(nonzero_pivots));
// Step 3
m_lu.topLeftCorner(smalldim, smalldim)
.template triangularView<UnitLower>()
.adjoint()
.solveInPlace(c.topRows(smalldim));
} else {
// Step 2
m_lu.topLeftCorner(nonzero_pivots, nonzero_pivots)
.template triangularView<Upper>()
.transpose()
.solveInPlace(c.topRows(nonzero_pivots));
// Step 3
m_lu.topLeftCorner(smalldim, smalldim)
.template triangularView<UnitLower>()
.transpose()
.solveInPlace(c.topRows(smalldim));
}
// Step 4
PermutationPType invp = permutationP().inverse().eval();
for(Index i = 0; i < smalldim; ++i)
dst.row(invp.indices().coeff(i)) = c.row(i);
for(Index i = smalldim; i < rows; ++i)
dst.row(invp.indices().coeff(i)).setZero();
}
#endif #endif
namespace internal { namespace internal {
@ -765,7 +833,7 @@ struct Assignment<DstXprType, Inverse<FullPivLU<MatrixType> >, internal::assign_
typedef FullPivLU<MatrixType> LuType; typedef FullPivLU<MatrixType> LuType;
typedef Inverse<LuType> SrcXprType; typedef Inverse<LuType> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &) static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &)
{ {
dst = src.nestedExpression().solve(MatrixType::Identity(src.rows(), src.cols())); dst = src.nestedExpression().solve(MatrixType::Identity(src.rows(), src.cols()));
} }
}; };

View File

@ -11,7 +11,7 @@
#ifndef EIGEN_PARTIALLU_H #ifndef EIGEN_PARTIALLU_H
#define EIGEN_PARTIALLU_H #define EIGEN_PARTIALLU_H
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
template<typename _MatrixType> struct traits<PartialPivLU<_MatrixType> > template<typename _MatrixType> struct traits<PartialPivLU<_MatrixType> >
@ -185,7 +185,7 @@ template<typename _MatrixType> class PartialPivLU
inline Index rows() const { return m_lu.rows(); } inline Index rows() const { return m_lu.rows(); }
inline Index cols() const { return m_lu.cols(); } inline Index cols() const { return m_lu.cols(); }
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename RhsType, typename DstType> template<typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -206,17 +206,44 @@ template<typename _MatrixType> class PartialPivLU
m_lu.template triangularView<UnitLower>().solveInPlace(dst); m_lu.template triangularView<UnitLower>().solveInPlace(dst);
// Step 3 // Step 3
m_lu.template triangularView<Upper>().solveInPlace(dst); m_lu.template triangularView<Upper>().solveInPlace(dst);
}
template<bool Conjugate, typename RhsType, typename DstType>
EIGEN_DEVICE_FUNC
void _solve_impl_transposed(const RhsType &rhs, DstType &dst) const {
/* The decomposition PA = LU can be rewritten as A = P^{-1} L U.
* So we proceed as follows:
* Step 1: compute c = Pb.
* Step 2: replace c by the solution x to Lx = c.
* Step 3: replace c by the solution x to Ux = c.
*/
eigen_assert(rhs.rows() == m_lu.cols());
if (Conjugate) {
// Step 1
dst = m_lu.template triangularView<Upper>().adjoint().solve(rhs);
// Step 2
m_lu.template triangularView<UnitLower>().adjoint().solveInPlace(dst);
} else {
// Step 1
dst = m_lu.template triangularView<Upper>().transpose().solve(rhs);
// Step 2
m_lu.template triangularView<UnitLower>().transpose().solveInPlace(dst);
}
// Step 3
dst = permutationP().transpose() * dst;
} }
#endif #endif
protected: protected:
static void check_template_parameters() static void check_template_parameters()
{ {
EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar); EIGEN_STATIC_ASSERT_NON_INTEGER(Scalar);
} }
MatrixType m_lu; MatrixType m_lu;
PermutationType m_p; PermutationType m_p;
TranspositionType m_rowsTranspositions; TranspositionType m_rowsTranspositions;
@ -295,7 +322,7 @@ struct partial_lu_impl
{ {
Index rrows = rows-k-1; Index rrows = rows-k-1;
Index rcols = cols-k-1; Index rcols = cols-k-1;
Index row_of_biggest_in_col; Index row_of_biggest_in_col;
Score biggest_in_corner Score biggest_in_corner
= lu.col(k).tail(rows-k).unaryExpr(Scoring()).maxCoeff(&row_of_biggest_in_col); = lu.col(k).tail(rows-k).unaryExpr(Scoring()).maxCoeff(&row_of_biggest_in_col);
@ -436,10 +463,10 @@ template<typename InputType>
PartialPivLU<MatrixType>& PartialPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix) PartialPivLU<MatrixType>& PartialPivLU<MatrixType>::compute(const EigenBase<InputType>& matrix)
{ {
check_template_parameters(); check_template_parameters();
// the row permutation is stored as int indices, so just to be sure: // the row permutation is stored as int indices, so just to be sure:
eigen_assert(matrix.rows()<NumTraits<int>::highest()); eigen_assert(matrix.rows()<NumTraits<int>::highest());
m_lu = matrix.derived(); m_lu = matrix.derived();
eigen_assert(matrix.rows() == matrix.cols() && "PartialPivLU is only for square (and moreover invertible) matrices"); eigen_assert(matrix.rows() == matrix.cols() && "PartialPivLU is only for square (and moreover invertible) matrices");
@ -492,7 +519,7 @@ struct Assignment<DstXprType, Inverse<PartialPivLU<MatrixType> >, internal::assi
typedef PartialPivLU<MatrixType> LuType; typedef PartialPivLU<MatrixType> LuType;
typedef Inverse<LuType> SrcXprType; typedef Inverse<LuType> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &) static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<Scalar> &)
{ {
dst = src.nestedExpression().solve(MatrixType::Identity(src.rows(), src.cols())); dst = src.nestedExpression().solve(MatrixType::Identity(src.rows(), src.cols()));
} }
}; };

View File

@ -92,6 +92,20 @@ template<typename MatrixType> void lu_non_invertible()
// test that the code, which does resize(), may be applied to an xpr // test that the code, which does resize(), may be applied to an xpr
m2.block(0,0,m2.rows(),m2.cols()) = lu.solve(m3); m2.block(0,0,m2.rows(),m2.cols()) = lu.solve(m3);
VERIFY_IS_APPROX(m3, m1*m2); VERIFY_IS_APPROX(m3, m1*m2);
// test solve with transposed
m3 = MatrixType::Random(rows,cols2);
m2 = m1.transpose()*m3;
m3 = MatrixType::Random(rows,cols2);
lu.template _solve_impl_transposed<false>(m2, m3);
VERIFY_IS_APPROX(m2, m1.transpose()*m3);
// test solve with conjugate transposed
m3 = MatrixType::Random(rows,cols2);
m2 = m1.adjoint()*m3;
m3 = MatrixType::Random(rows,cols2);
lu.template _solve_impl_transposed<true>(m2, m3);
VERIFY_IS_APPROX(m2, m1.adjoint()*m3);
} }
template<typename MatrixType> void lu_invertible() template<typename MatrixType> void lu_invertible()
@ -124,6 +138,12 @@ template<typename MatrixType> void lu_invertible()
m2 = lu.solve(m3); m2 = lu.solve(m3);
VERIFY_IS_APPROX(m3, m1*m2); VERIFY_IS_APPROX(m3, m1*m2);
VERIFY_IS_APPROX(m2, lu.inverse()*m3); VERIFY_IS_APPROX(m2, lu.inverse()*m3);
// test solve with transposed
lu.template _solve_impl_transposed<false>(m3, m2);
VERIFY_IS_APPROX(m3, m1.transpose()*m2);
// test solve with conjugate transposed
lu.template _solve_impl_transposed<true>(m3, m2);
VERIFY_IS_APPROX(m3, m1.adjoint()*m2);
// Regression test for Bug 302 // Regression test for Bug 302
MatrixType m4 = MatrixType::Random(size,size); MatrixType m4 = MatrixType::Random(size,size);
@ -136,14 +156,24 @@ template<typename MatrixType> void lu_partial_piv()
PartialPivLU.h PartialPivLU.h
*/ */
typedef typename MatrixType::Index Index; typedef typename MatrixType::Index Index;
Index rows = internal::random<Index>(1,4); Index size = internal::random<Index>(1,4);
Index cols = rows;
MatrixType m1(cols, rows); MatrixType m1(size, size), m2(size, size), m3(size, size);
m1.setRandom(); m1.setRandom();
PartialPivLU<MatrixType> plu(m1); PartialPivLU<MatrixType> plu(m1);
VERIFY_IS_APPROX(m1, plu.reconstructedMatrix()); VERIFY_IS_APPROX(m1, plu.reconstructedMatrix());
m3 = MatrixType::Random(size,size);
m2 = plu.solve(m3);
VERIFY_IS_APPROX(m3, m1*m2);
VERIFY_IS_APPROX(m2, plu.inverse()*m3);
// test solve with transposed
plu.template _solve_impl_transposed<false>(m3, m2);
VERIFY_IS_APPROX(m3, m1.transpose()*m2);
// test solve with conjugate transposed
plu.template _solve_impl_transposed<true>(m3, m2);
VERIFY_IS_APPROX(m3, m1.adjoint()*m2);
} }
template<typename MatrixType> void lu_verify_assert() template<typename MatrixType> void lu_verify_assert()