mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 20:56:00 +08:00
Fix self-adjoint products when multiplying by a compile-time vector.
This commit is contained in:
parent
6854da2ea0
commit
bd0cd1d67b
@ -846,7 +846,7 @@ struct generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>
|
||||
|
||||
template <typename Dest>
|
||||
static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
|
||||
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::IsVectorAtCompileTime>::run(
|
||||
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::ColsAtCompileTime == 1>::run(
|
||||
dst, lhs.nestedExpression(), rhs, alpha);
|
||||
}
|
||||
};
|
||||
@ -858,7 +858,7 @@ struct generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>
|
||||
|
||||
template <typename Dest>
|
||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
|
||||
selfadjoint_product_impl<Lhs, 0, Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, Rhs::Mode, false>::run(
|
||||
selfadjoint_product_impl<Lhs, 0, Lhs::RowsAtCompileTime == 1, typename Rhs::MatrixType, Rhs::Mode, false>::run(
|
||||
dst, lhs, rhs.nestedExpression(), alpha);
|
||||
}
|
||||
};
|
||||
|
@ -164,6 +164,11 @@ struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
|
||||
|
||||
enum { LhsUpLo = LhsMode & (Upper | Lower) };
|
||||
|
||||
// Verify that the Rhs is a vector in the correct orientation.
|
||||
// Otherwise, we break the assumption that we are multiplying
|
||||
// MxN * Nx1.
|
||||
static_assert(Rhs::ColsAtCompileTime == 1, "The RHS must be a column vector.");
|
||||
|
||||
template <typename Dest>
|
||||
static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
|
||||
typedef typename Dest::Scalar ResScalar;
|
||||
@ -173,11 +178,6 @@ struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
|
||||
|
||||
eigen_assert(dest.rows() == a_lhs.rows() && dest.cols() == a_rhs.cols());
|
||||
|
||||
if (a_lhs.rows() == 1) {
|
||||
dest = (alpha * a_lhs.coeff(0, 0)) * a_rhs;
|
||||
return;
|
||||
}
|
||||
|
||||
add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
|
||||
add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user