Enable CompleteOrthogonalDecomposition::pseudoInverse with non-square fixed-size matrices.

This commit is contained in:
Gael Guennebaud 2019-11-13 21:16:53 +01:00
parent 002e5b6db6
commit 8496f86f84
4 changed files with 22 additions and 7 deletions

View File

@ -1,7 +1,7 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2014 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2014-2019 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -44,7 +44,6 @@ class Inverse : public InverseImpl<XprType,typename internal::traits<XprType>::S
{
public:
typedef typename XprType::StorageIndex StorageIndex;
typedef typename XprType::PlainObject PlainObject;
typedef typename XprType::Scalar Scalar;
typedef typename internal::ref_selector<XprType>::type XprTypeNested;
typedef typename internal::remove_all<XprTypeNested>::type XprTypeNestedCleaned;
@ -55,8 +54,8 @@ public:
: m_xpr(xpr)
{}
EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); }
EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); }
EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.cols(); }
EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.rows(); }
EIGEN_DEVICE_FUNC const XprTypeNestedCleaned& nestedExpression() const { return m_xpr; }

View File

@ -592,6 +592,13 @@ void CompleteOrthogonalDecomposition<_MatrixType>::_solve_impl_transposed(const
namespace internal {
template<typename MatrixType>
struct traits<Inverse<CompleteOrthogonalDecomposition<MatrixType> > >
: traits<typename Transpose<typename MatrixType::PlainObject>::PlainObject>
{
enum { Flags = 0 };
};
template<typename DstXprType, typename MatrixType>
struct Assignment<DstXprType, Inverse<CompleteOrthogonalDecomposition<MatrixType> >, internal::assign_op<typename DstXprType::Scalar,typename CompleteOrthogonalDecomposition<MatrixType>::Scalar>, Dense2Dense>
{
@ -599,7 +606,8 @@ struct Assignment<DstXprType, Inverse<CompleteOrthogonalDecomposition<MatrixType
typedef Inverse<CodType> SrcXprType;
static void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op<typename DstXprType::Scalar,typename CodType::Scalar> &)
{
dst = src.nestedExpression().solve(MatrixType::Identity(src.rows(), src.rows()));
typedef Matrix<typename CodType::Scalar, CodType::RowsAtCompileTime, CodType::RowsAtCompileTime, 0, CodType::MaxRowsAtCompileTime, CodType::MaxRowsAtCompileTime> IdentityMatrixType;
dst = src.nestedExpression().solve(IdentityMatrixType::Identity(src.cols(), src.cols()));
}
};

View File

@ -54,7 +54,11 @@ template<typename MatrixType> void permutationmatrices(const MatrixType& m)
m_permuted = m_original;
VERIFY_EVALUATION_COUNT(m_permuted = lp * m_permuted * rp, 1);
VERIFY_IS_APPROX(m_permuted, lm*m_original*rm);
LeftPermutationType lpi;
lpi = lp.inverse();
VERIFY_IS_APPROX(lpi*m_permuted,lp.inverse()*m_permuted);
VERIFY_IS_APPROX(lp.inverse()*m_permuted*rp.inverse(), m_original);
VERIFY_IS_APPROX(lv.asPermutation().inverse()*m_permuted*rv.asPermutation().inverse(), m_original);
VERIFY_IS_APPROX(MapLeftPerm(lv.data(),lv.size()).inverse()*m_permuted*MapRightPerm(rv.data(),rv.size()).inverse(), m_original);

View File

@ -70,10 +70,11 @@ void cod_fixedsize() {
Cols = MatrixType::ColsAtCompileTime
};
typedef typename MatrixType::Scalar Scalar;
typedef CompleteOrthogonalDecomposition<Matrix<Scalar, Rows, Cols> > COD;
int rank = internal::random<int>(1, (std::min)(int(Rows), int(Cols)) - 1);
Matrix<Scalar, Rows, Cols> matrix;
createRandomPIMatrixOfRank(rank, Rows, Cols, matrix);
CompleteOrthogonalDecomposition<Matrix<Scalar, Rows, Cols> > cod(matrix);
COD cod(matrix);
VERIFY(rank == cod.rank());
VERIFY(Cols - cod.rank() == cod.dimensionOfKernel());
VERIFY(cod.isInjective() == (rank == Rows));
@ -90,6 +91,9 @@ void cod_fixedsize() {
JacobiSVD<MatrixType> svd(matrix, ComputeFullU | ComputeFullV);
Matrix<Scalar, Cols, Cols2> svd_solution = svd.solve(rhs);
VERIFY_IS_APPROX(cod_solution, svd_solution);
typename Inverse<COD>::PlainObject pinv = cod.pseudoInverse();
VERIFY_IS_APPROX(cod_solution, pinv * rhs);
}
template<typename MatrixType> void qr()