mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-23 23:03:15 +08:00
fix transposed matrix product bug
This commit is contained in:
parent
112ad8b846
commit
574bc8820d
@ -93,6 +93,23 @@ class Product
|
|||||||
typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
|
typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
|
||||||
typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
|
typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
|
||||||
|
|
||||||
|
private:
|
||||||
|
using LhsTransposeType = Transpose<const LhsNestedCleaned>;
|
||||||
|
using LhsScalar = typename internal::traits<LhsNestedCleaned>::Scalar;
|
||||||
|
using LhsConjugateTransposeType = CwiseUnaryOp<internal::scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
|
||||||
|
using LhsAdjointType =
|
||||||
|
std::conditional_t<is_complex_helper<LhsScalar>::value, LhsConjugateTransposeType, LhsTransposeType>;
|
||||||
|
|
||||||
|
using RhsTransposeType = Transpose<const RhsNestedCleaned>;
|
||||||
|
using RhsScalar = typename internal::traits<RhsNestedCleaned>::Scalar;
|
||||||
|
using RhsConjugateTransposeType = CwiseUnaryOp<internal::scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
|
||||||
|
using RhsAdjointType =
|
||||||
|
std::conditional_t<is_complex_helper<RhsScalar>::value, RhsConjugateTransposeType, RhsTransposeType>;
|
||||||
|
|
||||||
|
public:
|
||||||
|
using TransposeReturnType = Product<RhsTransposeType, LhsTransposeType, Option>;
|
||||||
|
using AdjointReturnType = Product<RhsAdjointType, LhsAdjointType, Option>;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {
|
||||||
eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" &&
|
eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" &&
|
||||||
"if you wanted a coeff-wise or a dot product use the respective explicit functions");
|
"if you wanted a coeff-wise or a dot product use the respective explicit functions");
|
||||||
@ -104,6 +121,9 @@ class Product
|
|||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; }
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeReturnType transpose() const;
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointReturnType adjoint() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
LhsNested m_lhs;
|
LhsNested m_lhs;
|
||||||
RhsNested m_rhs;
|
RhsNested m_rhs;
|
||||||
|
@ -196,6 +196,18 @@ EIGEN_DEVICE_FUNC inline const typename MatrixBase<Derived>::AdjointReturnType M
|
|||||||
return AdjointReturnType(this->transpose());
|
return AdjointReturnType(this->transpose());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Lhs_, typename Rhs_, int Option>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Product<Lhs_, Rhs_, Option>::TransposeReturnType
|
||||||
|
Product<Lhs_, Rhs_, Option>::transpose() const {
|
||||||
|
return TransposeReturnType(m_rhs.transpose(), m_lhs.transpose());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Lhs_, typename Rhs_, int Option>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE typename Product<Lhs_, Rhs_, Option>::AdjointReturnType
|
||||||
|
Product<Lhs_, Rhs_, Option>::adjoint() const {
|
||||||
|
return AdjointReturnType(m_rhs.adjoint(), m_lhs.adjoint());
|
||||||
|
}
|
||||||
|
|
||||||
/***************************************************************************
|
/***************************************************************************
|
||||||
* "in place" transpose implementation
|
* "in place" transpose implementation
|
||||||
***************************************************************************/
|
***************************************************************************/
|
||||||
|
@ -1011,6 +1011,16 @@ struct ScalarBinaryOpTraits<void, void, BinaryOp> {
|
|||||||
typedef void ReturnType;
|
typedef void ReturnType;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct is_complex_helper {
|
||||||
|
static constexpr bool value = NumTraits<Scalar>::IsComplex;
|
||||||
|
};
|
||||||
|
// NumTraits<void> is not defined by intent
|
||||||
|
template <>
|
||||||
|
struct is_complex_helper<void> {
|
||||||
|
static constexpr bool value = false;
|
||||||
|
};
|
||||||
|
|
||||||
// We require Lhs and Rhs to have "compatible" scalar types.
|
// We require Lhs and Rhs to have "compatible" scalar types.
|
||||||
// It is tempting to always allow mixing different types but remember that this is often impossible in the vectorized
|
// It is tempting to always allow mixing different types but remember that this is often impossible in the vectorized
|
||||||
// paths. So allowing mixing different types gives very unexpected errors when enabling vectorization, when the user
|
// paths. So allowing mixing different types gives very unexpected errors when enabling vectorization, when the user
|
||||||
|
@ -66,6 +66,8 @@ void product_notemporary(const MatrixType& m) {
|
|||||||
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1);
|
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1);
|
||||||
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1);
|
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1);
|
||||||
VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0);
|
VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.adjoint()).transpose(), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.transpose()).adjoint(), 0);
|
||||||
|
|
||||||
VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1);
|
VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1);
|
||||||
// VERIFY_EVALUATION_COUNT( m3 = m3 + s1 * (m1 * m2.transpose()), 1);
|
// VERIFY_EVALUATION_COUNT( m3 = m3 + s1 * (m1 * m2.transpose()), 1);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user