From bd0cd1d67b7d93605da49924588ce4bd6a816124 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Tue, 8 Jul 2025 21:48:59 +0000 Subject: [PATCH] Fix self-adjoint products when multiplying by a compile-time vector. --- Eigen/src/Core/ProductEvaluators.h | 4 ++-- Eigen/src/Core/products/SelfadjointMatrixVector.h | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index ce8d954bf..a23004458 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -846,7 +846,7 @@ struct generic_product_impl template static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { - selfadjoint_product_impl::run( + selfadjoint_product_impl::run( dst, lhs.nestedExpression(), rhs, alpha); } }; @@ -858,7 +858,7 @@ struct generic_product_impl template static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { - selfadjoint_product_impl::run( + selfadjoint_product_impl::run( dst, lhs, rhs.nestedExpression(), alpha); } }; diff --git a/Eigen/src/Core/products/SelfadjointMatrixVector.h b/Eigen/src/Core/products/SelfadjointMatrixVector.h index f7387601f..580f6a850 100644 --- a/Eigen/src/Core/products/SelfadjointMatrixVector.h +++ b/Eigen/src/Core/products/SelfadjointMatrixVector.h @@ -164,6 +164,11 @@ struct selfadjoint_product_impl { enum { LhsUpLo = LhsMode & (Upper | Lower) }; + // Verify that the Rhs is a vector in the correct orientation. + // Otherwise, we break the assumption that we are multiplying + // MxN * Nx1. + static_assert(Rhs::ColsAtCompileTime == 1, "The RHS must be a column vector."); + template static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) { typedef typename Dest::Scalar ResScalar; @@ -173,11 +178,6 @@ struct selfadjoint_product_impl { eigen_assert(dest.rows() == a_lhs.rows() && dest.cols() == a_rhs.cols()); - if (a_lhs.rows() == 1) { - dest = (alpha * a_lhs.coeff(0, 0)) * a_rhs; - return; - } - add_const_on_value_type_t lhs = LhsBlasTraits::extract(a_lhs); add_const_on_value_type_t rhs = RhsBlasTraits::extract(a_rhs);