diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index b012691c1..e7cab4720 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -891,6 +891,86 @@ struct gebp_kernel Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0); }; +template::LhsProgress> +struct last_row_process_16_packets +{ + typedef gebp_traits Traits; + typedef gebp_traits SwappedTraits; + + typedef typename Traits::ResScalar ResScalar; + typedef typename SwappedTraits::LhsPacket SLhsPacket; + typedef typename SwappedTraits::RhsPacket SRhsPacket; + typedef typename SwappedTraits::ResPacket SResPacket; + typedef typename SwappedTraits::AccPacket SAccPacket; + + EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA, + const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2, + ResScalar alpha, SAccPacket &C0) + { + EIGEN_UNUSED_VARIABLE(res); + EIGEN_UNUSED_VARIABLE(straits); + EIGEN_UNUSED_VARIABLE(blA); + EIGEN_UNUSED_VARIABLE(blB); + EIGEN_UNUSED_VARIABLE(depth); + EIGEN_UNUSED_VARIABLE(endk); + EIGEN_UNUSED_VARIABLE(i); + EIGEN_UNUSED_VARIABLE(j2); + EIGEN_UNUSED_VARIABLE(alpha); + EIGEN_UNUSED_VARIABLE(C0); + } +}; + + +template +struct last_row_process_16_packets { + typedef gebp_traits Traits; + typedef gebp_traits SwappedTraits; + + typedef typename Traits::ResScalar ResScalar; + typedef typename SwappedTraits::LhsPacket SLhsPacket; + typedef typename SwappedTraits::RhsPacket SRhsPacket; + typedef typename SwappedTraits::ResPacket SResPacket; + typedef typename SwappedTraits::AccPacket SAccPacket; + + EIGEN_STRONG_INLINE void operator()(const DataMapper& res, SwappedTraits &straits, const LhsScalar* blA, + const RhsScalar* blB, Index depth, const Index endk, Index i, Index j2, + ResScalar alpha, SAccPacket &C0) + { + typedef typename unpacket_traits::half>::half SResPacketQuarter; + typedef typename unpacket_traits::half>::half SLhsPacketQuarter; + typedef typename unpacket_traits::half>::half SRhsPacketQuarter; + typedef typename unpacket_traits::half>::half SAccPacketQuarter; + + SResPacketQuarter R = res.template gatherPacket(i, j2); + SResPacketQuarter alphav = pset1(alpha); + + if (depth - endk > 0) + { + // We have to handle the last row(s) of the rhs, which + // correspond to a half-packet + SAccPacketQuarter c0 = predux_half_dowto4(predux_half_dowto4(C0)); + + for (Index kk = endk; kk < depth; kk++) + { + SLhsPacketQuarter a0; + SRhsPacketQuarter b0; + straits.loadLhsUnaligned(blB, a0); + straits.loadRhs(blA, b0); + straits.madd(a0,b0,c0,b0); + blB += SwappedTraits::LhsProgress/4; + blA += 1; + } + straits.acc(c0, alphav, R); + } + else + { + straits.acc(predux_half_dowto4(predux_half_dowto4(C0)), alphav, R); + } + res.scatterPacket(i, j2, R); + } +}; + template EIGEN_DONT_INLINE void gebp_kernel @@ -1527,13 +1607,15 @@ void gebp_kernel::half>::size; + const int SResPacketQuarterSize = unpacket_traits::half>::half>::size; if ((SwappedTraits::LhsProgress % 4) == 0 && - (SwappedTraits::LhsProgress <= 8) && - (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr)) + (SwappedTraits::LhsProgress<=16) && + (SwappedTraits::LhsProgress!=8 || SResPacketHalfSize==nr) && + (SwappedTraits::LhsProgress!=16 || SResPacketQuarterSize==nr)) { SAccPacket C0, C1, C2, C3; straits.initAcc(C0); @@ -1610,6 +1692,15 @@ void gebp_kernel p; + p(res, straits, blA, blB, depth, endk, i, j2,alpha, C0); + } else { SResPacket R = res.template gatherPacket(i, j2);