Enable vectorization of pack_rhs with a column-major RHS.

Rename and generalize Kernel<*> to PacketBlock<*,N>.
This commit is contained in:
Gael Guennebaud 2014-04-25 10:56:18 +02:00
parent b0e19db1cf
commit 3d8d0f6269
11 changed files with 64 additions and 43 deletions

View File

@ -417,15 +417,15 @@ template<> inline std::complex<double> pmul(const std::complex<double>& a, const
/***************************************************************************
* Kernel, that is a collection of N packets where N is the number of words
* in the packet.
* PacketBlock, that is a collection of N packets where the number of words
* in the packet is a multiple of N.
***************************************************************************/
template <typename Packet> struct Kernel {
Packet packet[unpacket_traits<Packet>::size];
template <typename Packet,int N=unpacket_traits<Packet>::size> struct PacketBlock {
Packet packet[N];
};
template<typename Packet> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet>& /*kernel*/) {
ptranspose(PacketBlock<Packet,1>& /*kernel*/) {
// Nothing to do in the scalar case, i.e. a 1x1 matrix.
}

View File

@ -431,8 +431,8 @@ template<> EIGEN_STRONG_INLINE Packet2cd pcplxflip<Packet2cd>(const Packet2cd& x
return Packet2cd(_mm256_shuffle_pd(x.v, x.v, 0x5));
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4cf>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4cf,4>& kernel) {
__m256d P0 = _mm256_castps_pd(kernel.packet[0].v);
__m256d P1 = _mm256_castps_pd(kernel.packet[1].v);
__m256d P2 = _mm256_castps_pd(kernel.packet[2].v);
@ -449,8 +449,8 @@ ptranspose(Kernel<Packet4cf>& kernel) {
kernel.packet[2].v = _mm256_castpd_ps(_mm256_permute2f128_pd(T1, T3, 49));
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet2cd>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet2cd,2>& kernel) {
__m256d tmp = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 0+(2<<4));
kernel.packet[1].v = _mm256_permute2f128_pd(kernel.packet[0].v, kernel.packet[1].v, 1+(3<<4));
kernel.packet[0].v = tmp;

View File

@ -498,8 +498,8 @@ struct palign_impl<Offset,Packet4d>
}
};
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet8f>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8f,8>& kernel) {
__m256 T0 = _mm256_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m256 T1 = _mm256_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
__m256 T2 = _mm256_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
@ -526,8 +526,26 @@ ptranspose(Kernel<Packet8f>& kernel) {
kernel.packet[7] = _mm256_permute2f128_ps(S3, S7, 0x31);
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4d>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet8f,4>& kernel) {
__m256 T0 = _mm256_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m256 T1 = _mm256_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
__m256 T2 = _mm256_unpacklo_ps(kernel.packet[2], kernel.packet[3]);
__m256 T3 = _mm256_unpackhi_ps(kernel.packet[2], kernel.packet[3]);
__m256 S0 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(1,0,1,0));
__m256 S1 = _mm256_shuffle_ps(T0,T2,_MM_SHUFFLE(3,2,3,2));
__m256 S2 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(1,0,1,0));
__m256 S3 = _mm256_shuffle_ps(T1,T3,_MM_SHUFFLE(3,2,3,2));
kernel.packet[0] = _mm256_permute2f128_ps(S0, S1, 0x20);
kernel.packet[1] = _mm256_permute2f128_ps(S2, S3, 0x20);
kernel.packet[2] = _mm256_permute2f128_ps(S0, S1, 0x31);
kernel.packet[3] = _mm256_permute2f128_ps(S2, S3, 0x31);
}
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4d,4>& kernel) {
__m256d T0 = _mm256_shuffle_pd(kernel.packet[0], kernel.packet[1], 15);
__m256d T1 = _mm256_shuffle_pd(kernel.packet[0], kernel.packet[1], 0);
__m256d T2 = _mm256_shuffle_pd(kernel.packet[2], kernel.packet[3], 15);

View File

@ -229,7 +229,7 @@ template<> EIGEN_STRONG_INLINE Packet2cf pcplxflip<Packet2cf>(const Packet2cf& x
return Packet2cf(vec_perm(x.v, x.v, p16uc_COMPLEX_REV));
}
template<> EIGEN_STRONG_INLINE void ptranspose(Kernel<Packet2cf>& kernel)
template<> EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet2cf,2>& kernel)
{
Packet4f tmp = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_COMPLEX_TRANSPOSE_0);
kernel.packet[1].v = vec_perm(kernel.packet[0].v, kernel.packet[1].v, p16uc_COMPLEX_TRANSPOSE_1);

