diff --git a/Eigen/src/Core/products/GeneralBlockPanelKernel.h b/Eigen/src/Core/products/GeneralBlockPanelKernel.h index fdd0ec0e9..6c1d882fd 100644 --- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h +++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h @@ -15,13 +15,13 @@ namespace Eigen { namespace internal { -enum PacketSizeType { - PacketFull = 0, - PacketHalf, - PacketQuarter +enum GEBPPacketSizeType { + GEBPPacketFull = 0, + GEBPPacketHalf, + GEBPPacketQuarter }; -template +template class gebp_traits; @@ -375,10 +375,10 @@ template struct packet_conditional { typedef T3 type; }; template -struct packet_conditional { typedef T1 type; }; +struct packet_conditional { typedef T1 type; }; template -struct packet_conditional { typedef T2 type; }; +struct packet_conditional { typedef T2 type; }; #define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \ typedef typename packet_conditional -struct gebp_traits - : gebp_traits +struct gebp_traits + : gebp_traits { typedef float RhsPacket; @@ -1203,8 +1203,8 @@ template Traits; - typedef gebp_traits HalfTraits; - typedef gebp_traits QuarterTraits; + typedef gebp_traits HalfTraits; + typedef gebp_traits QuarterTraits; typedef typename Traits::ResScalar ResScalar; typedef typename Traits::LhsPacket LhsPacket; diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 767feb99d..eb1d924e5 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -14,6 +14,54 @@ namespace Eigen { namespace internal { +enum GEMVPacketSizeType { + GEMVPacketFull = 0, + GEMVPacketHalf, + GEMVPacketQuarter +}; + +template +struct gemv_packet_cond { typedef T3 type; }; + +template +struct gemv_packet_cond { typedef T1 type; }; + +template +struct gemv_packet_cond { typedef T2 type; }; + +template +class gemv_traits +{ + typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; + +#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \ + typedef typename gemv_packet_cond::type, \ + typename packet_traits::half, \ + typename unpacket_traits::half>::half>::type \ + prefix ## name ## Packet + + PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize); + PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize); + PACKET_DECL_COND_PREFIX(_, Res, _PacketSize); +#undef PACKET_DECL_COND_PREFIX + +public: + enum { + Vectorizable = unpacket_traits<_LhsPacket>::vectorizable && + unpacket_traits<_RhsPacket>::vectorizable && + int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size), + LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1, + RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1, + ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1 + }; + + typedef typename conditional::type LhsPacket; + typedef typename conditional::type RhsPacket; + typedef typename conditional::type ResPacket; +}; + + /* Optimized col-major matrix * vector product: * This algorithm processes the matrix per vertical panels, * which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments. @@ -30,23 +78,23 @@ namespace internal { template struct general_matrix_vector_product { + typedef gemv_traits Traits; + typedef gemv_traits HalfTraits; + typedef gemv_traits QuarterTraits; + typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; -enum { - Vectorizable = packet_traits::Vectorizable && packet_traits::Vectorizable - && int(packet_traits::size)==int(packet_traits::size), - LhsPacketSize = Vectorizable ? packet_traits::size : 1, - RhsPacketSize = Vectorizable ? packet_traits::size : 1, - ResPacketSize = Vectorizable ? packet_traits::size : 1 -}; + typedef typename Traits::LhsPacket LhsPacket; + typedef typename Traits::RhsPacket RhsPacket; + typedef typename Traits::ResPacket ResPacket; -typedef typename packet_traits::type _LhsPacket; -typedef typename packet_traits::type _RhsPacket; -typedef typename packet_traits::type _ResPacket; + typedef typename HalfTraits::LhsPacket LhsPacketHalf; + typedef typename HalfTraits::RhsPacket RhsPacketHalf; + typedef typename HalfTraits::ResPacket ResPacketHalf; -typedef typename conditional::type LhsPacket; -typedef typename conditional::type RhsPacket; -typedef typename conditional::type ResPacket; + typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; + typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; + typedef typename QuarterTraits::ResPacket ResPacketQuarter; EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, @@ -73,19 +121,33 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product cj; conj_helper pcj; + conj_helper pcj_half; + conj_helper pcj_quarter; + const Index lhsStride = lhs.stride(); // TODO: for padded aligned inputs, we could enable aligned reads - enum { LhsAlignment = Unaligned }; + enum { LhsAlignment = Unaligned, + ResPacketSize = Traits::ResPacketSize, + ResPacketSizeHalf = HalfTraits::ResPacketSize, + ResPacketSizeQuarter = QuarterTraits::ResPacketSize, + LhsPacketSize = Traits::LhsPacketSize, + HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize, + HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf + }; const Index n8 = rows-8*ResPacketSize+1; const Index n4 = rows-4*ResPacketSize+1; const Index n3 = rows-3*ResPacketSize+1; const Index n2 = rows-2*ResPacketSize+1; const Index n1 = rows-1*ResPacketSize+1; + const Index n_half = rows-1*ResPacketSizeHalf+1; + const Index n_quarter = rows-1*ResPacketSizeQuarter+1; // TODO: improve the following heuristic: const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4); ResPacket palpha = pset1(alpha); + ResPacketHalf palpha_half = pset1(alpha); + ResPacketQuarter palpha_quarter = pset1(alpha); for(Index j2=0; j2(res+i+ResPacketSize*0))); i+=ResPacketSize; } + if(HasHalf && i(ResScalar(0)); + for(Index j=j2; j(rhs(j,0)); + c0 = pcj_half.pmadd(lhs.template load(i+0,j),b0,c0); + } + pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu(res+i+ResPacketSizeHalf*0))); + i+=ResPacketSizeHalf; + } + if(HasQuarter && i(ResScalar(0)); + for(Index j=j2; j(rhs(j,0)); + c0 = pcj_quarter.pmadd(lhs.template load(i+0,j),b0,c0); + } + pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu(res+i+ResPacketSizeQuarter*0))); + i+=ResPacketSizeQuarter; + } for(;i struct general_matrix_vector_product { -typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; + typedef gemv_traits Traits; + typedef gemv_traits HalfTraits; + typedef gemv_traits QuarterTraits; -enum { - Vectorizable = packet_traits::Vectorizable && packet_traits::Vectorizable - && int(packet_traits::size)==int(packet_traits::size), - LhsPacketSize = Vectorizable ? packet_traits::size : 1, - RhsPacketSize = Vectorizable ? packet_traits::size : 1, - ResPacketSize = Vectorizable ? packet_traits::size : 1 -}; + typedef typename ScalarBinaryOpTraits::ReturnType ResScalar; -typedef typename packet_traits::type _LhsPacket; -typedef typename packet_traits::type _RhsPacket; -typedef typename packet_traits::type _ResPacket; + typedef typename Traits::LhsPacket LhsPacket; + static const Index LhsPacketSize = Traits::LhsPacketSize; + typedef typename Traits::RhsPacket RhsPacket; + typedef typename Traits::ResPacket ResPacket; -typedef typename conditional::type LhsPacket; -typedef typename conditional::type RhsPacket; -typedef typename conditional::type ResPacket; + typedef typename HalfTraits::LhsPacket LhsPacketHalf; + typedef typename HalfTraits::RhsPacket RhsPacketHalf; + typedef typename HalfTraits::ResPacket ResPacketHalf; + + typedef typename QuarterTraits::LhsPacket LhsPacketQuarter; + typedef typename QuarterTraits::RhsPacket RhsPacketQuarter; + typedef typename QuarterTraits::ResPacket ResPacketQuarter; EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run( Index rows, Index cols, @@ -254,6 +339,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product cj; conj_helper pcj; + conj_helper pcj_half; + conj_helper pcj_quarter; // TODO: fine tune the following heuristic. The rationale is that if the matrix is very large, // processing 8 rows at once might be counter productive wrt cache. @@ -262,7 +349,16 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product(ResScalar(0)); + ResPacketHalf c0_h = pset1(ResScalar(0)); + ResPacketQuarter c0_q = pset1(ResScalar(0)); Index j=0; for(; j+LhsPacketSize<=cols; j+=LhsPacketSize) { @@ -390,6 +488,22 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product(i,j),b0,c0); } ResScalar cc0 = predux(c0); + if (HasHalf) { + for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf) + { + RhsPacketHalf b0 = rhs.template load(j,0); + c0_h = pcj_half.pmadd(lhs.template load(i,j),b0,c0_h); + } + cc0 += predux(c0_h); + } + if (HasQuarter) { + for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter) + { + RhsPacketQuarter b0 = rhs.template load(j,0); + c0_q = pcj_quarter.pmadd(lhs.template load(i,j),b0,c0_q); + } + cc0 += predux(c0_q); + } for(; j