diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 792b1811c..6b32c508b 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -256,7 +256,7 @@ struct generic_product_impl { dst.coeffRef(0,0) += (lhs.transpose().cwiseProduct(rhs)).sum(); } - + template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { dst.coeffRef(0,0) -= (lhs.transpose().cwiseProduct(rhs)).sum(); } @@ -375,6 +375,11 @@ struct generic_product_impl template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { + // Fallback to inner product if both the lhs and rhs is a runtime vector. + if (lhs.rows() == 1 && rhs.cols() == 1) { + dst.coeffRef(0,0) += alpha * (lhs.row(0).transpose().cwiseProduct(rhs.col(0)).sum()); + return; + } LhsNested actual_lhs(lhs); RhsNested actual_rhs(rhs); internal::gemv_dense_selector if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0) return; - // Fallback to GEMV if either the lhs or rhs is a runtime vector - if (dst.cols() == 1) + if (dst.cols() == 1 && dst.rows() == 1) { + // Fallback to inner product if both the lhs and rhs is a runtime vector. + dst.coeffRef(0,0) += alpha * (a_lhs.row(0).transpose().cwiseProduct(a_rhs.col(0)).sum()); + return; + } + else if (dst.cols() == 1) { + // Fallback to GEMV if either the lhs or rhs is a runtime vector typename Dest::ColXpr dst_vec(dst.col(0)); return internal::generic_product_impl ::scaleAndAddTo(dst_vec, a_lhs, a_rhs.col(0), alpha); } else if (dst.rows() == 1) { + // Fallback to GEMV if either the lhs or rhs is a runtime vector typename Dest::RowXpr dst_vec(dst.row(0)); return internal::generic_product_impl ::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);