View File

@ -539,7 +539,7 @@ struct palign_impl<Offset,Packet4i>
};
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4f>& kernel) {
ptranspose(PacketBlock<Packet4f,4>& kernel) {
Packet4f t0, t1, t2, t3;
t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);
@ -552,7 +552,7 @@ ptranspose(Kernel<Packet4f>& kernel) {
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4i>& kernel) {
ptranspose(PacketBlock<Packet4i,4>& kernel) {
Packet4i t0, t1, t2, t3;
t0 = vec_mergeh(kernel.packet[0], kernel.packet[2]);
t1 = vec_mergel(kernel.packet[0], kernel.packet[2]);

View File

@ -264,7 +264,7 @@ template<> EIGEN_STRONG_INLINE Packet2cf pdiv<Packet2cf>(const Packet2cf& a, con
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet2cf>& kernel) {
ptranspose(PacketBlock<Packet2cf,2>& kernel) {
float32x4_t tmp = vcombine_f32(vget_high_f32(kernel.packet[0].v), vget_high_f32(kernel.packet[1].v));
kernel.packet[0].v = vcombine_f32(vget_low_f32(kernel.packet[0].v), vget_low_f32(kernel.packet[1].v));
kernel.packet[1].v = tmp;

View File

@ -452,7 +452,7 @@ PALIGN_NEON(3,Packet4i,vextq_s32)
#undef PALIGN_NEON
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4f>& kernel) {
ptranspose(PacketBlock<Packet4f,4>& kernel) {
float32x4x2_t tmp1 = vzipq_f32(kernel.packet[0], kernel.packet[1]);
float32x4x2_t tmp2 = vzipq_f32(kernel.packet[2], kernel.packet[3]);
@ -463,7 +463,7 @@ ptranspose(Kernel<Packet4f>& kernel) {
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4i>& kernel) {
ptranspose(PacketBlock<Packet4i,4>& kernel) {
int32x4x2_t tmp1 = vzipq_s32(kernel.packet[0], kernel.packet[1]);
int32x4x2_t tmp2 = vzipq_s32(kernel.packet[2], kernel.packet[3]);
kernel.packet[0] = vcombine_s32(vget_low_s32(tmp1.val[0]), vget_low_s32(tmp2.val[0]));

View File

@ -462,8 +462,8 @@ EIGEN_STRONG_INLINE Packet1cd pcplxflip/*<Packet1cd>*/(const Packet1cd& x)
return Packet1cd(preverse(Packet2d(x.v)));
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet2cf>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet2cf,2>& kernel) {
__m128d w1 = _mm_castps_pd(kernel.packet[0].v);
__m128d w2 = _mm_castps_pd(kernel.packet[1].v);

View File

@ -784,20 +784,20 @@ struct palign_impl<Offset,Packet2d>
};
#endif
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4f>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4f,4>& kernel) {
_MM_TRANSPOSE4_PS(kernel.packet[0], kernel.packet[1], kernel.packet[2], kernel.packet[3]);
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet2d>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet2d,2>& kernel) {
__m128d tmp = _mm_unpackhi_pd(kernel.packet[0], kernel.packet[1]);
kernel.packet[0] = _mm_unpacklo_pd(kernel.packet[0], kernel.packet[1]);
kernel.packet[1] = tmp;
}
template<> EIGEN_DEVICE_FUNC inline void
ptranspose(Kernel<Packet4i>& kernel) {
EIGEN_DEVICE_FUNC inline void
ptranspose(PacketBlock<Packet4i,4>& kernel) {
__m128i T0 = _mm_unpacklo_epi32(kernel.packet[0], kernel.packet[1]);
__m128i T1 = _mm_unpacklo_epi32(kernel.packet[2], kernel.packet[3]);
__m128i T2 = _mm_unpackhi_epi32(kernel.packet[0], kernel.packet[1]);

View File

@ -1585,7 +1585,7 @@ EIGEN_DONT_INLINE void gemm_pack_lhs<Scalar, Index, Pack1, Pack2, RowMajor, Conj
{
for (Index m = 0; m < pack; m += PacketSize)
{
Kernel<Packet> kernel;
PacketBlock<Packet> kernel;
for (int p = 0; p < PacketSize; ++p) kernel.packet[p] = ploadu<Packet>(&lhs(i+p+m, k));
ptranspose(kernel);
for (int p = 0; p < PacketSize; ++p) pstore(blockA+count+m+(pack)*p, cj.pconj(kernel.packet[p]));
@ -1675,7 +1675,7 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, Pan
// if(PacketSize==8) // TODO enbale vectorized transposition for PacketSize==4
// {
// for(; k<peeled_k; k+=PacketSize) {
// Kernel<Packet> kernel;
// PacketBlock<Packet> kernel;
// for (int p = 0; p < PacketSize; ++p) {
// kernel.packet[p] = ploadu<Packet>(&rhs[(j2+p)*rhsStride+k]);
// }
@ -1713,19 +1713,22 @@ EIGEN_DONT_INLINE void gemm_pack_rhs<Scalar, Index, nr, ColMajor, Conjugate, Pan
const Scalar* b1 = &rhs[(j2+1)*rhsStride];
const Scalar* b2 = &rhs[(j2+2)*rhsStride];
const Scalar* b3 = &rhs[(j2+3)*rhsStride];
Index k=0;
if(PacketSize==4) // TODO enbale vectorized transposition for PacketSize==2 ??
if((PacketSize%4)==0) // TODO enbale vectorized transposition for PacketSize==2 ??
{
for(; k<peeled_k; k+=PacketSize) {
Kernel<Packet> kernel;
for (int p = 0; p < PacketSize; ++p) {
kernel.packet[p] = ploadu<Packet>(&rhs[(j2+p)*rhsStride+k]);
}
PacketBlock<Packet,(PacketSize%4)==0?4:PacketSize> kernel;
kernel.packet[0] = ploadu<Packet>(&b0[k]);
kernel.packet[1] = ploadu<Packet>(&b1[k]);
kernel.packet[2] = ploadu<Packet>(&b2[k]);
kernel.packet[3] = ploadu<Packet>(&b3[k]);
ptranspose(kernel);
for (int p = 0; p < PacketSize; ++p) {
pstoreu(blockB+count, cj.pconj(kernel.packet[p]));
count+=PacketSize;
}
pstoreu(blockB+count+0*PacketSize, cj.pconj(kernel.packet[0]));
pstoreu(blockB+count+1*PacketSize, cj.pconj(kernel.packet[1]));
pstoreu(blockB+count+2*PacketSize, cj.pconj(kernel.packet[2]));
pstoreu(blockB+count+3*PacketSize, cj.pconj(kernel.packet[3]));
count+=4*PacketSize;
}
}
for(; k<depth; k++)

View File

@ -228,7 +228,7 @@ template<typename Scalar> void packetmath()
internal::pstore(data2, internal::preverse(internal::pload<Packet>(data1)));
VERIFY(areApprox(ref, data2, PacketSize) && "internal::preverse");
internal::Kernel<Packet> kernel;
internal::PacketBlock<Packet> kernel;
for (int i=0; i<PacketSize; ++i) {
kernel.packet[i] = internal::pload<Packet>(data1+i*PacketSize);
}
@ -236,7 +236,7 @@ template<typename Scalar> void packetmath()
for (int i=0; i<PacketSize; ++i) {
internal::pstore(data2, kernel.packet[i]);
for (int j = 0; j < PacketSize; ++j) {
VERIFY(isApproxAbs(data2[j], data1[i+j*PacketSize], refvalue));
VERIFY(isApproxAbs(data2[j], data1[i+j*PacketSize], refvalue) && "ptranspose");
}
}
}
@ -393,9 +393,9 @@ template<typename Scalar> void packetmath_scatter_gather() {
for (int i = 0; i < PacketSize*11; ++i) {
if ((i%11) == 0) {
VERIFY(isApproxAbs(buffer[i], data1[i/11], refvalue));
VERIFY(isApproxAbs(buffer[i], data1[i/11], refvalue) && "pscatter");
} else {
VERIFY(isApproxAbs(buffer[i], Scalar(0), refvalue));
VERIFY(isApproxAbs(buffer[i], Scalar(0), refvalue) && "pscatter");
}
}
@ -405,7 +405,7 @@ template<typename Scalar> void packetmath_scatter_gather() {
packet = internal::pgather<Scalar, Packet>(buffer, 7);
internal::pstore(data1, packet);
for (int i = 0; i < PacketSize; ++i) {
VERIFY(isApproxAbs(data1[i], buffer[i*7], refvalue));
VERIFY(isApproxAbs(data1[i], buffer[i*7], refvalue) && "pgather");
}
}