Clean a bit the implementation of inverse permutations

This commit is contained in:
Gael Guennebaud 2015-10-08 18:36:39 +02:00
parent 8d00a953af
commit d866279364
5 changed files with 50 additions and 75 deletions

View File

@ -382,8 +382,6 @@ using std::ptrdiff_t;
#include "src/Core/DiagonalMatrix.h"
#include "src/Core/Diagonal.h"
#include "src/Core/DiagonalProduct.h"
#include "src/Core/PermutationMatrix.h"
#include "src/Core/Transpositions.h"
#include "src/Core/Redux.h"
#include "src/Core/Visitor.h"
#include "src/Core/Fuzzy.h"
@ -393,6 +391,8 @@ using std::ptrdiff_t;
#include "src/Core/GeneralProduct.h"
#include "src/Core/Solve.h"
#include "src/Core/Inverse.h"
#include "src/Core/PermutationMatrix.h"
#include "src/Core/Transpositions.h"
#include "src/Core/TriangularMatrix.h"
#include "src/Core/SelfAdjointView.h"
#include "src/Core/products/GeneralBlockPanelKernel.h"

View File

@ -47,11 +47,12 @@ public:
typedef typename XprType::PlainObject PlainObject;
typedef typename internal::ref_selector<XprType>::type XprTypeNested;
typedef typename internal::remove_all<XprTypeNested>::type XprTypeNestedCleaned;
typedef typename internal::ref_selector<Inverse>::type Nested;
explicit Inverse(const XprType &xpr)
: m_xpr(xpr)
{}
EIGEN_DEVICE_FUNC Index rows() const { return m_xpr.rows(); }
EIGEN_DEVICE_FUNC Index cols() const { return m_xpr.cols(); }

View File

