Revert "fix transposed matrix product bug"

This reverts merge request !1598
This commit is contained in:
Charles Schlosser 2024-04-23 14:07:11 +00:00
parent 9cec679ef1
commit 34967b0b5b
4 changed files with 3 additions and 47 deletions

View File

@ -21,7 +21,7 @@ class ProductImpl;
namespace internal {
template <typename Lhs, typename Rhs, int Option>
struct traits<Product<Lhs, Rhs, Option>> {
struct traits<Product<Lhs, Rhs, Option> > {
typedef remove_all_t<Lhs> LhsCleaned;
typedef remove_all_t<Rhs> RhsCleaned;
typedef traits<LhsCleaned> LhsTraits;
@ -93,23 +93,6 @@ class Product
typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
private:
using LhsTransposeType = Transpose<const LhsNestedCleaned>;
using LhsScalar = typename internal::traits<LhsNestedCleaned>::Scalar;
using LhsConjugateTransposeType = CwiseUnaryOp<internal::scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
using LhsAdjointType =
std::conditional_t<is_complex_helper<LhsScalar>::value, LhsConjugateTransposeType, LhsTransposeType>;
using RhsTransposeType = Transpose<const RhsNestedCleaned>;
using RhsScalar = typename internal::traits<RhsNestedCleaned>::Scalar;
using RhsConjugateTransposeType = CwiseUnaryOp<internal::scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
using RhsAdjointType =
std::conditional_t<is_complex_helper<RhsScalar>::value, RhsConjugateTransposeType, RhsTransposeType>;
public:
using TransposeReturnType = Product<RhsTransposeType, LhsTransposeType, Option>;
using AdjointReturnType = Product<RhsAdjointType, LhsAdjointType, Option>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {
eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" &&
"if you wanted a coeff-wise or a dot product use the respective explicit functions");
@ -121,9 +104,6 @@ class Product
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeReturnType transpose() const;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointReturnType adjoint() const;
protected:
LhsNested m_lhs;
RhsNested m_rhs;
@ -132,12 +112,12 @@ class Product
namespace internal {
template <typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs, Rhs>::ret>
class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {};
class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type {};
/** Conversion to scalar for inner-products */
template <typename Lhs, typename Rhs, int Option>
class dense_product_base<Lhs, Rhs, Option, InnerProduct>
: public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {
: public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type {
typedef Product<Lhs, Rhs, Option> ProductXpr;
typedef typename internal::dense_xpr_base<ProductXpr>::type Base;

View File

@ -196,18 +196,6 @@ EIGEN_DEVICE_FUNC inline const typename MatrixBase<Derived>::AdjointReturnType M
return AdjointReturnType(this->transpose());
}
template <typename Lhs_, typename Rhs_, int Option>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Product<Lhs_, Rhs_, Option>::TransposeReturnType
Product<Lhs_, Rhs_, Option>::transpose() const {
return TransposeReturnType(m_rhs.transpose(), m_lhs.transpose());
}
template <typename Lhs_, typename Rhs_, int Option>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Product<Lhs_, Rhs_, Option>::AdjointReturnType
Product<Lhs_, Rhs_, Option>::adjoint() const {
return AdjointReturnType(m_rhs.adjoint(), m_lhs.adjoint());
}
/***************************************************************************
* "in place" transpose implementation
***************************************************************************/

View File

@ -1011,16 +1011,6 @@ struct ScalarBinaryOpTraits<void, void, BinaryOp> {
typedef void ReturnType;
};
template <typename Scalar>
struct is_complex_helper {
static constexpr bool value = NumTraits<Scalar>::IsComplex;
};
// NumTraits<void> is not defined by intent
template <>
struct is_complex_helper<void> {
static constexpr bool value = false;
};
// We require Lhs and Rhs to have "compatible" scalar types.
// It is tempting to always allow mixing different types but remember that this is often impossible in the vectorized
// paths. So allowing mixing different types gives very unexpected errors when enabling vectorization, when the user

View File

@ -66,8 +66,6 @@ void product_notemporary(const MatrixType& m) {
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1);
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1);
VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.adjoint()).transpose(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.transpose()).adjoint(), 0);
VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1);
// VERIFY_EVALUATION_COUNT( m3 = m3 + s1 * (m1 * m2.transpose()), 1);