Clean a bit the implementation of inverse permutations

This commit is contained in:
Gael Guennebaud 2015-10-08 18:36:39 +02:00
parent 8d00a953af
commit d866279364
5 changed files with 50 additions and 75 deletions

View File

@ -382,8 +382,6 @@ using std::ptrdiff_t;
#include "src/Core/DiagonalMatrix.h" #include "src/Core/DiagonalMatrix.h"
#include "src/Core/Diagonal.h" #include "src/Core/Diagonal.h"
#include "src/Core/DiagonalProduct.h" #include "src/Core/DiagonalProduct.h"
#include "src/Core/PermutationMatrix.h"
#include "src/Core/Transpositions.h"
#include "src/Core/Redux.h" #include "src/Core/Redux.h"
#include "src/Core/Visitor.h" #include "src/Core/Visitor.h"
#include "src/Core/Fuzzy.h" #include "src/Core/Fuzzy.h"
@ -393,6 +391,8 @@ using std::ptrdiff_t;
#include "src/Core/GeneralProduct.h" #include "src/Core/GeneralProduct.h"
#include "src/Core/Solve.h" #include "src/Core/Solve.h"
#include "src/Core/Inverse.h" #include "src/Core/Inverse.h"
#include "src/Core/PermutationMatrix.h"
#include "src/Core/Transpositions.h"
#include "src/Core/TriangularMatrix.h" #include "src/Core/TriangularMatrix.h"
#include "src/Core/SelfAdjointView.h" #include "src/Core/SelfAdjointView.h"
#include "src/Core/products/GeneralBlockPanelKernel.h" #include "src/Core/products/GeneralBlockPanelKernel.h"

View File

@ -47,11 +47,12 @@ public:
typedef typename XprType::PlainObject PlainObject; typedef typename XprType::PlainObject PlainObject;
typedef typename internal::ref_selector<XprType>::type XprTypeNested; typedef typename internal::ref_selector<XprType>::type XprTypeNested;
typedef typename internal::remove_all<XprTypeNested>::type XprTypeNestedCleaned; typedef typename internal::remove_all<XprTypeNested>::type XprTypeNestedCleaned;
typedef typename internal::ref_selector<Inverse>::type Nested;
explicit Inverse(const XprType &xpr) explicit Inverse(const XprType &xpr)
: m_xpr(xpr) : m_xpr(xpr)
{} {}
EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); } EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); }
EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); } EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); }

View File

