mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-10-10 07:06:32 +08:00
Eigen transpose product
This commit is contained in:
parent
fb95e90f7f
commit
0ee5c90aa9
@ -21,7 +21,7 @@ class ProductImpl;
|
||||
namespace internal {
|
||||
|
||||
template <typename Lhs, typename Rhs, int Option>
|
||||
struct traits<Product<Lhs, Rhs, Option> > {
|
||||
struct traits<Product<Lhs, Rhs, Option>> {
|
||||
typedef remove_all_t<Lhs> LhsCleaned;
|
||||
typedef remove_all_t<Rhs> RhsCleaned;
|
||||
typedef traits<LhsCleaned> LhsTraits;
|
||||
@ -55,6 +55,129 @@ struct traits<Product<Lhs, Rhs, Option> > {
|
||||
};
|
||||
};
|
||||
|
||||
struct TransposeProductEnum {
|
||||
// convenience enumerations to specialize transposed products
|
||||
enum : int {
|
||||
Default = 0x00,
|
||||
Matrix = 0x01,
|
||||
Permutation = 0x02,
|
||||
MatrixMatrix = (Matrix << 8) | Matrix,
|
||||
MatrixPermutation = (Matrix << 8) | Permutation,
|
||||
PermutationMatrix = (Permutation << 8) | Matrix
|
||||
};
|
||||
};
|
||||
template <typename Xpr>
|
||||
struct TransposeKind {
|
||||
static constexpr int Kind = is_matrix_base_xpr<Xpr>::value ? TransposeProductEnum::Matrix
|
||||
: is_permutation_base_xpr<Xpr>::value ? TransposeProductEnum::Permutation
|
||||
: TransposeProductEnum::Default;
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs>
|
||||
struct TransposeProductKind {
|
||||
static constexpr int Kind = (TransposeKind<Lhs>::Kind << 8) | TransposeKind<Rhs>::Kind;
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs, int Option, int Kind = TransposeProductKind<Lhs, Rhs>::Kind>
|
||||
struct product_transpose_helper {
|
||||
// by default, don't optimize the transposed product
|
||||
using Derived = Product<Lhs, Rhs, Option>;
|
||||
using Scalar = typename Derived::Scalar;
|
||||
using TransposeType = Transpose<const Derived>;
|
||||
using ConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<Scalar>, TransposeType>;
|
||||
using AdjointType = std::conditional_t<NumTraits<Scalar>::IsComplex, ConjugateTransposeType, TransposeType>;
|
||||
|
||||
// return (lhs * rhs)^T
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||
return TransposeType(derived);
|
||||
}
|
||||
// return (lhs * rhs)^H
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||
return AdjointType(TransposeType(derived));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lhs, typename Rhs, int Option>
|
||||
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::MatrixMatrix> {
|
||||
// expand the transposed matrix-matrix product
|
||||
using Derived = Product<Lhs, Rhs, Option>;
|
||||
|
||||
using LhsScalar = typename traits<Lhs>::Scalar;
|
||||
using LhsTransposeType = typename DenseBase<Lhs>::ConstTransposeReturnType;
|
||||
using LhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
|
||||
using LhsAdjointType =
|
||||
std::conditional_t<NumTraits<LhsScalar>::IsComplex, LhsConjugateTransposeType, LhsTransposeType>;
|
||||
|
||||
using RhsScalar = typename traits<Rhs>::Scalar;
|
||||
using RhsTransposeType = typename DenseBase<Rhs>::ConstTransposeReturnType;
|
||||
using RhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
|
||||
using RhsAdjointType =
|
||||
std::conditional_t<NumTraits<RhsScalar>::IsComplex, RhsConjugateTransposeType, RhsTransposeType>;
|
||||
|
||||
using TransposeType = Product<RhsTransposeType, LhsTransposeType, Option>;
|
||||
using AdjointType = Product<RhsAdjointType, LhsAdjointType, Option>;
|
||||
|
||||
// return rhs^T * lhs^T
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||
return TransposeType(RhsTransposeType(derived.rhs()), LhsTransposeType(derived.lhs()));
|
||||
}
|
||||
// return rhs^H * lhs^H
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||
return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())),
|
||||
LhsAdjointType(LhsTransposeType(derived.lhs())));
|
||||
}
|
||||
};
|
||||
template <typename Lhs, typename Rhs, int Option>
|
||||
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::PermutationMatrix> {
|
||||
// expand the transposed permutation-matrix product
|
||||
using Derived = Product<Lhs, Rhs, Option>;
|
||||
|
||||
using LhsInverseType = typename PermutationBase<Lhs>::InverseReturnType;
|
||||
|
||||
using RhsScalar = typename traits<Rhs>::Scalar;
|
||||
using RhsTransposeType = typename DenseBase<Rhs>::ConstTransposeReturnType;
|
||||
using RhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
|
||||
using RhsAdjointType =
|
||||
std::conditional_t<NumTraits<RhsScalar>::IsComplex, RhsConjugateTransposeType, RhsTransposeType>;
|
||||
|
||||
using TransposeType = Product<RhsTransposeType, LhsInverseType, Option>;
|
||||
using AdjointType = Product<RhsAdjointType, LhsInverseType, Option>;
|
||||
|
||||
// return rhs^T * lhs^-1
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||
return TransposeType(RhsTransposeType(derived.rhs()), LhsInverseType(derived.lhs()));
|
||||
}
|
||||
// return rhs^H * lhs^-1
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||
return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())), LhsInverseType(derived.lhs()));
|
||||
}
|
||||
};
|
||||
template <typename Lhs, typename Rhs, int Option>
|
||||
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::MatrixPermutation> {
|
||||
// expand the transposed matrix-permutation product
|
||||
using Derived = Product<Lhs, Rhs, Option>;
|
||||
|
||||
using LhsScalar = typename traits<Lhs>::Scalar;
|
||||
using LhsTransposeType = typename DenseBase<Lhs>::ConstTransposeReturnType;
|
||||
using LhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
|
||||
using LhsAdjointType =
|
||||
std::conditional_t<NumTraits<LhsScalar>::IsComplex, LhsConjugateTransposeType, LhsTransposeType>;
|
||||
|
||||
using RhsInverseType = typename PermutationBase<Rhs>::InverseReturnType;
|
||||
|
||||
using TransposeType = Product<RhsInverseType, LhsTransposeType, Option>;
|
||||
using AdjointType = Product<RhsInverseType, LhsAdjointType, Option>;
|
||||
|
||||
// return rhs^-1 * lhs^T
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
|
||||
return TransposeType(RhsInverseType(derived.rhs()), LhsTransposeType(derived.lhs()));
|
||||
}
|
||||
// return rhs^-1 * lhs^H
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
|
||||
return AdjointType(RhsInverseType(derived.rhs()), LhsAdjointType(LhsTransposeType(derived.lhs())));
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/** \class Product
|
||||
@ -93,6 +216,9 @@ class Product
|
||||
typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
|
||||
typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
|
||||
|
||||
using TransposeReturnType = typename internal::product_transpose_helper<Lhs, Rhs, Option>::TransposeType;
|
||||
using AdjointReturnType = typename internal::product_transpose_helper<Lhs, Rhs, Option>::AdjointType;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {
|
||||
eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" &&
|
||||
"if you wanted a coeff-wise or a dot product use the respective explicit functions");
|
||||
@ -104,6 +230,13 @@ class Product
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; }
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeReturnType transpose() const {
|
||||
return internal::product_transpose_helper<Lhs, Rhs, Option>::run_transpose(*this);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointReturnType adjoint() const {
|
||||
return internal::product_transpose_helper<Lhs, Rhs, Option>::run_adjoint(*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
LhsNested m_lhs;
|
||||
RhsNested m_rhs;
|
||||
@ -112,12 +245,12 @@ class Product
|
||||
namespace internal {
|
||||
|
||||
template <typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs, Rhs>::ret>
|
||||
class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type {};
|
||||
class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {};
|
||||
|
||||
/** Conversion to scalar for inner-products */
|
||||
template <typename Lhs, typename Rhs, int Option>
|
||||
class dense_product_base<Lhs, Rhs, Option, InnerProduct>
|
||||
: public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type {
|
||||
: public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {
|
||||
typedef Product<Lhs, Rhs, Option> ProductXpr;
|
||||
typedef typename internal::dense_xpr_base<ProductXpr>::type Base;
|
||||
|
||||
|
@ -928,6 +928,12 @@ template <typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
|
||||
struct block_xpr_helper<const Block<XprType, BlockRows, BlockCols, InnerPanel>>
|
||||
: block_xpr_helper<Block<XprType, BlockRows, BlockCols, InnerPanel>> {};
|
||||
|
||||
template <typename XprType>
|
||||
struct is_matrix_base_xpr : std::is_base_of<MatrixBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
|
||||
|
||||
template <typename XprType>
|
||||
struct is_permutation_base_xpr : std::is_base_of<PermutationBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
/** \class ScalarBinaryOpTraits
|
||||
|
@ -65,6 +65,8 @@ void product_notemporary(const MatrixType& m) {
|
||||
|
||||
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1);
|
||||
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.adjoint()).transpose(), 0);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.transpose()).adjoint(), 0);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0);
|
||||
|
||||
VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1);
|
||||
@ -75,6 +77,8 @@ void product_notemporary(const MatrixType& m) {
|
||||
VERIFY_EVALUATION_COUNT(m3 = m3 - (m1 * m2.adjoint()), 1);
|
||||
|
||||
VERIFY_EVALUATION_COUNT(m3 = m3 + (m1 * m2.adjoint()).transpose(), 1);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.adjoint()).transpose(), 0);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.transpose()).adjoint(), 0);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + m1 * m2.transpose(), 0);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() += m3 + m1 * m2.transpose(), 0);
|
||||
VERIFY_EVALUATION_COUNT(m3.noalias() -= m3 + m1 * m2.transpose(), 0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user