Fix self-adjoint products when multiplying by a compile-time vector.

This commit is contained in:
Antonio Sánchez 2025-07-08 21:48:59 +00:00
parent 6854da2ea0
commit bd0cd1d67b
2 changed files with 7 additions and 7 deletions

View File

@ -846,7 +846,7 @@ struct generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>
template <typename Dest>
static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::IsVectorAtCompileTime>::run(
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::ColsAtCompileTime == 1>::run(
dst, lhs.nestedExpression(), rhs, alpha);
}
};
@ -858,7 +858,7 @@ struct generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>
template <typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
selfadjoint_product_impl<Lhs, 0, Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, Rhs::Mode, false>::run(
selfadjoint_product_impl<Lhs, 0, Lhs::RowsAtCompileTime == 1, typename Rhs::MatrixType, Rhs::Mode, false>::run(
dst, lhs, rhs.nestedExpression(), alpha);
}
};

View File

@ -164,6 +164,11 @@ struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
enum { LhsUpLo = LhsMode & (Upper | Lower) };
// Verify that the Rhs is a vector in the correct orientation.
// Otherwise, we break the assumption that we are multiplying
// MxN * Nx1.
static_assert(Rhs::ColsAtCompileTime == 1, "The RHS must be a column vector.");
template <typename Dest>
static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
typedef typename Dest::Scalar ResScalar;
@ -173,11 +178,6 @@ struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
eigen_assert(dest.rows() == a_lhs.rows() && dest.cols() == a_rhs.cols());
if (a_lhs.rows() == 1) {
dest = (alpha * a_lhs.coeff(0, 0)) * a_rhs;
return;
}
add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);