mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-15 05:05:58 +08:00
Fix self-adjoint products when multiplying by a compile-time vector.
This commit is contained in:
parent
6854da2ea0
commit
bd0cd1d67b
@ -846,7 +846,7 @@ struct generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>
|
|||||||
|
|
||||||
template <typename Dest>
|
template <typename Dest>
|
||||||
static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
|
static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
|
||||||
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::IsVectorAtCompileTime>::run(
|
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::ColsAtCompileTime == 1>::run(
|
||||||
dst, lhs.nestedExpression(), rhs, alpha);
|
dst, lhs.nestedExpression(), rhs, alpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -858,7 +858,7 @@ struct generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>
|
|||||||
|
|
||||||
template <typename Dest>
|
template <typename Dest>
|
||||||
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
|
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
|
||||||
selfadjoint_product_impl<Lhs, 0, Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, Rhs::Mode, false>::run(
|
selfadjoint_product_impl<Lhs, 0, Lhs::RowsAtCompileTime == 1, typename Rhs::MatrixType, Rhs::Mode, false>::run(
|
||||||
dst, lhs, rhs.nestedExpression(), alpha);
|
dst, lhs, rhs.nestedExpression(), alpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -164,6 +164,11 @@ struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
|
|||||||
|
|
||||||
enum { LhsUpLo = LhsMode & (Upper | Lower) };
|
enum { LhsUpLo = LhsMode & (Upper | Lower) };
|
||||||
|
|
||||||
|
// Verify that the Rhs is a vector in the correct orientation.
|
||||||
|
// Otherwise, we break the assumption that we are multiplying
|
||||||
|
// MxN * Nx1.
|
||||||
|
static_assert(Rhs::ColsAtCompileTime == 1, "The RHS must be a column vector.");
|
||||||
|
|
||||||
template <typename Dest>
|
template <typename Dest>
|
||||||
static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
|
static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
|
||||||
typedef typename Dest::Scalar ResScalar;
|
typedef typename Dest::Scalar ResScalar;
|
||||||
@ -173,11 +178,6 @@ struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
|
|||||||
|
|
||||||
eigen_assert(dest.rows() == a_lhs.rows() && dest.cols() == a_rhs.cols());
|
eigen_assert(dest.rows() == a_lhs.rows() && dest.cols() == a_rhs.cols());
|
||||||
|
|
||||||
if (a_lhs.rows() == 1) {
|
|
||||||
dest = (alpha * a_lhs.coeff(0, 0)) * a_rhs;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
|
add_const_on_value_type_t<ActualLhsType> lhs = LhsBlasTraits::extract(a_lhs);
|
||||||
add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
|
add_const_on_value_type_t<ActualRhsType> rhs = RhsBlasTraits::extract(a_rhs);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user