From 574bc8820ddbc381075348f210cef19de7b22f2b Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Tue, 23 Apr 2024 03:25:57 +0000 Subject: [PATCH] fix transposed matrix product bug --- Eigen/src/Core/Product.h | 26 +++++++++++++++++++++++--- Eigen/src/Core/Transpose.h | 12 ++++++++++++ Eigen/src/Core/util/XprHelper.h | 10 ++++++++++ test/product_notemporary.cpp | 2 ++ 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 6bad832e0..23c90c223 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -21,7 +21,7 @@ class ProductImpl; namespace internal { template -struct traits > { +struct traits> { typedef remove_all_t LhsCleaned; typedef remove_all_t RhsCleaned; typedef traits LhsTraits; @@ -93,6 +93,23 @@ class Product typedef internal::remove_all_t LhsNestedCleaned; typedef internal::remove_all_t RhsNestedCleaned; + private: + using LhsTransposeType = Transpose; + using LhsScalar = typename internal::traits::Scalar; + using LhsConjugateTransposeType = CwiseUnaryOp, LhsTransposeType>; + using LhsAdjointType = + std::conditional_t::value, LhsConjugateTransposeType, LhsTransposeType>; + + using RhsTransposeType = Transpose; + using RhsScalar = typename internal::traits::Scalar; + using RhsConjugateTransposeType = CwiseUnaryOp, RhsTransposeType>; + using RhsAdjointType = + std::conditional_t::value, RhsConjugateTransposeType, RhsTransposeType>; + + public: + using TransposeReturnType = Product; + using AdjointReturnType = Product; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" && "if you wanted a coeff-wise or a dot product use the respective explicit functions"); @@ -104,6 +121,9 @@ class Product EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeReturnType transpose() const; + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointReturnType adjoint() const; + protected: LhsNested m_lhs; RhsNested m_rhs; @@ -112,12 +132,12 @@ class Product namespace internal { template ::ret> -class dense_product_base : public internal::dense_xpr_base >::type {}; +class dense_product_base : public internal::dense_xpr_base>::type {}; /** Conversion to scalar for inner-products */ template class dense_product_base - : public internal::dense_xpr_base >::type { + : public internal::dense_xpr_base>::type { typedef Product ProductXpr; typedef typename internal::dense_xpr_base::type Base; diff --git a/Eigen/src/Core/Transpose.h b/Eigen/src/Core/Transpose.h index 1cc7a2867..f12cf4a8d 100644 --- a/Eigen/src/Core/Transpose.h +++ b/Eigen/src/Core/Transpose.h @@ -196,6 +196,18 @@ EIGEN_DEVICE_FUNC inline const typename MatrixBase::AdjointReturnType M return AdjointReturnType(this->transpose()); } +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Product::TransposeReturnType +Product::transpose() const { + return TransposeReturnType(m_rhs.transpose(), m_lhs.transpose()); +} + +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Product::AdjointReturnType +Product::adjoint() const { + return AdjointReturnType(m_rhs.adjoint(), m_lhs.adjoint()); +} + /*************************************************************************** * "in place" transpose implementation ***************************************************************************/ diff --git a/Eigen/src/Core/util/XprHelper.h b/Eigen/src/Core/util/XprHelper.h index 555faa1cc..ea4a7c3fb 100644 --- a/Eigen/src/Core/util/XprHelper.h +++ b/Eigen/src/Core/util/XprHelper.h @@ -1011,6 +1011,16 @@ struct ScalarBinaryOpTraits { typedef void ReturnType; }; +template +struct is_complex_helper { + static constexpr bool value = NumTraits::IsComplex; +}; +// NumTraits is not defined by intent +template <> +struct is_complex_helper { + static constexpr bool value = false; +}; + // We require Lhs and Rhs to have "compatible" scalar types. // It is tempting to always allow mixing different types but remember that this is often impossible in the vectorized // paths. So allowing mixing different types gives very unexpected errors when enabling vectorization, when the user diff --git a/test/product_notemporary.cpp b/test/product_notemporary.cpp index 77a296966..f41a4b171 100644 --- a/test/product_notemporary.cpp +++ b/test/product_notemporary.cpp @@ -66,6 +66,8 @@ void product_notemporary(const MatrixType& m) { VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1); VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1); VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0); + VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.adjoint()).transpose(), 0); + VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.transpose()).adjoint(), 0); VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1); // VERIFY_EVALUATION_COUNT( m3 = m3 + s1 * (m1 * m2.transpose()), 1);