Speed up matrix multiplication for small to medium size matrices by using half- or quarter-packet vectorized loads in gemm_pack_rhs if they have size 4, instead of dropping down the the scalar path.

Benchmark measurements below are for computing ```c.noalias() = a.transpose() * b;``` for square RowMajor matrices of varying size.

Measured improvement with AVX+FMA:

name                           old time/op             new time/op             delta
BM_MatMul_ATB/8                 139ns ± 1%              129ns ± 1%   -7.49%          (p=0.008 n=5+5)
BM_MatMul_ATB/32               1.46µs ± 1%             1.22µs ± 0%  -16.72%          (p=0.008 n=5+5)
BM_MatMul_ATB/64               8.43µs ± 1%             7.41µs ± 0%  -12.04%          (p=0.008 n=5+5)
BM_MatMul_ATB/128              56.8µs ± 1%             52.9µs ± 1%   -6.83%          (p=0.008 n=5+5)
BM_MatMul_ATB/256               407µs ± 1%              395µs ± 3%   -2.94%          (p=0.032 n=5+5)
BM_MatMul_ATB/512              3.27ms ± 3%             3.18ms ± 1%     ~             (p=0.056 n=5+5)


Measured improvement for AVX512:

name                          old time/op             new time/op             delta
BM_MatMul_ATB/8                167ns ± 1%              154ns ± 1%   -7.63%          (p=0.008 n=5+5)
BM_MatMul_ATB/32              1.08µs ± 1%             0.83µs ± 3%  -23.58%          (p=0.008 n=5+5)
BM_MatMul_ATB/64              6.21µs ± 1%             5.06µs ± 1%  -18.47%          (p=0.008 n=5+5)
BM_MatMul_ATB/128             36.1µs ± 2%             31.3µs ± 1%  -13.32%          (p=0.008 n=5+5)
BM_MatMul_ATB/256              263µs ± 2%              242µs ± 2%   -7.92%          (p=0.008 n=5+5)
BM_MatMul_ATB/512             1.95ms ± 2%             1.91ms ± 2%     ~             (p=0.095 n=5+5)
BM_MatMul_ATB/1k              15.4ms ± 4%             14.8ms ± 2%     ~             (p=0.095 n=5+5)
This commit is contained in:
Rasmus Munk Larsen 2020-04-07 22:09:51 +00:00
parent 8e875719b3
commit f0577a2bfd

View File

@ -2653,8 +2653,14 @@ template<typename Scalar, typename Index, typename DataMapper, int nr, bool Conj
struct gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Conjugate, PanelMode>
{
typedef typename packet_traits<Scalar>::type Packet;
typedef typename unpacket_traits<Packet>::half HalfPacket;
typedef typename unpacket_traits<typename unpacket_traits<Packet>::half>::half QuarterPacket;
typedef typename DataMapper::LinearMapper LinearMapper;
enum { PacketSize = packet_traits<Scalar>::size };
enum { PacketSize = packet_traits<Scalar>::size,
HalfPacketSize = unpacket_traits<HalfPacket>::size,
QuarterPacketSize = unpacket_traits<QuarterPacket>::size,
HasHalf = (int)HalfPacketSize < (int)PacketSize,
HasQuarter = (int)QuarterPacketSize < (int)HalfPacketSize };
EIGEN_DONT_INLINE void operator()(Scalar* blockB, const DataMapper& rhs, Index depth, Index cols, Index stride=0, Index offset=0);
};
@ -2716,6 +2722,14 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, DataMapper, nr, RowMajor, Co
Packet A = rhs.template loadPacket<Packet>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += PacketSize;
} else if (HasHalf && HalfPacketSize==4) {
HalfPacket A = rhs.template loadPacket<HalfPacket>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += HalfPacketSize;
} else if (HasQuarter && QuarterPacketSize==4) {
QuarterPacket A = rhs.template loadPacket<QuarterPacket>(k, j2);
pstoreu(blockB+count, cj.pconj(A));
count += QuarterPacketSize;
} else {
const LinearMapper dm0 = rhs.getLinearMapper(k, j2);
blockB[count+0] = cj(dm0(0));