@ -2,7 +2,7 @@
// for linear algebra. // for linear algebra.
// //
// Copyright (C) 2009 Benoit Jacob <jacob.benoit.1@gmail.com> // Copyright (C) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2009-2011 Gael Guennebaud <gael.guennebaud@inria.fr> // Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
// //
// This Source Code Form is subject to the terms of the Mozilla // This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed // Public License v. 2.0. If a copy of the MPL was not distributed
@ -13,9 +13,6 @@
namespace Eigen { namespace Eigen {
// TODO: this does not seems to be needed at all:
// template<int RowCol,typename IndicesType,typename MatrixType, typename StorageKind> class PermutedImpl;
/** \class PermutationBase /** \class PermutationBase
* \ingroup Core_Module * \ingroup Core_Module
* *
@ -67,8 +64,9 @@ class PermutationBase : public EigenBase<Derived>
DenseMatrixType; DenseMatrixType;
typedef PermutationMatrix<IndicesType::SizeAtCompileTime,IndicesType::MaxSizeAtCompileTime,StorageIndex> typedef PermutationMatrix<IndicesType::SizeAtCompileTime,IndicesType::MaxSizeAtCompileTime,StorageIndex>
PlainPermutationType; PlainPermutationType;
typedef PlainPermutationType PlainObject;
using Base::derived; using Base::derived;
typedef Transpose<PermutationBase> TransposeReturnType; typedef Inverse<Derived> InverseReturnType;
typedef void Scalar; typedef void Scalar;
#endif #endif
@ -196,14 +194,14 @@ class PermutationBase : public EigenBase<Derived>
* *
* \note \note_try_to_help_rvo * \note \note_try_to_help_rvo
*/ */
inline TransposeReturnType inverse() const inline InverseReturnType inverse() const
{ return TransposeReturnType(derived()); } { return InverseReturnType(derived()); }
/** \returns the tranpose permutation matrix. /** \returns the tranpose permutation matrix.
* *
* \note \note_try_to_help_rvo * \note \note_try_to_help_rvo
*/ */
inline TransposeReturnType transpose() const inline InverseReturnType transpose() const
{ return TransposeReturnType(derived()); } { return InverseReturnType(derived()); }
/**** multiplication helpers to hopefully get RVO ****/ /**** multiplication helpers to hopefully get RVO ****/
@ -238,7 +236,7 @@ class PermutationBase : public EigenBase<Derived>
* \note \note_try_to_help_rvo * \note \note_try_to_help_rvo
*/ */
template<typename Other> template<typename Other>
inline PlainPermutationType operator*(const Transpose<PermutationBase<Other> >& other) const inline PlainPermutationType operator*(const InverseImpl<Other,PermutationStorage>& other) const
{ return PlainPermutationType(internal::PermPermProduct, *this, other.eval()); } { return PlainPermutationType(internal::PermPermProduct, *this, other.eval()); }
/** \returns the product of an inverse permutation with another permutation. /** \returns the product of an inverse permutation with another permutation.
@ -246,7 +244,7 @@ class PermutationBase : public EigenBase<Derived>
* \note \note_try_to_help_rvo * \note \note_try_to_help_rvo
*/ */
template<typename Other> friend template<typename Other> friend
inline PlainPermutationType operator*(const Transpose<PermutationBase<Other> >& other, const PermutationBase& perm) inline PlainPermutationType operator*(const InverseImpl<Other, PermutationStorage>& other, const PermutationBase& perm)
{ return PlainPermutationType(internal::PermPermProduct, other.eval(), perm); } { return PlainPermutationType(internal::PermPermProduct, other.eval(), perm); }
/** \returns the determinant of the permutation matrix, which is either 1 or -1 depending on the parity of the permutation. /** \returns the determinant of the permutation matrix, which is either 1 or -1 depending on the parity of the permutation.
@ -398,13 +396,13 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename Other> template<typename Other>
PermutationMatrix(const Transpose<PermutationBase<Other> >& other) PermutationMatrix(const InverseImpl<Other,PermutationStorage>& other)
: m_indices(other.nestedExpression().size()) : m_indices(other.derived().nestedExpression().size())
{ {
eigen_internal_assert(m_indices.size() <= NumTraits<StorageIndex>::highest()); eigen_internal_assert(m_indices.size() <= NumTraits<StorageIndex>::highest());
StorageIndex end = StorageIndex(m_indices.size()); StorageIndex end = StorageIndex(m_indices.size());
for (StorageIndex i=0; i<end;++i) for (StorageIndex i=0; i<end;++i)
m_indices.coeffRef(other.nestedExpression().indices().coeff(i)) = i; m_indices.coeffRef(other.derived().nestedExpression().indices().coeff(i)) = i;
} }
template<typename Lhs,typename Rhs> template<typename Lhs,typename Rhs>
PermutationMatrix(internal::PermPermProduct_t, const Lhs& lhs, const Rhs& rhs) PermutationMatrix(internal::PermPermProduct_t, const Lhs& lhs, const Rhs& rhs)
@ -564,84 +562,61 @@ operator*(const PermutationBase<PermutationDerived> &permutation,
(permutation.derived(), matrix.derived()); (permutation.derived(), matrix.derived());
} }
namespace internal {
/* Template partial specialization for transposed/inverse permutations */ template<typename PermutationType>
class InverseImpl<PermutationType, PermutationStorage>
template<typename Derived> : public EigenBase<Inverse<PermutationType> >
struct traits<Transpose<PermutationBase<Derived> > >
: traits<Derived>
{};
} // end namespace internal
// TODO: the specificties should be handled by the evaluator,
// at the very least we should only specialize TransposeImpl
template<typename Derived>
class Transpose<PermutationBase<Derived> >
: public EigenBase<Transpose<PermutationBase<Derived> > >
{ {
typedef Derived PermutationType;
typedef typename PermutationType::IndicesType IndicesType;
typedef typename PermutationType::PlainPermutationType PlainPermutationType; typedef typename PermutationType::PlainPermutationType PlainPermutationType;
typedef internal::traits<PermutationType> PermTraits;
protected:
InverseImpl() {}
public: public:
typedef Inverse<PermutationType> InverseType;
using EigenBase<Inverse<PermutationType> >::derived;
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
typedef internal::traits<PermutationType> Traits; typedef typename PermutationType::DenseMatrixType DenseMatrixType;
typedef typename Derived::DenseMatrixType DenseMatrixType;
enum { enum {
Flags = Traits::Flags, RowsAtCompileTime = PermTraits::RowsAtCompileTime,
RowsAtCompileTime = Traits::RowsAtCompileTime, ColsAtCompileTime = PermTraits::ColsAtCompileTime,
ColsAtCompileTime = Traits::ColsAtCompileTime, MaxRowsAtCompileTime = PermTraits::MaxRowsAtCompileTime,
MaxRowsAtCompileTime = Traits::MaxRowsAtCompileTime, MaxColsAtCompileTime = PermTraits::MaxColsAtCompileTime
MaxColsAtCompileTime = Traits::MaxColsAtCompileTime
}; };
typedef typename Traits::Scalar Scalar;
typedef typename Traits::StorageIndex StorageIndex;
#endif #endif
Transpose(const PermutationType& p) : m_permutation(p) {}
inline Index rows() const { return m_permutation.rows(); }
inline Index cols() const { return m_permutation.cols(); }
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename DenseDerived> template<typename DenseDerived>
void evalTo(MatrixBase<DenseDerived>& other) const void evalTo(MatrixBase<DenseDerived>& other) const
{ {
other.setZero(); other.setZero();
for (Index i=0; i<rows();++i) for (Index i=0; i<derived().rows();++i)
other.coeffRef(i, m_permutation.indices().coeff(i)) = typename DenseDerived::Scalar(1); other.coeffRef(i, derived().nestedExpression().indices().coeff(i)) = typename DenseDerived::Scalar(1);
} }
#endif #endif
/** \return the equivalent permutation matrix */ /** \return the equivalent permutation matrix */
PlainPermutationType eval() const { return *this; } PlainPermutationType eval() const { return derived(); }
DenseMatrixType toDenseMatrix() const { return *this; } DenseMatrixType toDenseMatrix() const { return derived(); }
/** \returns the matrix with the inverse permutation applied to the columns. /** \returns the matrix with the inverse permutation applied to the columns.
*/ */
template<typename OtherDerived> friend template<typename OtherDerived> friend
const Product<OtherDerived, Transpose, AliasFreeProduct> const Product<OtherDerived, InverseType, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trPerm) operator*(const MatrixBase<OtherDerived>& matrix, const InverseType& trPerm)
{ {
return Product<OtherDerived, Transpose, AliasFreeProduct>(matrix.derived(), trPerm.derived()); return Product<OtherDerived, InverseType, AliasFreeProduct>(matrix.derived(), trPerm.derived());
} }
/** \returns the matrix with the inverse permutation applied to the rows. /** \returns the matrix with the inverse permutation applied to the rows.
*/ */
template<typename OtherDerived> template<typename OtherDerived>
const Product<Transpose, OtherDerived, AliasFreeProduct> const Product<InverseType, OtherDerived, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix) const operator*(const MatrixBase<OtherDerived>& matrix) const
{ {
return Product<Transpose, OtherDerived, AliasFreeProduct>(*this, matrix.derived()); return Product<InverseType, OtherDerived, AliasFreeProduct>(derived(), matrix.derived());
} }
const PermutationType& nestedExpression() const { return m_permutation; }
protected:
const PermutationType& m_permutation;
}; };
template<typename Derived> template<typename Derived>

