diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 6bad832e0..37683e3c2 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -21,7 +21,7 @@ class ProductImpl; namespace internal { template -struct traits > { +struct traits> { typedef remove_all_t LhsCleaned; typedef remove_all_t RhsCleaned; typedef traits LhsTraits; @@ -55,6 +55,129 @@ struct traits > { }; }; +struct TransposeProductEnum { + // convenience enumerations to specialize transposed products + enum : int { + Default = 0x00, + Matrix = 0x01, + Permutation = 0x02, + MatrixMatrix = (Matrix << 8) | Matrix, + MatrixPermutation = (Matrix << 8) | Permutation, + PermutationMatrix = (Permutation << 8) | Matrix + }; +}; +template +struct TransposeKind { + static constexpr int Kind = is_matrix_base_xpr::value ? TransposeProductEnum::Matrix + : is_permutation_base_xpr::value ? TransposeProductEnum::Permutation + : TransposeProductEnum::Default; +}; + +template +struct TransposeProductKind { + static constexpr int Kind = (TransposeKind::Kind << 8) | TransposeKind::Kind; +}; + +template ::Kind> +struct product_transpose_helper { + // by default, don't optimize the transposed product + using Derived = Product; + using Scalar = typename Derived::Scalar; + using TransposeType = Transpose; + using ConjugateTransposeType = CwiseUnaryOp, TransposeType>; + using AdjointType = std::conditional_t::IsComplex, ConjugateTransposeType, TransposeType>; + + // return (lhs * rhs)^T + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) { + return TransposeType(derived); + } + // return (lhs * rhs)^H + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) { + return AdjointType(TransposeType(derived)); + } +}; + +template +struct product_transpose_helper { + // expand the transposed matrix-matrix product + using Derived = Product; + + using LhsScalar = typename traits::Scalar; + using LhsTransposeType = typename DenseBase::ConstTransposeReturnType; + using LhsConjugateTransposeType = CwiseUnaryOp, LhsTransposeType>; + using LhsAdjointType = + std::conditional_t::IsComplex, LhsConjugateTransposeType, LhsTransposeType>; + + using RhsScalar = typename traits::Scalar; + using RhsTransposeType = typename DenseBase::ConstTransposeReturnType; + using RhsConjugateTransposeType = CwiseUnaryOp, RhsTransposeType>; + using RhsAdjointType = + std::conditional_t::IsComplex, RhsConjugateTransposeType, RhsTransposeType>; + + using TransposeType = Product; + using AdjointType = Product; + + // return rhs^T * lhs^T + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) { + return TransposeType(RhsTransposeType(derived.rhs()), LhsTransposeType(derived.lhs())); + } + // return rhs^H * lhs^H + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) { + return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())), + LhsAdjointType(LhsTransposeType(derived.lhs()))); + } +}; +template +struct product_transpose_helper { + // expand the transposed permutation-matrix product + using Derived = Product; + + using LhsInverseType = typename PermutationBase::InverseReturnType; + + using RhsScalar = typename traits::Scalar; + using RhsTransposeType = typename DenseBase::ConstTransposeReturnType; + using RhsConjugateTransposeType = CwiseUnaryOp, RhsTransposeType>; + using RhsAdjointType = + std::conditional_t::IsComplex, RhsConjugateTransposeType, RhsTransposeType>; + + using TransposeType = Product; + using AdjointType = Product; + + // return rhs^T * lhs^-1 + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) { + return TransposeType(RhsTransposeType(derived.rhs()), LhsInverseType(derived.lhs())); + } + // return rhs^H * lhs^-1 + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) { + return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())), LhsInverseType(derived.lhs())); + } +}; +template +struct product_transpose_helper { + // expand the transposed matrix-permutation product + using Derived = Product; + + using LhsScalar = typename traits::Scalar; + using LhsTransposeType = typename DenseBase::ConstTransposeReturnType; + using LhsConjugateTransposeType = CwiseUnaryOp, LhsTransposeType>; + using LhsAdjointType = + std::conditional_t::IsComplex, LhsConjugateTransposeType, LhsTransposeType>; + + using RhsInverseType = typename PermutationBase::InverseReturnType; + + using TransposeType = Product; + using AdjointType = Product; + + // return rhs^-1 * lhs^T + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) { + return TransposeType(RhsInverseType(derived.rhs()), LhsTransposeType(derived.lhs())); + } + // return rhs^-1 * lhs^H + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) { + return AdjointType(RhsInverseType(derived.rhs()), LhsAdjointType(LhsTransposeType(derived.lhs()))); + } +}; + } // end namespace internal /** \class Product @@ -93,6 +216,9 @@ class Product typedef internal::remove_all_t LhsNestedCleaned; typedef internal::remove_all_t RhsNestedCleaned; + using TransposeReturnType = typename internal::product_transpose_helper::TransposeType; + using AdjointReturnType = typename internal::product_transpose_helper::AdjointType; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" && "if you wanted a coeff-wise or a dot product use the respective explicit functions"); @@ -104,6 +230,13 @@ class Product EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeReturnType transpose() const { + return internal::product_transpose_helper::run_transpose(*this); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointReturnType adjoint() const { + return internal::product_transpose_helper::run_adjoint(*this); + } + protected: LhsNested m_lhs; RhsNested m_rhs; @@ -112,12 +245,12 @@ class Product namespace internal { template ::ret> -class dense_product_base : public internal::dense_xpr_base >::type {}; +class dense_product_base : public internal::dense_xpr_base>::type {}; /** Conversion to scalar for inner-products */ template class dense_product_base - : public internal::dense_xpr_base >::type { + : public internal::dense_xpr_base>::type { typedef Product ProductXpr; typedef typename internal::dense_xpr_base::type Base; diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index 555faa1cc..a6a7d3fbb 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -928,6 +928,12 @@ template struct block_xpr_helper> : block_xpr_helper> {}; +template +struct is_matrix_base_xpr : std::is_base_of>, remove_all_t> {}; + +template +struct is_permutation_base_xpr : std::is_base_of>, remove_all_t> {}; + } // end namespace internal /** \class ScalarBinaryOpTraits diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp index 77a296966..c22ea13cd 100644 --- a/test/product_notemporary.cpp +++ b/test/product_notemporary.cpp @@ -65,6 +65,8 @@ void product_notemporary(const MatrixType& m) { VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1); VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1); + VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.adjoint()).transpose(), 0); + VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.transpose()).adjoint(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0); VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1); @@ -75,6 +77,8 @@ void product_notemporary(const MatrixType& m) { VERIFY_EVALUATION_COUNT(m3 = m3 - (m1 * m2.adjoint()), 1); VERIFY_EVALUATION_COUNT(m3 = m3 + (m1 * m2.adjoint()).transpose(), 1); + VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.adjoint()).transpose(), 0); + VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.transpose()).adjoint(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() += m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() -= m3 + m1 * m2.transpose(), 0);