From 23299632c246b77937fb78e8607863a2f02e191b Mon Sep 17 00:00:00 2001 From: Lianhuang Li Date: Wed, 21 Sep 2022 16:36:40 +0000 Subject: [PATCH] Use 3px8/2px8/1px8/1x8 gebp_kernel on arm64-neon --- .../Core/arch/NEON/GeneralBlockPanelKernel.h | 94 ++- .../Core/products/GeneralBlockPanelKernel.h | 799 ++++++++++++++++-- 2 files changed, 796 insertions(+), 97 deletions(-) diff --git a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h index 6cd6edd56..5022205fc 100644 --- a/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/arch/NEON/GeneralBlockPanelKernel.h @@ -49,7 +49,9 @@ struct gebp_traits { typedef float RhsPacket; typedef float32x4_t RhsPacketx4; - + enum { + nr = 8 + }; EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const { dest = *b; @@ -77,7 +79,6 @@ struct gebp_traits { c = vfmaq_n_f32(c, a, b); } - // NOTE: Template parameter inference failed when compiled with Android NDK: // "candidate template ignored: could not match 'FixedInt' against 'Eigen::internal::FixedInt<0>". @@ -94,9 +95,10 @@ struct gebp_traits template EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const { - #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0)) - // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 - // vfmaq_laneq_f32 is implemented through a costly dup + #if EIGEN_COMP_GNUC_STRICT + // 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 + // vfmaq_laneq_f32 is implemented through a costly dup, which was fixed in gcc9 + // 2. workaround the gcc register split problem on arm64-neon if(LaneID==0) asm("fmla %0.4s, %1.4s, %2.s[0]\n" : "+w" (c) : "w" (a), "w" (b) : ); else if(LaneID==1) asm("fmla %0.4s, %1.4s, %2.s[1]\n" : "+w" (c) : "w" (a), "w" (b) : ); else if(LaneID==2) asm("fmla %0.4s, %1.4s, %2.s[2]\n" : "+w" (c) : "w" (a), "w" (b) : ); @@ -113,7 +115,9 @@ struct gebp_traits : gebp_traits { typedef double RhsPacket; - + enum { + nr = 8 + }; struct RhsPacketx4 { float64x2_t B_0, B_1; }; @@ -163,9 +167,10 @@ struct gebp_traits template EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const { - #if EIGEN_COMP_GNUC_STRICT && !(EIGEN_GNUC_AT_LEAST(9,0)) - // workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 - // vfmaq_laneq_f64 is implemented through a costly dup + #if EIGEN_COMP_GNUC_STRICT + // 1. workaround gcc issue https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101 + // vfmaq_laneq_f64 is implemented through a costly dup, which was fixed in gcc9 + // 2. workaround the gcc register split problem on arm64-neon if(LaneID==0) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : ); else if(LaneID==1) asm("fmla %0.2d, %1.2d, %2.d[1]\n" : "+w" (c) : "w" (a), "w" (b.B_0) : ); else if(LaneID==2) asm("fmla %0.2d, %1.2d, %2.d[0]\n" : "+w" (c) : "w" (a), "w" (b.B_1) : ); @@ -179,6 +184,77 @@ struct gebp_traits } }; +#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC + +template<> +struct gebp_traits + : gebp_traits +{ + typedef half RhsPacket; + typedef float16x4_t RhsPacketx4; + typedef float16x4_t PacketHalf; + enum { + nr = 8 + }; + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacket& dest) const + { + dest = *b; + } + + EIGEN_STRONG_INLINE void loadRhs(const RhsScalar* b, RhsPacketx4& dest) const + { + dest = vld1_f16((const __fp16 *)b); + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar* b, RhsPacket& dest) const + { + dest = *b; + } + + EIGEN_STRONG_INLINE void updateRhs(const RhsScalar*, RhsPacketx4&) const + {} + + EIGEN_STRONG_INLINE void loadRhsQuad(const RhsScalar* b, RhsPacket& dest) const + { + loadRhs(b,dest); + } + + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { + c = vfmaq_n_f16(c, a, b); + } + EIGEN_STRONG_INLINE void madd(const PacketHalf& a, const RhsPacket& b, PacketHalf& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { + c = vfma_n_f16(c, a, b); + } + + // NOTE: Template parameter inference failed when compiled with Android NDK: + // "candidate template ignored: could not match 'FixedInt' against 'Eigen::internal::FixedInt<0>". + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const + { madd_helper<0>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<1>&) const + { madd_helper<1>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<2>&) const + { madd_helper<2>(a, b, c); } + EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<3>&) const + { madd_helper<3>(a, b, c); } + private: + template + EIGEN_STRONG_INLINE void madd_helper(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c) const + { + #if EIGEN_COMP_GNUC_STRICT + // 1. vfmaq_lane_f16 is implemented through a costly dup + // 2. workaround the gcc register split problem on arm64-neon + if(LaneID==0) asm("fmla %0.8h, %1.8h, %2.h[0]\n" : "+w" (c) : "w" (a), "w" (b) : ); + else if(LaneID==1) asm("fmla %0.8h, %1.8h, %2.h[1]\n" : "+w" (c) : "w" (a), "w" (b) : ); + else if(LaneID==2) asm("fmla %0.8h, %1.8h, %2.h[2]\n" : "+w" (c) : "w" (a), "w" (b) : ); + else if(LaneID==3) asm("fmla %0.8h, %1.8h, %2.h[3]\n" : "+w" (c) : "w" (a), "w" (b) : ); + #else + c = vfmaq_lane_f16(c, a, b, LaneID); + #endif + } +}; +#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC #endif // EIGEN_ARCH_ARM64 } // namespace internal diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index b1a127754..0b076510b 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -1070,6 +1070,7 @@ struct gebp_kernel typedef typename Traits::RhsPacketx4 RhsPacketx4; typedef typename RhsPanelHelper::type RhsPanel15; + typedef typename RhsPanelHelper::type RhsPanel27; typedef gebp_traits SwappedTraits; @@ -1215,13 +1216,135 @@ struct lhs_process_one_packet int prefetch_res_offset, Index peeled_kc, Index pk, Index cols, Index depth, Index packet_cols4) { GEBPTraits traits; - + Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; // loops on each largest micro horizontal panel of lhs // (LhsProgress x depth) for(Index i=peelStart; i); \ + traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C4, T0, fix<0>); \ + traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C5, T0, fix<1>); \ + traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C6, T0, fix<2>); \ + traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C7, T0, fix<3>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 1pX8"); \ + } while (false) + + EIGEN_ASM_COMMENT("begin gebp micro kernel 1pX8"); + + EIGEN_GEBGP_ONESTEP(0); + EIGEN_GEBGP_ONESTEP(1); + EIGEN_GEBGP_ONESTEP(2); + EIGEN_GEBGP_ONESTEP(3); + EIGEN_GEBGP_ONESTEP(4); + EIGEN_GEBGP_ONESTEP(5); + EIGEN_GEBGP_ONESTEP(6); + EIGEN_GEBGP_ONESTEP(7); + + blB += pk*8*RhsProgress; + blA += pk*(1*LhsProgress); + + EIGEN_ASM_COMMENT("end gebp micro kernel 1pX8"); + } + // process remaining peeled loop + for(Index k=peeled_kc; k(alpha); + + R0 = r0.template loadPacket(0); + R1 = r1.template loadPacket(0); + traits.acc(C0, alphav, R0); + traits.acc(C1, alphav, R1); + r0.storePacket(0, R0); + r1.storePacket(0, R1); + + R0 = r2.template loadPacket(0); + R1 = r3.template loadPacket(0); + traits.acc(C2, alphav, R0); + traits.acc(C3, alphav, R1); + r2.storePacket(0, R0); + r3.storePacket(0, R1); + + R0 = r4.template loadPacket(0); + R1 = r5.template loadPacket(0); + traits.acc(C4, alphav, R0); + traits.acc(C5, alphav, R1); + r4.storePacket(0, R0); + r5.storePacket(0, R1); + + R0 = r6.template loadPacket(0); + R1 = r7.template loadPacket(0); + traits.acc(C6, alphav, R0); + traits.acc(C7, alphav, R1); + r6.storePacket(0, R0); + r7.storePacket(0, R1); + } // loops on each largest micro vertical panel of rhs (depth * nr) - for(Index j2=0; j2 cj; Index packet_cols4 = nr>=4 ? (cols/4) * 4 : 0; + Index packet_cols8 = nr>=8 ? (cols/8) * 8 : 0; const Index peeled_mc3 = mr>=3*Traits::LhsProgress ? (rows/(3*LhsProgress))*(3*LhsProgress) : 0; const Index peeled_mc2 = mr>=2*Traits::LhsProgress ? peeled_mc3+((rows-peeled_mc3)/(2*LhsProgress))*(2*LhsProgress) : 0; const Index peeled_mc1 = mr>=1*Traits::LhsProgress ? peeled_mc2+((rows-peeled_mc2)/(1*LhsProgress))*(1*LhsProgress) : 0; @@ -1443,7 +1567,219 @@ void gebp_kernel= 8 + for(Index j2=0; j2); \ + traits.madd(A1, rhs_panel, C8, T0, fix<0>); \ + traits.madd(A2, rhs_panel, C16, T0, fix<0>); \ + traits.updateRhs(blB + (1 + 8 * K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C9, T0, fix<1>); \ + traits.madd(A2, rhs_panel, C17, T0, fix<1>); \ + traits.updateRhs(blB + (2 + 8 * K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C10, T0, fix<2>); \ + traits.madd(A2, rhs_panel, C18, T0, fix<2>); \ + traits.updateRhs(blB + (3 + 8 * K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C11, T0, fix<3>); \ + traits.madd(A2, rhs_panel, C19, T0, fix<3>); \ + traits.loadRhs(blB + (4 + 8 * K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C4, T0, fix<0>); \ + traits.madd(A1, rhs_panel, C12, T0, fix<0>); \ + traits.madd(A2, rhs_panel, C20, T0, fix<0>); \ + traits.updateRhs(blB + (5 + 8 * K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C5, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C13, T0, fix<1>); \ + traits.madd(A2, rhs_panel, C21, T0, fix<1>); \ + traits.updateRhs(blB + (6 + 8 * K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C6, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C14, T0, fix<2>); \ + traits.madd(A2, rhs_panel, C22, T0, fix<2>); \ + traits.updateRhs(blB + (7 + 8 * K) * Traits::RhsProgress, rhs_panel); \ + traits.madd(A0, rhs_panel, C7, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C15, T0, fix<3>); \ + traits.madd(A2, rhs_panel, C23, T0, fix<3>); \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 3pX8"); \ + } while (false) + + EIGEN_GEBP_ONESTEP(0); + EIGEN_GEBP_ONESTEP(1); + EIGEN_GEBP_ONESTEP(2); + EIGEN_GEBP_ONESTEP(3); + EIGEN_GEBP_ONESTEP(4); + EIGEN_GEBP_ONESTEP(5); + EIGEN_GEBP_ONESTEP(6); + EIGEN_GEBP_ONESTEP(7); + + blB += pk * 8 * RhsProgress; + blA += pk * 3 * Traits::LhsProgress; + EIGEN_ASM_COMMENT("end gebp micro kernel 3pX8"); + } + + // process remaining peeled loop + for (Index k = peeled_kc; k < depth; k++) + { + + RhsPanel27 rhs_panel; + RhsPacket T0; + LhsPacket A2; + EIGEN_GEBP_ONESTEP(0); + blB += 8 * RhsProgress; + blA += 3 * Traits::LhsProgress; + } + + #undef EIGEN_GEBP_ONESTEP + + ResPacket R0, R1, R2; + ResPacket alphav = pset1(alpha); + + R0 = r0.template loadPacket(0 * Traits::ResPacketSize); + R1 = r0.template loadPacket(1 * Traits::ResPacketSize); + R2 = r0.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C0, alphav, R0); + traits.acc(C8, alphav, R1); + traits.acc(C16, alphav, R2); + r0.storePacket(0 * Traits::ResPacketSize, R0); + r0.storePacket(1 * Traits::ResPacketSize, R1); + r0.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r1.template loadPacket(0 * Traits::ResPacketSize); + R1 = r1.template loadPacket(1 * Traits::ResPacketSize); + R2 = r1.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C1, alphav, R0); + traits.acc(C9, alphav, R1); + traits.acc(C17, alphav, R2); + r1.storePacket(0 * Traits::ResPacketSize, R0); + r1.storePacket(1 * Traits::ResPacketSize, R1); + r1.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r2.template loadPacket(0 * Traits::ResPacketSize); + R1 = r2.template loadPacket(1 * Traits::ResPacketSize); + R2 = r2.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C2, alphav, R0); + traits.acc(C10, alphav, R1); + traits.acc(C18, alphav, R2); + r2.storePacket(0 * Traits::ResPacketSize, R0); + r2.storePacket(1 * Traits::ResPacketSize, R1); + r2.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r3.template loadPacket(0 * Traits::ResPacketSize); + R1 = r3.template loadPacket(1 * Traits::ResPacketSize); + R2 = r3.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C3, alphav, R0); + traits.acc(C11, alphav, R1); + traits.acc(C19, alphav, R2); + r3.storePacket(0 * Traits::ResPacketSize, R0); + r3.storePacket(1 * Traits::ResPacketSize, R1); + r3.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r4.template loadPacket(0 * Traits::ResPacketSize); + R1 = r4.template loadPacket(1 * Traits::ResPacketSize); + R2 = r4.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C4, alphav, R0); + traits.acc(C12, alphav, R1); + traits.acc(C20, alphav, R2); + r4.storePacket(0 * Traits::ResPacketSize, R0); + r4.storePacket(1 * Traits::ResPacketSize, R1); + r4.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r5.template loadPacket(0 * Traits::ResPacketSize); + R1 = r5.template loadPacket(1 * Traits::ResPacketSize); + R2 = r5.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C5, alphav, R0); + traits.acc(C13, alphav, R1); + traits.acc(C21, alphav, R2); + r5.storePacket(0 * Traits::ResPacketSize, R0); + r5.storePacket(1 * Traits::ResPacketSize, R1); + r5.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r6.template loadPacket(0 * Traits::ResPacketSize); + R1 = r6.template loadPacket(1 * Traits::ResPacketSize); + R2 = r6.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C6, alphav, R0); + traits.acc(C14, alphav, R1); + traits.acc(C22, alphav, R2); + r6.storePacket(0 * Traits::ResPacketSize, R0); + r6.storePacket(1 * Traits::ResPacketSize, R1); + r6.storePacket(2 * Traits::ResPacketSize, R2); + + R0 = r7.template loadPacket(0 * Traits::ResPacketSize); + R1 = r7.template loadPacket(1 * Traits::ResPacketSize); + R2 = r7.template loadPacket(2 * Traits::ResPacketSize); + traits.acc(C7, alphav, R0); + traits.acc(C15, alphav, R1); + traits.acc(C23, alphav, R2); + r7.storePacket(0 * Traits::ResPacketSize, R0); + r7.storePacket(1 * Traits::ResPacketSize, R1); + r7.storePacket(2 * Traits::ResPacketSize, R2); + } + } + + for(Index j2=packet_cols8; j2=6 without FMA (bug 1637) + #if EIGEN_GNUC_AT_LEAST(6,0) && defined(EIGEN_VECTORIZE_SSE) + #define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND __asm__ ("" : [a0] "+x,m" (A0),[a1] "+x,m" (A1)); + #else + #define EIGEN_GEBP_2Px8_SPILLING_WORKAROUND + #endif +#define EIGEN_GEBGP_ONESTEP(K) \ + do { \ + EIGEN_ASM_COMMENT("begin step of gebp micro kernel 2pX8"); \ + traits.loadLhs(&blA[(0 + 2 * K) * LhsProgress], A0); \ + traits.loadLhs(&blA[(1 + 2 * K) * LhsProgress], A1); \ + traits.loadRhs(&blB[(0 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C0, T0, fix<0>); \ + traits.madd(A1, rhs_panel, C8, T0, fix<0>); \ + traits.updateRhs(&blB[(1 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C1, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C9, T0, fix<1>); \ + traits.updateRhs(&blB[(2 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C2, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C10, T0, fix<2>); \ + traits.updateRhs(&blB[(3 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C3, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C11, T0, fix<3>); \ + traits.loadRhs(&blB[(4 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C4, T0, fix<0>); \ + traits.madd(A1, rhs_panel, C12, T0, fix<0>); \ + traits.updateRhs(&blB[(5 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C5, T0, fix<1>); \ + traits.madd(A1, rhs_panel, C13, T0, fix<1>); \ + traits.updateRhs(&blB[(6 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C6, T0, fix<2>); \ + traits.madd(A1, rhs_panel, C14, T0, fix<2>); \ + traits.updateRhs(&blB[(7 + 8 * K) * RhsProgress], rhs_panel); \ + traits.madd(A0, rhs_panel, C7, T0, fix<3>); \ + traits.madd(A1, rhs_panel, C15, T0, fix<3>); \ + EIGEN_GEBP_2Px8_SPILLING_WORKAROUND \ + EIGEN_ASM_COMMENT("end step of gebp micro kernel 2pX8"); \ + } while (false) + + EIGEN_ASM_COMMENT("begin gebp micro kernel 2pX8"); + + EIGEN_GEBGP_ONESTEP(0); + EIGEN_GEBGP_ONESTEP(1); + EIGEN_GEBGP_ONESTEP(2); + EIGEN_GEBGP_ONESTEP(3); + EIGEN_GEBGP_ONESTEP(4); + EIGEN_GEBGP_ONESTEP(5); + EIGEN_GEBGP_ONESTEP(6); + EIGEN_GEBGP_ONESTEP(7); + + blB += pk*8*RhsProgress; + blA += pk*(2*Traits::LhsProgress); + + EIGEN_ASM_COMMENT("end gebp micro kernel 2pX8"); + } + // process remaining peeled loop + for(Index k=peeled_kc; k(alpha); + + R0 = r0.template loadPacket(0 * Traits::ResPacketSize); + R1 = r0.template loadPacket(1 * Traits::ResPacketSize); + R2 = r1.template loadPacket(0 * Traits::ResPacketSize); + R3 = r1.template loadPacket(1 * Traits::ResPacketSize); + traits.acc(C0, alphav, R0); + traits.acc(C8, alphav, R1); + traits.acc(C1, alphav, R2); + traits.acc(C9, alphav, R3); + r0.storePacket(0 * Traits::ResPacketSize, R0); + r0.storePacket(1 * Traits::ResPacketSize, R1); + r1.storePacket(0 * Traits::ResPacketSize, R2); + r1.storePacket(1 * Traits::ResPacketSize, R3); + + R0 = r2.template loadPacket(0 * Traits::ResPacketSize); + R1 = r2.template loadPacket(1 * Traits::ResPacketSize); + R2 = r3.template loadPacket(0 * Traits::ResPacketSize); + R3 = r3.template loadPacket(1 * Traits::ResPacketSize); + traits.acc(C2, alphav, R0); + traits.acc(C10, alphav, R1); + traits.acc(C3, alphav, R2); + traits.acc(C11, alphav, R3); + r2.storePacket(0 * Traits::ResPacketSize, R0); + r2.storePacket(1 * Traits::ResPacketSize, R1); + r3.storePacket(0 * Traits::ResPacketSize, R2); + r3.storePacket(1 * Traits::ResPacketSize, R3); + + R0 = r4.template loadPacket(0 * Traits::ResPacketSize); + R1 = r4.template loadPacket(1 * Traits::ResPacketSize); + R2 = r5.template loadPacket(0 * Traits::ResPacketSize); + R3 = r5.template loadPacket(1 * Traits::ResPacketSize); + traits.acc(C4, alphav, R0); + traits.acc(C12, alphav, R1); + traits.acc(C5, alphav, R2); + traits.acc(C13, alphav, R3); + r4.storePacket(0 * Traits::ResPacketSize, R0); + r4.storePacket(1 * Traits::ResPacketSize, R1); + r5.storePacket(0 * Traits::ResPacketSize, R2); + r5.storePacket(1 * Traits::ResPacketSize, R3); + + R0 = r6.template loadPacket(0 * Traits::ResPacketSize); + R1 = r6.template loadPacket(1 * Traits::ResPacketSize); + R2 = r7.template loadPacket(0 * Traits::ResPacketSize); + R3 = r7.template loadPacket(1 * Traits::ResPacketSize); + traits.acc(C6, alphav, R0); + traits.acc(C14, alphav, R1); + traits.acc(C7, alphav, R2); + traits.acc(C15, alphav, R3); + r6.storePacket(0 * Traits::ResPacketSize, R0); + r6.storePacket(1 * Traits::ResPacketSize, R1); + r7.storePacket(0 * Traits::ResPacketSize, R2); + r7.storePacket(1 * Traits::ResPacketSize, R3); + } + } + + for(Index j2=packet_cols8; j2=4 ? (cols/4) * 4 : 0; Index count = 0; const Index peeled_k = (depth/PacketSize)*PacketSize; -// if(nr>=8) -// { -// for(Index j2=0; j2 kernel; -// for (int p = 0; p < PacketSize; ++p) { -// kernel.packet[p] = ploadu(&rhs[(j2+p)*rhsStride+k]); -// } -// ptranspose(kernel); -// for (int p = 0; p < PacketSize; ++p) { -// pstoreu(blockB+count, cj.pconj(kernel.packet[p])); -// count+=PacketSize; -// } -// } -// } -// for(; k=8) + { + for(Index j2=0; j2 kernel0, kernel1, kernel2, kernel3; + kernel0.packet[0%PacketSize] = dm0.template loadPacket(k); + kernel0.packet[1%PacketSize] = dm1.template loadPacket(k); + kernel1.packet[0%PacketSize] = dm2.template loadPacket(k); + kernel1.packet[1%PacketSize] = dm3.template loadPacket(k); + kernel2.packet[0%PacketSize] = dm4.template loadPacket(k); + kernel2.packet[1%PacketSize] = dm5.template loadPacket(k); + kernel3.packet[0%PacketSize] = dm6.template loadPacket(k); + kernel3.packet[1%PacketSize] = dm7.template loadPacket(k); + ptranspose(kernel0); + ptranspose(kernel1); + ptranspose(kernel2); + ptranspose(kernel3); + + pstoreu(blockB + count + 0 * PacketSize, cj.pconj(kernel0.packet[0 % PacketSize])); + pstoreu(blockB + count + 1 * PacketSize, cj.pconj(kernel1.packet[0 % PacketSize])); + pstoreu(blockB + count + 2 * PacketSize, cj.pconj(kernel2.packet[0 % PacketSize])); + pstoreu(blockB + count + 3 * PacketSize, cj.pconj(kernel3.packet[0 % PacketSize])); + + pstoreu(blockB + count + 4 * PacketSize, cj.pconj(kernel0.packet[1 % PacketSize])); + pstoreu(blockB + count + 5 * PacketSize, cj.pconj(kernel1.packet[1 % PacketSize])); + pstoreu(blockB + count + 6 * PacketSize, cj.pconj(kernel2.packet[1 % PacketSize])); + pstoreu(blockB + count + 7 * PacketSize, cj.pconj(kernel3.packet[1 % PacketSize])); + count+=8*PacketSize; + } + else if (PacketSize == 4) + { + PacketBlock kernel0, kernel1; + + kernel0.packet[0%PacketSize] = dm0.template loadPacket(k); + kernel0.packet[1%PacketSize] = dm1.template loadPacket(k); + kernel0.packet[2%PacketSize] = dm2.template loadPacket(k); + kernel0.packet[3%PacketSize] = dm3.template loadPacket(k); + kernel1.packet[0%PacketSize] = dm4.template loadPacket(k); + kernel1.packet[1%PacketSize] = dm5.template loadPacket(k); + kernel1.packet[2%PacketSize] = dm6.template loadPacket(k); + kernel1.packet[3%PacketSize] = dm7.template loadPacket(k); + ptranspose(kernel0); + ptranspose(kernel1); + + pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize])); + pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel1.packet[0%PacketSize])); + pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[1%PacketSize])); + pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel1.packet[1%PacketSize])); + pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[2%PacketSize])); + pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel1.packet[2%PacketSize])); + pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[3%PacketSize])); + pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel1.packet[3%PacketSize])); + count+=8*PacketSize; + } + else if (PacketSize == 8) + { + PacketBlock kernel0; + + kernel0.packet[0%PacketSize] = dm0.template loadPacket(k); + kernel0.packet[1%PacketSize] = dm1.template loadPacket(k); + kernel0.packet[2%PacketSize] = dm2.template loadPacket(k); + kernel0.packet[3%PacketSize] = dm3.template loadPacket(k); + kernel0.packet[4%PacketSize] = dm4.template loadPacket(k); + kernel0.packet[5%PacketSize] = dm5.template loadPacket(k); + kernel0.packet[6%PacketSize] = dm6.template loadPacket(k); + kernel0.packet[7%PacketSize] = dm7.template loadPacket(k); + ptranspose(kernel0); + + pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel0.packet[0%PacketSize])); + pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel0.packet[1%PacketSize])); + pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel0.packet[2%PacketSize])); + pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel0.packet[3%PacketSize])); + pstoreu(blockB+count+4*PacketSize, cj.pconj(kernel0.packet[4%PacketSize])); + pstoreu(blockB+count+5*PacketSize, cj.pconj(kernel0.packet[5%PacketSize])); + pstoreu(blockB+count+6*PacketSize, cj.pconj(kernel0.packet[6%PacketSize])); + pstoreu(blockB+count+7*PacketSize, cj.pconj(kernel0.packet[7%PacketSize])); + count+=8*PacketSize; + } + } + } + + for(; k=4) { @@ -2522,39 +3143,41 @@ struct gemm_pack_rhs=4 ? (cols/4) * 4 : 0; Index count = 0; - // if(nr>=8) - // { - // for(Index j2=0; j2(&rhs[k*rhsStride + j2]); - // pstoreu(blockB+count, cj.pconj(A)); - // } else if (PacketSize==4) { - // Packet A = ploadu(&rhs[k*rhsStride + j2]); - // Packet B = ploadu(&rhs[k*rhsStride + j2 + PacketSize]); - // pstoreu(blockB+count, cj.pconj(A)); - // pstoreu(blockB+count+PacketSize, cj.pconj(B)); - // } else { - // const Scalar* b0 = &rhs[k*rhsStride + j2]; - // blockB[count+0] = cj(b0[0]); - // blockB[count+1] = cj(b0[1]); - // blockB[count+2] = cj(b0[2]); - // blockB[count+3] = cj(b0[3]); - // blockB[count+4] = cj(b0[4]); - // blockB[count+5] = cj(b0[5]); - // blockB[count+6] = cj(b0[6]); - // blockB[count+7] = cj(b0[7]); - // } - // count += 8; - // } - // // skip what we have after - // if(PanelMode) count += 8 * (stride-offset-depth); - // } - // } + if(nr>=8) + { + for(Index j2=0; j2(k, j2); + pstoreu(blockB+count, cj.pconj(A)); + count += PacketSize; + } else if (PacketSize==4) { + Packet A = rhs.template loadPacket(k, j2); + Packet B = rhs.template loadPacket(k, j2 + 4); + pstoreu(blockB+count, cj.pconj(A)); + pstoreu(blockB+count+PacketSize, cj.pconj(B)); + count += 2*PacketSize; + } else { + const LinearMapper dm0 = rhs.getLinearMapper(k, j2); + blockB[count+0] = cj(dm0(0)); + blockB[count+1] = cj(dm0(1)); + blockB[count+2] = cj(dm0(2)); + blockB[count+3] = cj(dm0(3)); + blockB[count+4] = cj(dm0(4)); + blockB[count+5] = cj(dm0(5)); + blockB[count+6] = cj(dm0(6)); + blockB[count+7] = cj(dm0(7)); + count += 8; + } + } + // skip what we have after + if(PanelMode) count += 8 * (stride-offset-depth); + } + } if(nr>=4) { for(Index j2=packet_cols8; j2