@ -2,7 +2,7 @@
// for linear algebra.
//
// Copyright (C) 2009 Benoit Jacob <jacob.benoit.1@gmail.com>
// Copyright (C) 2009-2011 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2009-2015 Gael Guennebaud <gael.guennebaud@inria.fr>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
@ -13,9 +13,6 @@
namespace Eigen {
// TODO: this does not seems to be needed at all:
// template<int RowCol,typename IndicesType,typename MatrixType, typename StorageKind> class PermutedImpl;
/** \class PermutationBase
* \ingroup Core_Module
*
@ -67,8 +64,9 @@ class PermutationBase : public EigenBase<Derived>
DenseMatrixType;
typedef PermutationMatrix<IndicesType::SizeAtCompileTime,IndicesType::MaxSizeAtCompileTime,StorageIndex>
PlainPermutationType;
typedef PlainPermutationType PlainObject;
using Base::derived;
typedef Transpose<PermutationBase> TransposeReturnType;
typedef Inverse<Derived> InverseReturnType;
typedef void Scalar;
#endif
@ -196,14 +194,14 @@ class PermutationBase : public EigenBase<Derived>
*
* \note \note_try_to_help_rvo
*/
inline TransposeReturnType inverse() const
{ return TransposeReturnType(derived()); }
inline InverseReturnType inverse() const
{ return InverseReturnType(derived()); }
/** \returns the tranpose permutation matrix.
*
* \note \note_try_to_help_rvo
*/
inline TransposeReturnType transpose() const
{ return TransposeReturnType(derived()); }
inline InverseReturnType transpose() const
{ return InverseReturnType(derived()); }
/**** multiplication helpers to hopefully get RVO ****/
@ -238,7 +236,7 @@ class PermutationBase : public EigenBase<Derived>
* \note \note_try_to_help_rvo
*/
template<typename Other>
inline PlainPermutationType operator*(const Transpose<PermutationBase<Other> >& other) const
inline PlainPermutationType operator*(const InverseImpl<Other,PermutationStorage>& other) const
{ return PlainPermutationType(internal::PermPermProduct, *this, other.eval()); }
/** \returns the product of an inverse permutation with another permutation.
@ -246,7 +244,7 @@ class PermutationBase : public EigenBase<Derived>
* \note \note_try_to_help_rvo
*/
template<typename Other> friend
inline PlainPermutationType operator*(const Transpose<PermutationBase<Other> >& other, const PermutationBase& perm)
inline PlainPermutationType operator*(const InverseImpl<Other, PermutationStorage>& other, const PermutationBase& perm)
{ return PlainPermutationType(internal::PermPermProduct, other.eval(), perm); }
/** \returns the determinant of the permutation matrix, which is either 1 or -1 depending on the parity of the permutation.
@ -398,13 +396,13 @@ class PermutationMatrix : public PermutationBase<PermutationMatrix<SizeAtCompile
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename Other>
PermutationMatrix(const Transpose<PermutationBase<Other> >& other)
: m_indices(other.nestedExpression().size())
PermutationMatrix(const InverseImpl<Other,PermutationStorage>& other)
: m_indices(other.derived().nestedExpression().size())
{
eigen_internal_assert(m_indices.size() <= NumTraits<StorageIndex>::highest());
StorageIndex end = StorageIndex(m_indices.size());
for (StorageIndex i=0; i<end;++i)
m_indices.coeffRef(other.nestedExpression().indices().coeff(i)) = i;
m_indices.coeffRef(other.derived().nestedExpression().indices().coeff(i)) = i;
}
template<typename Lhs,typename Rhs>
PermutationMatrix(internal::PermPermProduct_t, const Lhs& lhs, const Rhs& rhs)
@ -564,84 +562,61 @@ operator*(const PermutationBase<PermutationDerived> &permutation,
(permutation.derived(), matrix.derived());
}
namespace internal {
/* Template partial specialization for transposed/inverse permutations */
template<typename Derived>
struct traits<Transpose<PermutationBase<Derived> > >
: traits<Derived>
{};
} // end namespace internal
// TODO: the specificties should be handled by the evaluator,
// at the very least we should only specialize TransposeImpl
template<typename Derived>
class Transpose<PermutationBase<Derived> >
: public EigenBase<Transpose<PermutationBase<Derived> > >
template<typename PermutationType>
class InverseImpl<PermutationType, PermutationStorage>
: public EigenBase<Inverse<PermutationType> >
{
typedef Derived PermutationType;
typedef typename PermutationType::IndicesType IndicesType;
typedef typename PermutationType::PlainPermutationType PlainPermutationType;
typedef internal::traits<PermutationType> PermTraits;
protected:
InverseImpl() {}
public:
typedef Inverse<PermutationType> InverseType;
using EigenBase<Inverse<PermutationType> >::derived;
#ifndef EIGEN_PARSED_BY_DOXYGEN
typedef internal::traits<PermutationType> Traits;
typedef typename Derived::DenseMatrixType DenseMatrixType;
typedef typename PermutationType::DenseMatrixType DenseMatrixType;
enum {
Flags = Traits::Flags,
RowsAtCompileTime = Traits::RowsAtCompileTime,
ColsAtCompileTime = Traits::ColsAtCompileTime,
MaxRowsAtCompileTime = Traits::MaxRowsAtCompileTime,
MaxColsAtCompileTime = Traits::MaxColsAtCompileTime
RowsAtCompileTime = PermTraits::RowsAtCompileTime,
ColsAtCompileTime = PermTraits::ColsAtCompileTime,
MaxRowsAtCompileTime = PermTraits::MaxRowsAtCompileTime,
MaxColsAtCompileTime = PermTraits::MaxColsAtCompileTime
};
typedef typename Traits::Scalar Scalar;
typedef typename Traits::StorageIndex StorageIndex;
#endif
Transpose(const PermutationType& p) : m_permutation(p) {}
inline Index rows() const { return m_permutation.rows(); }
inline Index cols() const { return m_permutation.cols(); }
#ifndef EIGEN_PARSED_BY_DOXYGEN
template<typename DenseDerived>
void evalTo(MatrixBase<DenseDerived>& other) const
{
other.setZero();
for (Index i=0; i<rows();++i)
other.coeffRef(i, m_permutation.indices().coeff(i)) = typename DenseDerived::Scalar(1);
for (Index i=0; i<derived().rows();++i)
other.coeffRef(i, derived().nestedExpression().indices().coeff(i)) = typename DenseDerived::Scalar(1);
}
#endif
/** \return the equivalent permutation matrix */
PlainPermutationType eval() const { return *this; }
PlainPermutationType eval() const { return derived(); }
DenseMatrixType toDenseMatrix() const { return *this; }
DenseMatrixType toDenseMatrix() const { return derived(); }
/** \returns the matrix with the inverse permutation applied to the columns.
*/
template<typename OtherDerived> friend
const Product<OtherDerived, Transpose, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix, const Transpose& trPerm)
const Product<OtherDerived, InverseType, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix, const InverseType& trPerm)
{
return Product<OtherDerived, Transpose, AliasFreeProduct>(matrix.derived(), trPerm.derived());
return Product<OtherDerived, InverseType, AliasFreeProduct>(matrix.derived(), trPerm.derived());
}
/** \returns the matrix with the inverse permutation applied to the rows.
*/
template<typename OtherDerived>
const Product<Transpose, OtherDerived, AliasFreeProduct>
const Product<InverseType, OtherDerived, AliasFreeProduct>
operator*(const MatrixBase<OtherDerived>& matrix) const
{
return Product<Transpose, OtherDerived, AliasFreeProduct>(*this, matrix.derived());
return Product<InverseType, OtherDerived, AliasFreeProduct>(derived(), matrix.derived());
}
const PermutationType& nestedExpression() const { return m_permutation; }
protected:
const PermutationType& m_permutation;
};
template<typename Derived>

