mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-21 20:09:06 +08:00
Merged in glchaves/eigen (pull request PR-635)
Speed up GEMV on AVX-512 builds, just as done for GEBP previously. Approved-by: Rasmus Larsen <rmlarsen@google.com>
This commit is contained in:
commit
bf9cbed8d0
@ -15,13 +15,13 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
enum PacketSizeType {
|
||||
PacketFull = 0,
|
||||
PacketHalf,
|
||||
PacketQuarter
|
||||
enum GEBPPacketSizeType {
|
||||
GEBPPacketFull = 0,
|
||||
GEBPPacketHalf,
|
||||
GEBPPacketQuarter
|
||||
};
|
||||
|
||||
template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=PacketFull>
|
||||
template<typename _LhsScalar, typename _RhsScalar, bool _ConjLhs=false, bool _ConjRhs=false, int Arch=Architecture::Target, int _PacketSize=GEBPPacketFull>
|
||||
class gebp_traits;
|
||||
|
||||
|
||||
@ -375,10 +375,10 @@ template <int N, typename T1, typename T2, typename T3>
|
||||
struct packet_conditional { typedef T3 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct packet_conditional<PacketFull, T1, T2, T3> { typedef T1 type; };
|
||||
struct packet_conditional<GEBPPacketFull, T1, T2, T3> { typedef T1 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct packet_conditional<PacketHalf, T1, T2, T3> { typedef T2 type; };
|
||||
struct packet_conditional<GEBPPacketHalf, T1, T2, T3> { typedef T2 type; };
|
||||
|
||||
#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
|
||||
typedef typename packet_conditional<packet_size, \
|
||||
@ -1054,8 +1054,8 @@ protected:
|
||||
#if EIGEN_ARCH_ARM64 && defined EIGEN_VECTORIZE_NEON
|
||||
|
||||
template<>
|
||||
struct gebp_traits <float, float, false, false,Architecture::NEON,PacketFull>
|
||||
: gebp_traits<float,float,false,false,Architecture::Generic,PacketFull>
|
||||
struct gebp_traits <float, float, false, false,Architecture::NEON,GEBPPacketFull>
|
||||
: gebp_traits<float,float,false,false,Architecture::Generic,GEBPPacketFull>
|
||||
{
|
||||
typedef float RhsPacket;
|
||||
|
||||
@ -1203,8 +1203,8 @@ template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMa
|
||||
struct gebp_kernel
|
||||
{
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target> Traits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketHalf> HalfTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,PacketQuarter> QuarterTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketHalf> HalfTraits;
|
||||
typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs,Architecture::Target,GEBPPacketQuarter> QuarterTraits;
|
||||
|
||||
typedef typename Traits::ResScalar ResScalar;
|
||||
typedef typename Traits::LhsPacket LhsPacket;
|
||||
|
@ -14,6 +14,54 @@ namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
enum GEMVPacketSizeType {
|
||||
GEMVPacketFull = 0,
|
||||
GEMVPacketHalf,
|
||||
GEMVPacketQuarter
|
||||
};
|
||||
|
||||
template <int N, typename T1, typename T2, typename T3>
|
||||
struct gemv_packet_cond { typedef T3 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> { typedef T1 type; };
|
||||
|
||||
template <typename T1, typename T2, typename T3>
|
||||
struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> { typedef T2 type; };
|
||||
|
||||
template<typename LhsScalar, typename RhsScalar, int _PacketSize=GEMVPacketFull>
|
||||
class gemv_traits
|
||||
{
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
|
||||
#define PACKET_DECL_COND_PREFIX(prefix, name, packet_size) \
|
||||
typedef typename gemv_packet_cond<packet_size, \
|
||||
typename packet_traits<name ## Scalar>::type, \
|
||||
typename packet_traits<name ## Scalar>::half, \
|
||||
typename unpacket_traits<typename packet_traits<name ## Scalar>::half>::half>::type \
|
||||
prefix ## name ## Packet
|
||||
|
||||
PACKET_DECL_COND_PREFIX(_, Lhs, _PacketSize);
|
||||
PACKET_DECL_COND_PREFIX(_, Rhs, _PacketSize);
|
||||
PACKET_DECL_COND_PREFIX(_, Res, _PacketSize);
|
||||
#undef PACKET_DECL_COND_PREFIX
|
||||
|
||||
public:
|
||||
enum {
|
||||
Vectorizable = unpacket_traits<_LhsPacket>::vectorizable &&
|
||||
unpacket_traits<_RhsPacket>::vectorizable &&
|
||||
int(unpacket_traits<_LhsPacket>::size)==int(unpacket_traits<_RhsPacket>::size),
|
||||
LhsPacketSize = Vectorizable ? unpacket_traits<_LhsPacket>::size : 1,
|
||||
RhsPacketSize = Vectorizable ? unpacket_traits<_RhsPacket>::size : 1,
|
||||
ResPacketSize = Vectorizable ? unpacket_traits<_ResPacket>::size : 1
|
||||
};
|
||||
|
||||
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
|
||||
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
|
||||
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
|
||||
};
|
||||
|
||||
|
||||
/* Optimized col-major matrix * vector product:
|
||||
* This algorithm processes the matrix per vertical panels,
|
||||
* which are then processed horizontaly per chunck of 8*PacketSize x 1 vertical segments.
|
||||
@ -30,23 +78,23 @@ namespace internal {
|
||||
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
|
||||
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,ColMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
|
||||
{
|
||||
typedef gemv_traits<LhsScalar,RhsScalar> Traits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
|
||||
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
|
||||
enum {
|
||||
Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
|
||||
&& int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
|
||||
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
|
||||
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
|
||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
|
||||
};
|
||||
typedef typename Traits::LhsPacket LhsPacket;
|
||||
typedef typename Traits::RhsPacket RhsPacket;
|
||||
typedef typename Traits::ResPacket ResPacket;
|
||||
|
||||
typedef typename packet_traits<LhsScalar>::type _LhsPacket;
|
||||
typedef typename packet_traits<RhsScalar>::type _RhsPacket;
|
||||
typedef typename packet_traits<ResScalar>::type _ResPacket;
|
||||
typedef typename HalfTraits::LhsPacket LhsPacketHalf;
|
||||
typedef typename HalfTraits::RhsPacket RhsPacketHalf;
|
||||
typedef typename HalfTraits::ResPacket ResPacketHalf;
|
||||
|
||||
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
|
||||
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
|
||||
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
|
||||
typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
|
||||
typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
|
||||
typedef typename QuarterTraits::ResPacket ResPacketQuarter;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
|
||||
Index rows, Index cols,
|
||||
@ -73,19 +121,33 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
|
||||
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
|
||||
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
|
||||
conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
|
||||
conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
|
||||
|
||||
const Index lhsStride = lhs.stride();
|
||||
// TODO: for padded aligned inputs, we could enable aligned reads
|
||||
enum { LhsAlignment = Unaligned };
|
||||
enum { LhsAlignment = Unaligned,
|
||||
ResPacketSize = Traits::ResPacketSize,
|
||||
ResPacketSizeHalf = HalfTraits::ResPacketSize,
|
||||
ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
|
||||
LhsPacketSize = Traits::LhsPacketSize,
|
||||
HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
|
||||
HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
|
||||
};
|
||||
|
||||
const Index n8 = rows-8*ResPacketSize+1;
|
||||
const Index n4 = rows-4*ResPacketSize+1;
|
||||
const Index n3 = rows-3*ResPacketSize+1;
|
||||
const Index n2 = rows-2*ResPacketSize+1;
|
||||
const Index n1 = rows-1*ResPacketSize+1;
|
||||
const Index n_half = rows-1*ResPacketSizeHalf+1;
|
||||
const Index n_quarter = rows-1*ResPacketSizeQuarter+1;
|
||||
|
||||
// TODO: improve the following heuristic:
|
||||
const Index block_cols = cols<128 ? cols : (lhsStride*sizeof(LhsScalar)<32000?16:4);
|
||||
ResPacket palpha = pset1<ResPacket>(alpha);
|
||||
ResPacketHalf palpha_half = pset1<ResPacketHalf>(alpha);
|
||||
ResPacketQuarter palpha_quarter = pset1<ResPacketQuarter>(alpha);
|
||||
|
||||
for(Index j2=0; j2<cols; j2+=block_cols)
|
||||
{
|
||||
@ -190,6 +252,28 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
pstoreu(res+i+ResPacketSize*0, pmadd(c0,palpha,ploadu<ResPacket>(res+i+ResPacketSize*0)));
|
||||
i+=ResPacketSize;
|
||||
}
|
||||
if(HasHalf && i<n_half)
|
||||
{
|
||||
ResPacketHalf c0 = pset1<ResPacketHalf>(ResScalar(0));
|
||||
for(Index j=j2; j<jend; j+=1)
|
||||
{
|
||||
RhsPacketHalf b0 = pset1<RhsPacketHalf>(rhs(j,0));
|
||||
c0 = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i+0,j),b0,c0);
|
||||
}
|
||||
pstoreu(res+i+ResPacketSizeHalf*0, pmadd(c0,palpha_half,ploadu<ResPacketHalf>(res+i+ResPacketSizeHalf*0)));
|
||||
i+=ResPacketSizeHalf;
|
||||
}
|
||||
if(HasQuarter && i<n_quarter)
|
||||
{
|
||||
ResPacketQuarter c0 = pset1<ResPacketQuarter>(ResScalar(0));
|
||||
for(Index j=j2; j<jend; j+=1)
|
||||
{
|
||||
RhsPacketQuarter b0 = pset1<RhsPacketQuarter>(rhs(j,0));
|
||||
c0 = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i+0,j),b0,c0);
|
||||
}
|
||||
pstoreu(res+i+ResPacketSizeQuarter*0, pmadd(c0,palpha_quarter,ploadu<ResPacketQuarter>(res+i+ResPacketSizeQuarter*0)));
|
||||
i+=ResPacketSizeQuarter;
|
||||
}
|
||||
for(;i<rows;++i)
|
||||
{
|
||||
ResScalar c0(0);
|
||||
@ -213,23 +297,24 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
template<typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version>
|
||||
struct general_matrix_vector_product<Index,LhsScalar,LhsMapper,RowMajor,ConjugateLhs,RhsScalar,RhsMapper,ConjugateRhs,Version>
|
||||
{
|
||||
typedef gemv_traits<LhsScalar,RhsScalar> Traits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketHalf> HalfTraits;
|
||||
typedef gemv_traits<LhsScalar,RhsScalar,GEMVPacketQuarter> QuarterTraits;
|
||||
|
||||
typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||
|
||||
enum {
|
||||
Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable
|
||||
&& int(packet_traits<LhsScalar>::size)==int(packet_traits<RhsScalar>::size),
|
||||
LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
|
||||
RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
|
||||
ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1
|
||||
};
|
||||
typedef typename Traits::LhsPacket LhsPacket;
|
||||
static const Index LhsPacketSize = Traits::LhsPacketSize;
|
||||
typedef typename Traits::RhsPacket RhsPacket;
|
||||
typedef typename Traits::ResPacket ResPacket;
|
||||
|
||||
typedef typename packet_traits<LhsScalar>::type _LhsPacket;
|
||||
typedef typename packet_traits<RhsScalar>::type _RhsPacket;
|
||||
typedef typename packet_traits<ResScalar>::type _ResPacket;
|
||||
typedef typename HalfTraits::LhsPacket LhsPacketHalf;
|
||||
typedef typename HalfTraits::RhsPacket RhsPacketHalf;
|
||||
typedef typename HalfTraits::ResPacket ResPacketHalf;
|
||||
|
||||
typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
|
||||
typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
|
||||
typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
|
||||
typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
|
||||
typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
|
||||
typedef typename QuarterTraits::ResPacket ResPacketQuarter;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(
|
||||
Index rows, Index cols,
|
||||
@ -254,6 +339,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
eigen_internal_assert(rhs.stride()==1);
|
||||
conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
|
||||
conj_helper<LhsPacket,RhsPacket,ConjugateLhs,ConjugateRhs> pcj;
|
||||
conj_helper<LhsPacketHalf,RhsPacketHalf,ConjugateLhs,ConjugateRhs> pcj_half;
|
||||
conj_helper<LhsPacketQuarter,RhsPacketQuarter,ConjugateLhs,ConjugateRhs> pcj_quarter;
|
||||
|
||||
// TODO: fine tune the following heuristic. The rationale is that if the matrix is very large,
|
||||
// processing 8 rows at once might be counter productive wrt cache.
|
||||
@ -262,7 +349,16 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
const Index n2 = rows-1;
|
||||
|
||||
// TODO: for padded aligned inputs, we could enable aligned reads
|
||||
enum { LhsAlignment = Unaligned };
|
||||
enum { LhsAlignment = Unaligned,
|
||||
ResPacketSize = Traits::ResPacketSize,
|
||||
ResPacketSizeHalf = HalfTraits::ResPacketSize,
|
||||
ResPacketSizeQuarter = QuarterTraits::ResPacketSize,
|
||||
LhsPacketSize = Traits::LhsPacketSize,
|
||||
LhsPacketSizeHalf = HalfTraits::LhsPacketSize,
|
||||
LhsPacketSizeQuarter = QuarterTraits::LhsPacketSize,
|
||||
HasHalf = (int)ResPacketSizeHalf < (int)ResPacketSize,
|
||||
HasQuarter = (int)ResPacketSizeQuarter < (int)ResPacketSizeHalf
|
||||
};
|
||||
|
||||
Index i=0;
|
||||
for(; i<n8; i+=8)
|
||||
@ -383,6 +479,8 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
for(; i<rows; ++i)
|
||||
{
|
||||
ResPacket c0 = pset1<ResPacket>(ResScalar(0));
|
||||
ResPacketHalf c0_h = pset1<ResPacketHalf>(ResScalar(0));
|
||||
ResPacketQuarter c0_q = pset1<ResPacketQuarter>(ResScalar(0));
|
||||
Index j=0;
|
||||
for(; j+LhsPacketSize<=cols; j+=LhsPacketSize)
|
||||
{
|
||||
@ -390,6 +488,22 @@ EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void general_matrix_vector_product<Index,Lhs
|
||||
c0 = pcj.pmadd(lhs.template load<LhsPacket,LhsAlignment>(i,j),b0,c0);
|
||||
}
|
||||
ResScalar cc0 = predux(c0);
|
||||
if (HasHalf) {
|
||||
for(; j+LhsPacketSizeHalf<=cols; j+=LhsPacketSizeHalf)
|
||||
{
|
||||
RhsPacketHalf b0 = rhs.template load<RhsPacketHalf,Unaligned>(j,0);
|
||||
c0_h = pcj_half.pmadd(lhs.template load<LhsPacketHalf,LhsAlignment>(i,j),b0,c0_h);
|
||||
}
|
||||
cc0 += predux(c0_h);
|
||||
}
|
||||
if (HasQuarter) {
|
||||
for(; j+LhsPacketSizeQuarter<=cols; j+=LhsPacketSizeQuarter)
|
||||
{
|
||||
RhsPacketQuarter b0 = rhs.template load<RhsPacketQuarter,Unaligned>(j,0);
|
||||
c0_q = pcj_quarter.pmadd(lhs.template load<LhsPacketQuarter,LhsAlignment>(i,j),b0,c0_q);
|
||||
}
|
||||
cc0 += predux(c0_q);
|
||||
}
|
||||
for(; j<cols; ++j)
|
||||
{
|
||||
cc0 += cj.pmul(lhs(i,j), rhs(j,0));
|
||||
|
Loading…
x
Reference in New Issue
Block a user