From fad36cc8148fc4f4581ebb5b7c4a0ae4438df00a Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 19 Jun 2015 10:51:57 +0200 Subject: [PATCH] Clean implementation of permutation * matrix products. --- Eigen/src/Core/PermutationMatrix.h | 78 ------------- Eigen/src/Core/ProductEvaluators.h | 93 +++++++++++++--- Eigen/src/SparseCore/SparsePermutation.h | 133 ++++++----------------- 3 files changed, 108 insertions(+), 196 deletions(-) diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h index 99f5aecdd..9a0c03612 100644 --- a/Eigen/src/Core/PermutationMatrix.h +++ b/Eigen/src/Core/PermutationMatrix.h @@ -42,10 +42,6 @@ namespace Eigen { namespace internal { -template -struct permut_matrix_product_retval; -template -struct permut_sparsematrix_product_retval; enum PermPermProduct_t {PermPermProduct}; } // end namespace internal @@ -570,80 +566,6 @@ operator*(const PermutationBase &permutation, namespace internal { -template -struct traits > - : traits -{ - typedef typename MatrixType::PlainObject ReturnType; -}; - -template -struct permut_matrix_product_retval - : public ReturnByValue > -{ - typedef typename remove_all::type MatrixTypeNestedCleaned; - typedef typename MatrixType::StorageIndex StorageIndex; - - permut_matrix_product_retval(const PermutationType& perm, const MatrixType& matrix) - : m_permutation(perm), m_matrix(matrix) - {} - - inline Index rows() const { return m_matrix.rows(); } - inline Index cols() const { return m_matrix.cols(); } - - template inline void evalTo(Dest& dst) const - { - const Index n = Side==OnTheLeft ? rows() : cols(); - // FIXME we need an is_same for expression that is not sensitive to constness. For instance - // is_same_xpr, Block >::value should be true. - //if(is_same::value && extract_data(dst) == extract_data(m_matrix)) - if(is_same_dense(dst, m_matrix)) - { - // apply the permutation inplace - Matrix mask(m_permutation.size()); - mask.fill(false); - Index r = 0; - while(r < m_permutation.size()) - { - // search for the next seed - while(r=m_permutation.size()) - break; - // we got one, let's follow it until we are back to the seed - Index k0 = r++; - Index kPrev = k0; - mask.coeffRef(k0) = true; - for(Index k=m_permutation.indices().coeff(k0); k!=k0; k=m_permutation.indices().coeff(k)) - { - Block(dst, k) - .swap(Block - (dst,((Side==OnTheLeft) ^ Transposed) ? k0 : kPrev)); - - mask.coeffRef(k) = true; - kPrev = k; - } - } - } - else - { - for(Index i = 0; i < n; ++i) - { - Block - (dst, ((Side==OnTheLeft) ^ Transposed) ? m_permutation.indices().coeff(i) : i) - - = - - Block - (m_matrix, ((Side==OnTheRight) ^ Transposed) ? m_permutation.indices().coeff(i) : i); - } - } - } - - protected: - const PermutationType& m_permutation; - typename MatrixType::Nested m_matrix; -}; - /* Template partial specialization for transposed/inverse permutations */ template diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 22b5e024b..9d1cb5d56 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -825,48 +825,107 @@ struct product_evaluator, ProductTag, DenseShape, /*************************************************************************** * Products with permutation matrices ***************************************************************************/ - -template -struct generic_product_impl + +/** \internal + * \class permutation_matrix_product + * Internal helper class implementing the product between a permutation matrix and a matrix. + * This class is specialized for DenseShape below and for SparseShape in SparseCore/SparsePermutation.h + */ +template +struct permutation_matrix_product; + +template +struct permutation_matrix_product +{ + typedef typename remove_all::type MatrixTypeCleaned; + + template + static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat) + { + const Index n = Side==OnTheLeft ? mat.rows() : mat.cols(); + // FIXME we need an is_same for expression that is not sensitive to constness. For instance + // is_same_xpr, Block >::value should be true. + //if(is_same::value && extract_data(dst) == extract_data(mat)) + if(is_same_dense(dst, mat)) + { + // apply the permutation inplace + Matrix mask(perm.size()); + mask.fill(false); + Index r = 0; + while(r < perm.size()) + { + // search for the next seed + while(r=perm.size()) + break; + // we got one, let's follow it until we are back to the seed + Index k0 = r++; + Index kPrev = k0; + mask.coeffRef(k0) = true; + for(Index k=perm.indices().coeff(k0); k!=k0; k=perm.indices().coeff(k)) + { + Block(dst, k) + .swap(Block + (dst,((Side==OnTheLeft) ^ Transposed) ? k0 : kPrev)); + + mask.coeffRef(k) = true; + kPrev = k; + } + } + } + else + { + for(Index i = 0; i < n; ++i) + { + Block + (dst, ((Side==OnTheLeft) ^ Transposed) ? perm.indices().coeff(i) : i) + + = + + Block + (mat, ((Side==OnTheRight) ^ Transposed) ? perm.indices().coeff(i) : i); + } + } + } +}; + +template +struct generic_product_impl { template static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) { - permut_matrix_product_retval pmpr(lhs, rhs); - pmpr.evalTo(dst); + permutation_matrix_product::run(dst, lhs, rhs); } }; -template -struct generic_product_impl +template +struct generic_product_impl { template static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) { - permut_matrix_product_retval pmpr(rhs, lhs); - pmpr.evalTo(dst); + permutation_matrix_product::run(dst, rhs, lhs); } }; -template -struct generic_product_impl, Rhs, PermutationShape, DenseShape, ProductTag> +template +struct generic_product_impl, Rhs, PermutationShape, MatrixShape, ProductTag> { template static void evalTo(Dest& dst, const Transpose& lhs, const Rhs& rhs) { - permut_matrix_product_retval pmpr(lhs.nestedPermutation(), rhs); - pmpr.evalTo(dst); + permutation_matrix_product::run(dst, lhs.nestedPermutation(), rhs); } }; -template -struct generic_product_impl, DenseShape, PermutationShape, ProductTag> +template +struct generic_product_impl, MatrixShape, PermutationShape, ProductTag> { template static void evalTo(Dest& dst, const Lhs& lhs, const Transpose& rhs) { - permut_matrix_product_retval pmpr(rhs.nestedPermutation(), lhs); - pmpr.evalTo(dst); + permutation_matrix_product::run(dst, rhs.nestedPermutation(), lhs); } }; diff --git a/Eigen/src/SparseCore/SparsePermutation.h b/Eigen/src/SparseCore/SparsePermutation.h index 4be93c18c..b15128979 100644 --- a/Eigen/src/SparseCore/SparsePermutation.h +++ b/Eigen/src/SparseCore/SparsePermutation.h @@ -16,25 +16,8 @@ namespace Eigen { namespace internal { -template -struct traits > -{ - typedef typename remove_all::type MatrixTypeNestedCleaned; - typedef typename MatrixTypeNestedCleaned::Scalar Scalar; - typedef typename MatrixTypeNestedCleaned::StorageIndex StorageIndex; - enum { - SrcStorageOrder = MatrixTypeNestedCleaned::Flags&RowMajorBit ? RowMajor : ColMajor, - MoveOuter = SrcStorageOrder==RowMajor ? Side==OnTheLeft : Side==OnTheRight - }; - - typedef typename internal::conditional, - SparseMatrix >::type ReturnType; -}; - -template -struct permut_sparsematrix_product_retval - : public ReturnByValue > +template +struct permutation_matrix_product { typedef typename remove_all::type MatrixTypeNestedCleaned; typedef typename MatrixTypeNestedCleaned::Scalar Scalar; @@ -44,61 +27,55 @@ struct permut_sparsematrix_product_retval SrcStorageOrder = MatrixTypeNestedCleaned::Flags&RowMajorBit ? RowMajor : ColMajor, MoveOuter = SrcStorageOrder==RowMajor ? Side==OnTheLeft : Side==OnTheRight }; + + typedef typename internal::conditional, + SparseMatrix >::type ReturnType; - permut_sparsematrix_product_retval(const PermutationType& perm, const MatrixType& matrix) - : m_permutation(perm), m_matrix(matrix) - {} - - inline int rows() const { return m_matrix.rows(); } - inline int cols() const { return m_matrix.cols(); } - - template inline void evalTo(Dest& dst) const + template + static inline void run(Dest& dst, const PermutationType& perm, const MatrixType& mat) { if(MoveOuter) { - SparseMatrix tmp(m_matrix.rows(), m_matrix.cols()); - Matrix sizes(m_matrix.outerSize()); - for(Index j=0; j tmp(mat.rows(), mat.cols()); + Matrix sizes(mat.outerSize()); + for(Index j=0; j tmp(m_matrix.rows(), m_matrix.cols()); + SparseMatrix tmp(mat.rows(), mat.cols()); Matrix sizes(tmp.outerSize()); sizes.setZero(); - PermutationMatrix perm; + PermutationMatrix perm_cpy; if((Side==OnTheLeft) ^ Transposed) - perm = m_permutation; + perm_cpy = perm; else - perm = m_permutation.transpose(); + perm_cpy = perm.transpose(); - for(Index j=0; j struct product_promote_storage_type { typedef Sparse ret; }; template struct product_promote_storage_type { typedef Sparse ret; }; - -// TODO, the following need cleaning, this is just a copy-past of the dense case - -template -struct generic_product_impl -{ - template - static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) - { - permut_sparsematrix_product_retval pmpr(lhs, rhs); - pmpr.evalTo(dst); - } -}; - -template -struct generic_product_impl -{ - template - static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) - { - permut_sparsematrix_product_retval pmpr(rhs, lhs); - pmpr.evalTo(dst); - } -}; - -template -struct generic_product_impl, Rhs, PermutationShape, SparseShape, ProductTag> -{ - template - static void evalTo(Dest& dst, const Transpose& lhs, const Rhs& rhs) - { - permut_sparsematrix_product_retval pmpr(lhs.nestedPermutation(), rhs); - pmpr.evalTo(dst); - } -}; - -template -struct generic_product_impl, SparseShape, PermutationShape, ProductTag> -{ - template - static void evalTo(Dest& dst, const Lhs& lhs, const Transpose& rhs) - { - permut_sparsematrix_product_retval pmpr(rhs.nestedPermutation(), lhs); - pmpr.evalTo(dst); - } -}; // TODO, the following two overloads are only needed to define the right temporary type through -// typename traits >::ReturnType -// while it should be correctly handled by traits >::PlainObject +// typename traits >::ReturnType +// whereas it should be correctly handled by traits >::PlainObject template struct product_evaluator, ProductTag, PermutationShape, SparseShape, typename traits::Scalar, typename traits::Scalar> - : public evaluator >::ReturnType>::type + : public evaluator::ReturnType>::type { typedef Product XprType; - typedef typename traits >::ReturnType PlainObject; + typedef typename permutation_matrix_product::ReturnType PlainObject; typedef typename evaluator::type Base; explicit product_evaluator(const XprType& xpr) @@ -179,10 +110,10 @@ protected: template struct product_evaluator, ProductTag, SparseShape, PermutationShape, typename traits::Scalar, typename traits::Scalar> - : public evaluator >::ReturnType>::type + : public evaluator::ReturnType>::type { typedef Product XprType; - typedef typename traits >::ReturnType PlainObject; + typedef typename permutation_matrix_product::ReturnType PlainObject; typedef typename evaluator::type Base; explicit product_evaluator(const XprType& xpr)