View File

@ -908,20 +908,20 @@ struct generic_product_impl<Lhs, Rhs, MatrixShape, PermutationShape, ProductTag>
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Transpose<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag>
struct generic_product_impl<Inverse<Lhs>, Rhs, PermutationShape, MatrixShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Transpose<Lhs>& lhs, const Rhs& rhs)
static void evalTo(Dest& dst, const Inverse<Lhs>& lhs, const Rhs& rhs)
{
permutation_matrix_product<Rhs, OnTheLeft, true, MatrixShape>::run(dst, lhs.nestedExpression(), rhs);
}
};
template<typename Lhs, typename Rhs, int ProductTag, typename MatrixShape>
struct generic_product_impl<Lhs, Transpose<Rhs>, MatrixShape, PermutationShape, ProductTag>
struct generic_product_impl<Lhs, Inverse<Rhs>, MatrixShape, PermutationShape, ProductTag>
{
template<typename Dest>
static void evalTo(Dest& dst, const Lhs& lhs, const Transpose<Rhs>& rhs)
static void evalTo(Dest& dst, const Lhs& lhs, const Inverse<Rhs>& rhs)
{
permutation_matrix_product<Lhs, OnTheRight, true, MatrixShape>::run(dst, rhs.nestedExpression(), lhs);
}

View File

@ -144,23 +144,22 @@ operator*( const PermutationBase<PermDerived>& perm, const SparseMatrixBase<Spar
{ return Product<PermDerived, SparseDerived>(perm.derived(), matrix.derived()); }
// TODO, the following specializations should not be needed as Transpose<Permutation*> should be a PermutationBase.
/** \returns the matrix with the inverse permutation applied to the columns.
*/
template<typename SparseDerived, typename PermDerived>
inline const Product<SparseDerived, Transpose<PermutationBase<PermDerived> > >
operator*(const SparseMatrixBase<SparseDerived>& matrix, const Transpose<PermutationBase<PermDerived> >& tperm)
template<typename SparseDerived, typename PermutationType>
inline const Product<SparseDerived, Inverse<PermutationType > >
operator*(const SparseMatrixBase<SparseDerived>& matrix, const InverseImpl<PermutationType, PermutationStorage>& tperm)
{
return Product<SparseDerived, Transpose<PermutationBase<PermDerived> > >(matrix.derived(), tperm);
return Product<SparseDerived, Inverse<PermutationType> >(matrix.derived(), tperm.derived());
}
/** \returns the matrix with the inverse permutation applied to the rows.
*/
template<typename SparseDerived, typename PermDerived>
inline const Product<Transpose<PermutationBase<PermDerived> >, SparseDerived>
operator*(const Transpose<PermutationBase<PermDerived> >& tperm, const SparseMatrixBase<SparseDerived>& matrix)
template<typename SparseDerived, typename PermutationType>
inline const Product<Inverse<PermutationType>, SparseDerived>
operator*(const InverseImpl<PermutationType,PermutationStorage>& tperm, const SparseMatrixBase<SparseDerived>& matrix)
{
return Product<Transpose<PermutationBase<PermDerived> >, SparseDerived>(tperm, matrix.derived());
return Product<Inverse<PermutationType>, SparseDerived>(tperm.derived(), matrix.derived());
}
} // end namespace Eigen