From a4760548793811ee1accf8de05ff791a43d54be5 Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Fri, 23 Nov 2018 10:25:19 +0100 Subject: [PATCH] bug #1624: improve matrix-matrix product on ARM 64, 20% speedup --- .../Core/products/GeneralBlockPanelKernel.h | 53 +++++++++++++++---- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index e7cab4720..b8b83c320 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -15,7 +15,7 @@ namespace Eigen { namespace internal { -template +template class gebp_traits; @@ -347,7 +347,7 @@ inline void computeProductBlockingSizes(Index& k, Index& m, Index& n, Index num_ * cplx*real : unpack rhs to constant packets, ... * real*cplx : load lhs as (a0,a0,a1,a1), and mul as usual */ -template +template class gebp_traits { public: @@ -461,8 +461,8 @@ public: }; -template -class gebp_traits, RealScalar, _ConjLhs, false> +template +class gebp_traits, RealScalar, _ConjLhs, false, Arch> { public: typedef std::complex LhsScalar; @@ -597,8 +597,8 @@ template struct unpacket_traits > { typede // return res; // } -template -class gebp_traits, std::complex, _ConjLhs, _ConjRhs > +template +class gebp_traits, std::complex, _ConjLhs, _ConjRhs,Arch> { public: typedef std::complex Scalar; @@ -746,8 +746,8 @@ protected: conj_helper cj; }; -template -class gebp_traits, false, _ConjRhs > +template +class gebp_traits, false, _ConjRhs, Arch> { public: typedef std::complex Scalar; @@ -852,7 +852,42 @@ protected: conj_helper cj; }; -/* optimized GEneral packed Block * packed Panel product kernel + +#if EIGEN_ARCH_ARM64 + +template<> +struct gebp_traits + : gebp_traits +{ + typedef float32x2_t RhsPacket; + + EIGEN_STRONG_INLINE void broadcastRhs(const RhsScalar* b, RhsPacket& b0, RhsPacket& b1, RhsPacket& b2, RhsPacket& b3) + { + loadRhs(b+0, b0); + loadRhs(b+1, b1); + loadRhs(b+2, b2); + loadRhs(b+3, b3); + } + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + { + dest = vld1_f32(b); + } + + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const + { + loadRhs(b,dest); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/) const + { + c = vfmaq_lane_f32(c, a, b, 0); + } +}; + +#endif + +/* optimized General packed Block * packed Panel product kernel * * Mixing type logic: C += A * B * | A | B | comments