From 3af4c6c1c9327411d13386e4719ce48f866c7567 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 19 Jun 2015 11:50:24 +0200 Subject: [PATCH] Make Transpositions use evaluators --- Eigen/src/Core/PermutationMatrix.h | 3 - Eigen/src/Core/ProductEvaluators.h | 73 +++++++++++++++++ Eigen/src/Core/Transpositions.h | 127 ++++++++++++----------------- Eigen/src/Core/util/Constants.h | 1 + 4 files changed, 128 insertions(+), 76 deletions(-) diff --git a/Eigen/src/Core/PermutationMatrix.h b/Eigen/src/Core/PermutationMatrix.h index 9a0c03612..8c9afd4ee 100644 --- a/Eigen/src/Core/PermutationMatrix.h +++ b/Eigen/src/Core/PermutationMatrix.h @@ -537,9 +537,6 @@ class PermutationWrapper : public PermutationBase diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 9d1cb5d56..4c673a6cb 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -929,6 +929,79 @@ struct generic_product_impl, MatrixShape, PermutationShape, } }; + +/*************************************************************************** +* Products with transpositions matrices +***************************************************************************/ + +// FIXME could we unify Transpositions and Permutation into a single "shape"?? + +/** \internal + * \class transposition_matrix_product + * Internal helper class implementing the product between a permutation matrix and a matrix. + */ +template +struct transposition_matrix_product +{ + template + static inline void evalTo(Dest& dst, const TranspositionType& tr, const MatrixType& mat) + { + typedef typename TranspositionType::StorageIndex StorageIndex; + const Index size = tr.size(); + StorageIndex j = 0; + + if(!(is_same::value && extract_data(dst) == extract_data(mat))) + dst = mat; + + for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k +struct generic_product_impl +{ + template + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + transposition_matrix_product::run(dst, lhs, rhs); + } +}; + +template +struct generic_product_impl +{ + template + static void evalTo(Dest& dst, const Lhs& lhs, const Rhs& rhs) + { + transposition_matrix_product::run(dst, rhs, lhs); + } +}; + +template +struct generic_product_impl, Rhs, TranspositionsShape, MatrixShape, ProductTag> +{ + template + static void evalTo(Dest& dst, const Transpose& lhs, const Rhs& rhs) + { + transposition_matrix_product::run(dst, lhs.nestedPermutation(), rhs); + } +}; + +template +struct generic_product_impl, MatrixShape, TranspositionsShape, ProductTag> +{ + template + static void evalTo(Dest& dst, const Lhs& lhs, const Transpose& rhs) + { + transposition_matrix_product::run(dst, rhs.nestedPermutation(), lhs); + } +}; + } // end namespace internal } // end namespace Eigen diff --git a/Eigen/src/Core/Transpositions.h b/Eigen/src/Core/Transpositions.h index b08df1ead..dad4f56c9 100644 --- a/Eigen/src/Core/Transpositions.h +++ b/Eigen/src/Core/Transpositions.h @@ -41,10 +41,6 @@ namespace Eigen { * \sa class PermutationMatrix */ -namespace internal { -template struct transposition_matrix_product_retval; -} - template class TranspositionsBase { @@ -325,77 +321,32 @@ class TranspositionsWrapper const typename IndicesType::Nested m_indices; }; + + /** \returns the \a matrix with the \a transpositions applied to the columns. */ -template -inline const internal::transposition_matrix_product_retval -operator*(const MatrixBase& matrix, - const TranspositionsBase &transpositions) +template +EIGEN_DEVICE_FUNC +const Product +operator*(const MatrixBase &matrix, + const TranspositionsBase& transpositions) { - return internal::transposition_matrix_product_retval - - (transpositions.derived(), matrix.derived()); + return Product + (matrix.derived(), transpositions.derived()); } /** \returns the \a matrix with the \a transpositions applied to the rows. */ -template -inline const internal::transposition_matrix_product_retval - -operator*(const TranspositionsBase &transpositions, - const MatrixBase& matrix) +template +EIGEN_DEVICE_FUNC +const Product +operator*(const TranspositionsBase &transpositions, + const MatrixBase& matrix) { - return internal::transposition_matrix_product_retval - - (transpositions.derived(), matrix.derived()); + return Product + (transpositions.derived(), matrix.derived()); } -namespace internal { - -template -struct traits > -{ - typedef typename MatrixType::PlainObject ReturnType; -}; - -template -struct transposition_matrix_product_retval - : public ReturnByValue > -{ - typedef typename remove_all::type MatrixTypeNestedCleaned; - typedef typename TranspositionType::StorageIndex StorageIndex; - - transposition_matrix_product_retval(const TranspositionType& tr, const MatrixType& matrix) - : m_transpositions(tr), m_matrix(matrix) - {} - - inline Index rows() const { return m_matrix.rows(); } - inline Index cols() const { return m_matrix.cols(); } - - template inline void evalTo(Dest& dst) const - { - const Index size = m_transpositions.size(); - StorageIndex j = 0; - - if(!(is_same::value && extract_data(dst) == extract_data(m_matrix))) - dst = m_matrix; - - for(Index k=(Transposed?size-1:0) ; Transposed?k>=0:k > /** \returns the \a matrix with the inverse transpositions applied to the columns. */ - template friend - inline const internal::transposition_matrix_product_retval - operator*(const MatrixBase& matrix, const Transpose& trt) + template friend + const Product + operator*(const MatrixBase& matrix, const Transpose& trt) { - return internal::transposition_matrix_product_retval(trt.m_transpositions, matrix.derived()); + return Product(matrix.derived(), trt.derived()); } /** \returns the \a matrix with the inverse transpositions applied to the rows. */ - template - inline const internal::transposition_matrix_product_retval - operator*(const MatrixBase& matrix) const + template + const Product + operator*(const MatrixBase& matrix) const { - return internal::transposition_matrix_product_retval(m_transpositions, matrix.derived()); + return Product(*this, matrix.derived()); } protected: const TranspositionType& m_transpositions; }; +namespace internal { + +// TODO currently a Transpositions expression has the form Transpositions or TranspositionsWrapper +// or their transpose; in the future shape should be defined by the expression traits +template +struct evaluator_traits > +{ + typedef typename storage_kind_to_evaluator_kind::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +template +struct evaluator_traits > +{ + typedef typename storage_kind_to_evaluator_kind::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +template +struct evaluator_traits > > +{ + typedef typename storage_kind_to_evaluator_kind::Kind Kind; + typedef TranspositionsShape Shape; + static const int AssumeAliasing = 0; +}; + +} // end namespace internal + } // end namespace Eigen #endif // EIGEN_TRANSPOSITIONS_H diff --git a/Eigen/src/Core/util/Constants.h b/Eigen/src/Core/util/Constants.h index 419409608..3e6c75444 100644 --- a/Eigen/src/Core/util/Constants.h +++ b/Eigen/src/Core/util/Constants.h @@ -482,6 +482,7 @@ struct BandShape { static std::string debugName() { return "BandSha struct TriangularShape { static std::string debugName() { return "TriangularShape"; } }; struct SelfAdjointShape { static std::string debugName() { return "SelfAdjointShape"; } }; struct PermutationShape { static std::string debugName() { return "PermutationShape"; } }; +struct TranspositionsShape { static std::string debugName() { return "TranspositionsShape"; } }; struct SparseShape { static std::string debugName() { return "SparseShape"; } }; namespace internal {