mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-10-10 23:21:29 +08:00
Eigen transpose product
This commit is contained in:
parent
fb95e90f7f
commit
0ee5c90aa9
@ -21,7 +21,7 @@ class ProductImpl;
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template <typename Lhs, typename Rhs, int Option>
|
template <typename Lhs, typename Rhs, int Option>
|
||||||
struct traits<Product<Lhs, Rhs, Option> > {
|
struct traits<Product<Lhs, Rhs, Option>> {
|
||||||
typedef remove_all_t<Lhs> LhsCleaned;
|
typedef remove_all_t<Lhs> LhsCleaned;
|
||||||
typedef remove_all_t<Rhs> RhsCleaned;
|
typedef remove_all_t<Rhs> RhsCleaned;
|
||||||
typedef traits<LhsCleaned> LhsTraits;
|
typedef traits<LhsCleaned> LhsTraits;
|
||||||
@ -55,6 +55,129 @@ struct traits<Product<Lhs, Rhs, Option> > {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct TransposeProductEnum {
|
||||||
|
// convenience enumerations to specialize transposed products
|
||||||
|
enum : int {
|
||||||
|
Default = 0x00,
|
||||||
|
Matrix = 0x01,
|
||||||
|
Permutation = 0x02,
|
||||||
|
MatrixMatrix = (Matrix << 8) | Matrix,
|
||||||
|
MatrixPermutation = (Matrix << 8) | Permutation,
|
||||||
|
PermutationMatrix = (Permutation << 8) | Matrix
|
||||||
|
};
|
||||||
|
};
|
||||||
|
template <typename Xpr>
|
||||||
|
struct TransposeKind {
|
||||||
|
static constexpr int Kind = is_matrix_base_xpr<Xpr>::value ? TransposeProductEnum::Matrix
|
||||||
|
: is_permutation_base_xpr<Xpr>::value ? TransposeProductEnum::Permutation
|
||||||
|
: TransposeProductEnum::Default;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Lhs, typename Rhs>
|
||||||
|
struct TransposeProductKind {
|
||||||
|
static constexpr int Kind = (TransposeKind<Lhs>::Kind << 8) | TransposeKind<Rhs>::Kind;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Lhs, typename Rhs, int Option, int Kind = TransposeProductKind<Lhs, Rhs>::Kind>
|
||||||
|
struct product_transpose_helper {
|
||||||
|
// by default, don't optimize the transposed product
|
||||||
|
using Derived = Product<Lhs, Rhs, Option>;
|
||||||
|
using Scalar = typename Derived::Scalar;
|
||||||
|
using TransposeType = Transpose<const Derived>;
|
||||||
|
using ConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<Scalar>, TransposeType>;
|
||||||
|
using AdjointType = std::conditional_t<NumTraits<Scalar>::IsComplex, ConjugateTransposeType, TransposeType>;
|
||||||
|
|
||||||
|
// return (lhs * rhs)^T
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||||
|
return TransposeType(derived);
|
||||||
|
}
|
||||||
|
// return (lhs * rhs)^H
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||||
|
return AdjointType(TransposeType(derived));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Lhs, typename Rhs, int Option>
|
||||||
|
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::MatrixMatrix> {
|
||||||
|
// expand the transposed matrix-matrix product
|
||||||
|
using Derived = Product<Lhs, Rhs, Option>;
|
||||||
|
|
||||||
|
using LhsScalar = typename traits<Lhs>::Scalar;
|
||||||
|
using LhsTransposeType = typename DenseBase<Lhs>::ConstTransposeReturnType;
|
||||||
|
using LhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
|
||||||
|
using LhsAdjointType =
|
||||||
|
std::conditional_t<NumTraits<LhsScalar>::IsComplex, LhsConjugateTransposeType, LhsTransposeType>;
|
||||||
|
|
||||||
|
using RhsScalar = typename traits<Rhs>::Scalar;
|
||||||
|
using RhsTransposeType = typename DenseBase<Rhs>::ConstTransposeReturnType;
|
||||||
|
using RhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
|
||||||
|
using RhsAdjointType =
|
||||||
|
std::conditional_t<NumTraits<RhsScalar>::IsComplex, RhsConjugateTransposeType, RhsTransposeType>;
|
||||||
|
|
||||||
|
using TransposeType = Product<RhsTransposeType, LhsTransposeType, Option>;
|
||||||
|
using AdjointType = Product<RhsAdjointType, LhsAdjointType, Option>;
|
||||||
|
|
||||||
|
// return rhs^T * lhs^T
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||||
|
return TransposeType(RhsTransposeType(derived.rhs()), LhsTransposeType(derived.lhs()));
|
||||||
|
}
|
||||||
|
// return rhs^H * lhs^H
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||||
|
return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())),
|
||||||
|
LhsAdjointType(LhsTransposeType(derived.lhs())));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename Lhs, typename Rhs, int Option>
|
||||||
|
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::PermutationMatrix> {
|
||||||
|
// expand the transposed permutation-matrix product
|
||||||
|
using Derived = Product<Lhs, Rhs, Option>;
|
||||||
|
|
||||||
|
using LhsInverseType = typename PermutationBase<Lhs>::InverseReturnType;
|
||||||
|
|
||||||
|
using RhsScalar = typename traits<Rhs>::Scalar;
|
||||||
|
using RhsTransposeType = typename DenseBase<Rhs>::ConstTransposeReturnType;
|
||||||
|
using RhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
|
||||||
|
using RhsAdjointType =
|
||||||
|
std::conditional_t<NumTraits<RhsScalar>::IsComplex, RhsConjugateTransposeType, RhsTransposeType>;
|
||||||
|
|
||||||
|
using TransposeType = Product<RhsTransposeType, LhsInverseType, Option>;
|
||||||
|
using AdjointType = Product<RhsAdjointType, LhsInverseType, Option>;
|
||||||
|
|
||||||
|
// return rhs^T * lhs^-1
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||||
|
return TransposeType(RhsTransposeType(derived.rhs()), LhsInverseType(derived.lhs()));
|
||||||
|
}
|
||||||
|
// return rhs^H * lhs^-1
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||||
|
return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())), LhsInverseType(derived.lhs()));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
template <typename Lhs, typename Rhs, int Option>
|
||||||
|
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::MatrixPermutation> {
|
||||||
|
// expand the transposed matrix-permutation product
|
||||||
|
using Derived = Product<Lhs, Rhs, Option>;
|
||||||
|
|
||||||
|
using LhsScalar = typename traits<Lhs>::Scalar;
|
||||||
|
using LhsTransposeType = typename DenseBase<Lhs>::ConstTransposeReturnType;
|
||||||
|
using LhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
|
||||||
|
using LhsAdjointType =
|
||||||
|
std::conditional_t<NumTraits<LhsScalar>::IsComplex, LhsConjugateTransposeType, LhsTransposeType>;
|
||||||
|
|
||||||
|
using RhsInverseType = typename PermutationBase<Rhs>::InverseReturnType;
|
||||||
|
|
||||||
|
using TransposeType = Product<RhsInverseType, LhsTransposeType, Option>;
|
||||||
|
using AdjointType = Product<RhsInverseType, LhsAdjointType, Option>;
|
||||||
|
|
||||||
|
// return rhs^-1 * lhs^T
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||||
|
return TransposeType(RhsInverseType(derived.rhs()), LhsTransposeType(derived.lhs()));
|
||||||
|
}
|
||||||
|
// return rhs^-1 * lhs^H
|
||||||
|
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||||
|
return AdjointType(RhsInverseType(derived.rhs()), LhsAdjointType(LhsTransposeType(derived.lhs())));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
/** \class Product
|
/** \class Product
|
||||||
@ -93,6 +216,9 @@ class Product
|
|||||||
typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
|
typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
|
||||||
typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
|
typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
|
||||||
|
|
||||||
|
using TransposeReturnType = typename internal::product_transpose_helper<Lhs, Rhs, Option>::TransposeType;
|
||||||
|
using AdjointReturnType = typename internal::product_transpose_helper<Lhs, Rhs, Option>::AdjointType;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {
|
||||||
eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" &&
|
eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" &&
|
||||||
"if you wanted a coeff-wise or a dot product use the respective explicit functions");
|
"if you wanted a coeff-wise or a dot product use the respective explicit functions");
|
||||||
@ -104,6 +230,13 @@ class Product
|
|||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; }
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeReturnType transpose() const {
|
||||||
|
return internal::product_transpose_helper<Lhs, Rhs, Option>::run_transpose(*this);
|
||||||
|
}
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointReturnType adjoint() const {
|
||||||
|
return internal::product_transpose_helper<Lhs, Rhs, Option>::run_adjoint(*this);
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
LhsNested m_lhs;
|
LhsNested m_lhs;
|
||||||
RhsNested m_rhs;
|
RhsNested m_rhs;
|
||||||
@ -112,12 +245,12 @@ class Product
|
|||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template <typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs, Rhs>::ret>
|
template <typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs, Rhs>::ret>
|
||||||
class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type {};
|
class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {};
|
||||||
|
|
||||||
/** Conversion to scalar for inner-products */
|
/** Conversion to scalar for inner-products */
|
||||||
template <typename Lhs, typename Rhs, int Option>
|
template <typename Lhs, typename Rhs, int Option>
|
||||||
class dense_product_base<Lhs, Rhs, Option, InnerProduct>
|
class dense_product_base<Lhs, Rhs, Option, InnerProduct>
|
||||||
: public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type {
|
: public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {
|
||||||
typedef Product<Lhs, Rhs, Option> ProductXpr;
|
typedef Product<Lhs, Rhs, Option> ProductXpr;
|
||||||
typedef typename internal::dense_xpr_base<ProductXpr>::type Base;
|
typedef typename internal::dense_xpr_base<ProductXpr>::type Base;
|
||||||
|
|
||||||
|
@ -928,6 +928,12 @@ template <typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
|
|||||||
struct block_xpr_helper<const Block<XprType, BlockRows, BlockCols, InnerPanel>>
|
struct block_xpr_helper<const Block<XprType, BlockRows, BlockCols, InnerPanel>>
|
||||||
: block_xpr_helper<Block<XprType, BlockRows, BlockCols, InnerPanel>> {};
|
: block_xpr_helper<Block<XprType, BlockRows, BlockCols, InnerPanel>> {};
|
||||||
|
|
||||||
|
template <typename XprType>
|
||||||
|
struct is_matrix_base_xpr : std::is_base_of<MatrixBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
|
||||||
|
|
||||||
|
template <typename XprType>
|
||||||
|
struct is_permutation_base_xpr : std::is_base_of<PermutationBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
/** \class ScalarBinaryOpTraits
|
/** \class ScalarBinaryOpTraits
|
||||||
|
@ -65,6 +65,8 @@ void product_notemporary(const MatrixType& m) {
|
|||||||
|
|
||||||
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1);
|
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1);
|
||||||
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1);
|
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1);
|
||||||
|
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.adjoint()).transpose(), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.transpose()).adjoint(), 0);
|
||||||
VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0);
|
VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0);
|
||||||
|
|
||||||
VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1);
|
VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1);
|
||||||
@ -75,6 +77,8 @@ void product_notemporary(const MatrixType& m) {
|
|||||||
VERIFY_EVALUATION_COUNT(m3 = m3 - (m1 * m2.adjoint()), 1);
|
VERIFY_EVALUATION_COUNT(m3 = m3 - (m1 * m2.adjoint()), 1);
|
||||||
|
|
||||||
VERIFY_EVALUATION_COUNT(m3 = m3 + (m1 * m2.adjoint()).transpose(), 1);
|
VERIFY_EVALUATION_COUNT(m3 = m3 + (m1 * m2.adjoint()).transpose(), 1);
|
||||||
|
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.adjoint()).transpose(), 0);
|
||||||
|
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.transpose()).adjoint(), 0);
|
||||||
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + m1 * m2.transpose(), 0);
|
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + m1 * m2.transpose(), 0);
|
||||||
VERIFY_EVALUATION_COUNT(m3.noalias() += m3 + m1 * m2.transpose(), 0);
|
VERIFY_EVALUATION_COUNT(m3.noalias() += m3 + m1 * m2.transpose(), 0);
|
||||||
VERIFY_EVALUATION_COUNT(m3.noalias() -= m3 + m1 * m2.transpose(), 0);
|
VERIFY_EVALUATION_COUNT(m3.noalias() -= m3 + m1 * m2.transpose(), 0);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user