From cf12474a8b7c121f7ca95e28f8ee3f8d9405cd41 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Thu, 12 Nov 2020 18:02:37 +0000 Subject: [PATCH] Optimize matrix*matrix and matrix*vector products when they correspond to inner products at runtime. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This speeds up inner products where the one or or both arguments is dynamic for small and medium-sized vectors (up to 32k). name old time/op new time/op delta BM_VecVecStatStat/1 1.64ns ± 0% 1.64ns ± 0% ~ BM_VecVecStatStat/8 2.99ns ± 0% 2.99ns ± 0% ~ BM_VecVecStatStat/64 7.00ns ± 1% 7.04ns ± 0% +0.66% BM_VecVecStatStat/512 61.6ns ± 0% 61.6ns ± 0% ~ BM_VecVecStatStat/4k 551ns ± 0% 553ns ± 1% +0.26% BM_VecVecStatStat/32k 4.45µs ± 0% 4.45µs ± 0% ~ BM_VecVecStatStat/256k 77.9µs ± 0% 78.1µs ± 1% ~ BM_VecVecStatStat/1M 312µs ± 0% 312µs ± 1% ~ BM_VecVecDynStat/1 13.3ns ± 1% 4.6ns ± 0% -65.35% BM_VecVecDynStat/8 14.4ns ± 0% 6.2ns ± 0% -57.00% BM_VecVecDynStat/64 24.0ns ± 0% 10.2ns ± 3% -57.57% BM_VecVecDynStat/512 138ns ± 0% 68ns ± 0% -50.52% BM_VecVecDynStat/4k 1.11µs ± 0% 0.56µs ± 0% -49.72% BM_VecVecDynStat/32k 8.89µs ± 0% 4.46µs ± 0% -49.89% BM_VecVecDynStat/256k 78.2µs ± 0% 78.1µs ± 1% ~ BM_VecVecDynStat/1M 313µs ± 0% 312µs ± 1% ~ BM_VecVecDynDyn/1 10.4ns ± 0% 10.5ns ± 0% +0.91% BM_VecVecDynDyn/8 12.0ns ± 3% 11.9ns ± 0% ~ BM_VecVecDynDyn/64 37.4ns ± 0% 19.6ns ± 1% -47.57% BM_VecVecDynDyn/512 159ns ± 0% 81ns ± 0% -49.07% BM_VecVecDynDyn/4k 1.13µs ± 0% 0.58µs ± 1% -49.11% BM_VecVecDynDyn/32k 8.91µs ± 0% 5.06µs ±12% -43.23% BM_VecVecDynDyn/256k 78.2µs ± 0% 78.2µs ± 1% ~ BM_VecVecDynDyn/1M 313µs ± 0% 312µs ± 1% ~ --- Eigen/src/Core/ProductEvaluators.h | 7 ++++++- Eigen/src/Core/products/GeneralMatrixMatrix.h | 10 ++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h index 792b1811c..6b32c508b 100644 --- a/Eigen/src/Core/ProductEvaluators.h +++ b/Eigen/src/Core/ProductEvaluators.h @@ -256,7 +256,7 @@ struct generic_product_impl { dst.coeffRef(0,0) += (lhs.transpose().cwiseProduct(rhs)).sum(); } - + template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) { dst.coeffRef(0,0) -= (lhs.transpose().cwiseProduct(rhs)).sum(); } @@ -375,6 +375,11 @@ struct generic_product_impl template static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) { + // Fallback to inner product if both the lhs and rhs is a runtime vector. + if (lhs.rows() == 1 && rhs.cols() == 1) { + dst.coeffRef(0,0) += alpha * (lhs.row(0).transpose().cwiseProduct(rhs.col(0)).sum()); + return; + } LhsNested actual_lhs(lhs); RhsNested actual_rhs(rhs); internal::gemv_dense_selector if(a_lhs.cols()==0 || a_lhs.rows()==0 || a_rhs.cols()==0) return; - // Fallback to GEMV if either the lhs or rhs is a runtime vector - if (dst.cols() == 1) + if (dst.cols() == 1 && dst.rows() == 1) { + // Fallback to inner product if both the lhs and rhs is a runtime vector. + dst.coeffRef(0,0) += alpha * (a_lhs.row(0).transpose().cwiseProduct(a_rhs.col(0)).sum()); + return; + } + else if (dst.cols() == 1) { + // Fallback to GEMV if either the lhs or rhs is a runtime vector typename Dest::ColXpr dst_vec(dst.col(0)); return internal::generic_product_impl ::scaleAndAddTo(dst_vec, a_lhs, a_rhs.col(0), alpha); } else if (dst.rows() == 1) { + // Fallback to GEMV if either the lhs or rhs is a runtime vector typename Dest::RowXpr dst_vec(dst.row(0)); return internal::generic_product_impl ::scaleAndAddTo(dst_vec, a_lhs.row(0), a_rhs, alpha);