View File

@ -908,20 +908,20 @@ struct generic_product_impl<Lhs, Rhs, MatrixShape, PermutationShape, ProductTag>
}; };
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag> struct generic_product_impl<Inverse<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag>
{ {
template<typename Dest> template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs) static void evalTo(Dest& dst, const Inverse<Lhs>& lhs, const Rhs& rhs)
{ {
permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs); permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
} }
}; };
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape> template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape, ProductTag> struct generic_product_impl<Lhs, Inverse<Rhs>, MatrixShape, PermutationShape, ProductTag>
{ {
template<typename Dest> template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs) static void evalTo(Dest& dst, const Lhs& lhs, const Inverse<Rhs>& rhs)
{ {
permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs); permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
} }

View File

@ -144,23 +144,22 @@ operator*( const PermutationBase<PermDerived>& perm, const SparseMatrixBase<Spar
{ return Product<PermDerived, SparseDerived>(perm.derived(), matrix.derived()); } { return Product<PermDerived, SparseDerived>(perm.derived(), matrix.derived()); }
// TODO, the following specializations should not be needed as Transpose<Permutation*> should be a PermutationBase.
/** \returns the matrix with the inverse permutation applied to the columns. /** \returns the matrix with the inverse permutation applied to the columns.
*/ */
template<typename SparseDerived, typename PermDerived> template<typename SparseDerived, typename PermutationType>
inline const Product<SparseDerived, Transpose<PermutationBase<PermDerived> > > inline const Product<SparseDerived, Inverse<PermutationType > >
operator*(const SparseMatrixBase<SparseDerived>& matrix, const Transpose<PermutationBase<PermDerived> >& tperm) operator*(const SparseMatrixBase<SparseDerived>& matrix, const InverseImpl<PermutationType, PermutationStorage>& tperm)
{ {
return Product<SparseDerived, Transpose<PermutationBase<PermDerived> > >(matrix.derived(), tperm); return Product<SparseDerived, Inverse<PermutationType> >(matrix.derived(), tperm.derived());
} }
/** \returns the matrix with the inverse permutation applied to the rows. /** \returns the matrix with the inverse permutation applied to the rows.
*/ */
template<typename SparseDerived, typename PermDerived> template<typename SparseDerived, typename PermutationType>
inline const Product<Transpose<PermutationBase<PermDerived> >, SparseDerived> inline const Product<Inverse<PermutationType>, SparseDerived>
operator*(const Transpose<PermutationBase<PermDerived> >& tperm, const SparseMatrixBase<SparseDerived>& matrix) operator*(const InverseImpl<PermutationType,PermutationStorage>& tperm, const SparseMatrixBase<SparseDerived>& matrix)
{ {
return Product<Transpose<PermutationBase<PermDerived> >, SparseDerived>(tperm, matrix.derived()); return Product<Inverse<PermutationType>, SparseDerived>(tperm.derived(), matrix.derived());
} }
} // end namespace Eigen } // end namespace Eigen