Eigen transpose product

This commit is contained in:
Charles Schlosser 2024-04-30 13:32:52 +00:00
parent fb95e90f7f
commit 0ee5c90aa9
3 changed files with 146 additions and 3 deletions

View File

@ -21,7 +21,7 @@ class ProductImpl;
namespace internal { namespace internal {
template <typename Lhs, typename Rhs, int Option> template <typename Lhs, typename Rhs, int Option>
struct traits<Product<Lhs, Rhs, Option> > { struct traits<Product<Lhs, Rhs, Option>> {
typedef remove_all_t<Lhs> LhsCleaned; typedef remove_all_t<Lhs> LhsCleaned;
typedef remove_all_t<Rhs> RhsCleaned; typedef remove_all_t<Rhs> RhsCleaned;
typedef traits<LhsCleaned> LhsTraits; typedef traits<LhsCleaned> LhsTraits;
@ -55,6 +55,129 @@ struct traits<Product<Lhs, Rhs, Option> > {
}; };
}; };
struct TransposeProductEnum {
// convenience enumerations to specialize transposed products
enum : int {
Default = 0x00,
Matrix = 0x01,
Permutation = 0x02,
MatrixMatrix = (Matrix << 8) | Matrix,
MatrixPermutation = (Matrix << 8) | Permutation,
PermutationMatrix = (Permutation << 8) | Matrix
};
};
template <typename Xpr>
struct TransposeKind {
static constexpr int Kind = is_matrix_base_xpr<Xpr>::value ? TransposeProductEnum::Matrix
: is_permutation_base_xpr<Xpr>::value ? TransposeProductEnum::Permutation
: TransposeProductEnum::Default;
};
template <typename Lhs, typename Rhs>
struct TransposeProductKind {
static constexpr int Kind = (TransposeKind<Lhs>::Kind << 8) | TransposeKind<Rhs>::Kind;
};
template <typename Lhs, typename Rhs, int Option, int Kind = TransposeProductKind<Lhs, Rhs>::Kind>
struct product_transpose_helper {
// by default, don't optimize the transposed product
using Derived = Product<Lhs, Rhs, Option>;
using Scalar = typename Derived::Scalar;
using TransposeType = Transpose<const Derived>;
using ConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<Scalar>, TransposeType>;
using AdjointType = std::conditional_t<NumTraits<Scalar>::IsComplex, ConjugateTransposeType, TransposeType>;
// return (lhs * rhs)^T
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
return TransposeType(derived);
}
// return (lhs * rhs)^H
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
return AdjointType(TransposeType(derived));
}
};
template <typename Lhs, typename Rhs, int Option>
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::MatrixMatrix> {
// expand the transposed matrix-matrix product
using Derived = Product<Lhs, Rhs, Option>;
using LhsScalar = typename traits<Lhs>::Scalar;
using LhsTransposeType = typename DenseBase<Lhs>::ConstTransposeReturnType;
using LhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
using LhsAdjointType =
std::conditional_t<NumTraits<LhsScalar>::IsComplex, LhsConjugateTransposeType, LhsTransposeType>;
using RhsScalar = typename traits<Rhs>::Scalar;
using RhsTransposeType = typename DenseBase<Rhs>::ConstTransposeReturnType;
using RhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
using RhsAdjointType =
std::conditional_t<NumTraits<RhsScalar>::IsComplex, RhsConjugateTransposeType, RhsTransposeType>;
using TransposeType = Product<RhsTransposeType, LhsTransposeType, Option>;
using AdjointType = Product<RhsAdjointType, LhsAdjointType, Option>;
// return rhs^T * lhs^T
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
return TransposeType(RhsTransposeType(derived.rhs()), LhsTransposeType(derived.lhs()));
}
// return rhs^H * lhs^H
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())),
LhsAdjointType(LhsTransposeType(derived.lhs())));
}
};
template <typename Lhs, typename Rhs, int Option>
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::PermutationMatrix> {
// expand the transposed permutation-matrix product
using Derived = Product<Lhs, Rhs, Option>;
using LhsInverseType = typename PermutationBase<Lhs>::InverseReturnType;
using RhsScalar = typename traits<Rhs>::Scalar;
using RhsTransposeType = typename DenseBase<Rhs>::ConstTransposeReturnType;
using RhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<RhsScalar>, RhsTransposeType>;
using RhsAdjointType =
std::conditional_t<NumTraits<RhsScalar>::IsComplex, RhsConjugateTransposeType, RhsTransposeType>;
using TransposeType = Product<RhsTransposeType, LhsInverseType, Option>;
using AdjointType = Product<RhsAdjointType, LhsInverseType, Option>;
// return rhs^T * lhs^-1
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
return TransposeType(RhsTransposeType(derived.rhs()), LhsInverseType(derived.lhs()));
}
// return rhs^H * lhs^-1
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
return AdjointType(RhsAdjointType(RhsTransposeType(derived.rhs())), LhsInverseType(derived.lhs()));
}
};
template <typename Lhs, typename Rhs, int Option>
struct product_transpose_helper<Lhs, Rhs, Option, TransposeProductEnum::MatrixPermutation> {
// expand the transposed matrix-permutation product
using Derived = Product<Lhs, Rhs, Option>;
using LhsScalar = typename traits<Lhs>::Scalar;
using LhsTransposeType = typename DenseBase<Lhs>::ConstTransposeReturnType;
using LhsConjugateTransposeType = CwiseUnaryOp<scalar_conjugate_op<LhsScalar>, LhsTransposeType>;
using LhsAdjointType =
std::conditional_t<NumTraits<LhsScalar>::IsComplex, LhsConjugateTransposeType, LhsTransposeType>;
using RhsInverseType = typename PermutationBase<Rhs>::InverseReturnType;
using TransposeType = Product<RhsInverseType, LhsTransposeType, Option>;
using AdjointType = Product<RhsInverseType, LhsAdjointType, Option>;
// return rhs^-1 * lhs^T
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeType run_transpose(const Derived& derived) {
return TransposeType(RhsInverseType(derived.rhs()), LhsTransposeType(derived.lhs()));
}
// return rhs^-1 * lhs^H
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointType run_adjoint(const Derived& derived) {
return AdjointType(RhsInverseType(derived.rhs()), LhsAdjointType(LhsTransposeType(derived.lhs())));
}
};
} // end namespace internal } // end namespace internal
/** \class Product /** \class Product
@ -93,6 +216,9 @@ class Product
typedef internal::remove_all_t<LhsNested> LhsNestedCleaned; typedef internal::remove_all_t<LhsNested> LhsNestedCleaned;
typedef internal::remove_all_t<RhsNested> RhsNestedCleaned; typedef internal::remove_all_t<RhsNested> RhsNestedCleaned;
using TransposeReturnType = typename internal::product_transpose_helper<Lhs, Rhs, Option>::TransposeType;
using AdjointReturnType = typename internal::product_transpose_helper<Lhs, Rhs, Option>::AdjointType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Product(const Lhs& lhs, const Rhs& rhs) : m_lhs(lhs), m_rhs(rhs) {
eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" && eigen_assert(lhs.cols() == rhs.rows() && "invalid matrix product" &&
"if you wanted a coeff-wise or a dot product use the respective explicit functions"); "if you wanted a coeff-wise or a dot product use the respective explicit functions");
@ -104,6 +230,13 @@ class Product
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const LhsNestedCleaned& lhs() const { return m_lhs; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const RhsNestedCleaned& rhs() const { return m_rhs; }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TransposeReturnType transpose() const {
return internal::product_transpose_helper<Lhs, Rhs, Option>::run_transpose(*this);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE AdjointReturnType adjoint() const {
return internal::product_transpose_helper<Lhs, Rhs, Option>::run_adjoint(*this);
}
protected: protected:
LhsNested m_lhs; LhsNested m_lhs;
RhsNested m_rhs; RhsNested m_rhs;
@ -112,12 +245,12 @@ class Product
namespace internal { namespace internal {
template <typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs, Rhs>::ret> template <typename Lhs, typename Rhs, int Option, int ProductTag = internal::product_type<Lhs, Rhs>::ret>
class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type {}; class dense_product_base : public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {};
/** Conversion to scalar for inner-products */ /** Conversion to scalar for inner-products */
template <typename Lhs, typename Rhs, int Option> template <typename Lhs, typename Rhs, int Option>
class dense_product_base<Lhs, Rhs, Option, InnerProduct> class dense_product_base<Lhs, Rhs, Option, InnerProduct>
: public internal::dense_xpr_base<Product<Lhs, Rhs, Option> >::type { : public internal::dense_xpr_base<Product<Lhs, Rhs, Option>>::type {
typedef Product<Lhs, Rhs, Option> ProductXpr; typedef Product<Lhs, Rhs, Option> ProductXpr;
typedef typename internal::dense_xpr_base<ProductXpr>::type Base; typedef typename internal::dense_xpr_base<ProductXpr>::type Base;

View File

@ -928,6 +928,12 @@ template <typename XprType, int BlockRows, int BlockCols, bool InnerPanel>
struct block_xpr_helper<const Block<XprType, BlockRows, BlockCols, InnerPanel>> struct block_xpr_helper<const Block<XprType, BlockRows, BlockCols, InnerPanel>>
: block_xpr_helper<Block<XprType, BlockRows, BlockCols, InnerPanel>> {}; : block_xpr_helper<Block<XprType, BlockRows, BlockCols, InnerPanel>> {};
template <typename XprType>
struct is_matrix_base_xpr : std::is_base_of<MatrixBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
template <typename XprType>
struct is_permutation_base_xpr : std::is_base_of<PermutationBase<remove_all_t<XprType>>, remove_all_t<XprType>> {};
} // end namespace internal } // end namespace internal
/** \class ScalarBinaryOpTraits /** \class ScalarBinaryOpTraits

View File

@ -65,6 +65,8 @@ void product_notemporary(const MatrixType& m) {
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1); VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()), 1);
VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1); VERIFY_EVALUATION_COUNT(m3 = (m1 * m2.adjoint()).transpose(), 1);
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.adjoint()).transpose(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() = (m1 * m2.transpose()).adjoint(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() = m1 * m2.adjoint(), 0);
VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1); VERIFY_EVALUATION_COUNT(m3 = s1 * (m1 * m2.transpose()), 1);
@ -75,6 +77,8 @@ void product_notemporary(const MatrixType& m) {
VERIFY_EVALUATION_COUNT(m3 = m3 - (m1 * m2.adjoint()), 1); VERIFY_EVALUATION_COUNT(m3 = m3 - (m1 * m2.adjoint()), 1);
VERIFY_EVALUATION_COUNT(m3 = m3 + (m1 * m2.adjoint()).transpose(), 1); VERIFY_EVALUATION_COUNT(m3 = m3 + (m1 * m2.adjoint()).transpose(), 1);
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.adjoint()).transpose(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + (m1 * m2.transpose()).adjoint(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() = m3 + m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() += m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() += m3 + m1 * m2.transpose(), 0);
VERIFY_EVALUATION_COUNT(m3.noalias() -= m3 + m1 * m2.transpose(), 0); VERIFY_EVALUATION_COUNT(m3.noalias() -= m3 + m1 * m2.transpose(), 0);