Use native AVX512 types instead of Eigen Packets whenever possible.

This commit is contained in:
Benoit Steiner 2016-12-21 20:06:18 -08:00
parent 660da83e18
commit d7825b6707

View File

@ -59,8 +59,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasLog = 1, HasLog = 1,
#endif #endif
HasExp = 1, HasExp = 1,
HasSqrt = 1, HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = 1, HasRsqrt = EIGEN_FAST_MATH,
#endif #endif
HasDiv = 1 HasDiv = 1
}; };
@ -75,7 +75,7 @@ template<> struct packet_traits<double> : default_packet_traits
size = 8, size = 8,
HasHalfPacket = 1, HasHalfPacket = 1,
#if EIGEN_GNUC_AT_LEAST(5, 3) #if EIGEN_GNUC_AT_LEAST(5, 3)
HasSqrt = 1, HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH, HasRsqrt = EIGEN_FAST_MATH,
#endif #endif
HasDiv = 1 HasDiv = 1
@ -719,7 +719,7 @@ vecs)
blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
final = padd(final, _mm256_blend_ps(blend1, blend2, 0xf0)); final = _mm256_add_ps(final, _mm256_blend_ps(blend1, blend2, 0xf0));
hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0); hsum1 = _mm256_hadd_ps(vecs8_0, vecs9_0);
hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0); hsum2 = _mm256_hadd_ps(vecs10_0, vecs11_0);
@ -769,7 +769,7 @@ vecs)
blend1 = _mm256_blend_ps(sum1, sum2, 0xcc); blend1 = _mm256_blend_ps(sum1, sum2, 0xcc);
blend2 = _mm256_blend_ps(sum3, sum4, 0xcc); blend2 = _mm256_blend_ps(sum3, sum4, 0xcc);
final_1 = padd(final_1, _mm256_blend_ps(blend1, blend2, 0xf0)); final_1 = _mm256_add_ps(final_1, _mm256_blend_ps(blend1, blend2, 0xf0));
__m512 final_output; __m512 final_output;
@ -819,7 +819,7 @@ template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs)
tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1); tmp1 = _mm256_hadd_pd(vecs2_1, vecs3_1);
tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
final_0 = padd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC)); final_0 = _mm256_add_pd(final_0, _mm256_blend_pd(tmp0, tmp1, 0xC));
tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0); tmp0 = _mm256_hadd_pd(vecs4_0, vecs5_0);
tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1)); tmp0 = _mm256_add_pd(tmp0, _mm256_permute2f128_pd(tmp0, tmp0, 1));
@ -835,7 +835,7 @@ template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs)
tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1); tmp1 = _mm256_hadd_pd(vecs6_1, vecs7_1);
tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1)); tmp1 = _mm256_add_pd(tmp1, _mm256_permute2f128_pd(tmp1, tmp1, 1));
final_1 = padd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC)); final_1 = _mm256_add_pd(final_1, _mm256_blend_pd(tmp0, tmp1, 0xC));
__m512d final_output = _mm512_insertf64x4(final_output, final_0, 0); __m512d final_output = _mm512_insertf64x4(final_output, final_0, 0);
@ -844,55 +844,54 @@ template<> EIGEN_STRONG_INLINE Packet8d preduxp<Packet8d>(const Packet8d* vecs)
template <> template <>
EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) { EIGEN_STRONG_INLINE float predux<Packet16f>(const Packet16f& a) {
//#ifdef EIGEN_VECTORIZE_AVX512DQ #ifdef EIGEN_VECTORIZE_AVX512DQ
#if 0 __m256 lane0 = _mm512_extractf32x8_ps(a, 0);
Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); __m256 lane1 = _mm512_extractf32x8_ps(a, 1);
Packet8f lane1 = _mm512_extractf32x8_ps(a, 1); __m256 sum = _mm256_add_ps(lane0, lane1);
Packet8f sum = padd(lane0, lane1); __m256 tmp0 = _mm256_hadd_ps(sum, _mm256_permute2f128_ps(a, a, 1));
Packet8f tmp0 = _mm256_hadd_ps(sum, _mm256_permute2f128_ps(a, a, 1));
tmp0 = _mm256_hadd_ps(tmp0, tmp0); tmp0 = _mm256_hadd_ps(tmp0, tmp0);
return pfirst(_mm256_hadd_ps(tmp0, tmp0)); return _mm_cvtss_f32(_mm256_castps256_ps128(_mm256_hadd_ps(tmp0, tmp0)));
#else #else
Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
Packet4f sum = padd(padd(lane0, lane1), padd(lane2, lane3)); __m128 sum = _mm_add_ps(_mm_add_ps(lane0, lane1), _mm_add_ps(lane2, lane3));
sum = _mm_hadd_ps(sum, sum); sum = _mm_hadd_ps(sum, sum);
sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1)); sum = _mm_hadd_ps(sum, _mm_permute_ps(sum, 1));
return pfirst(sum); return _mm_cvtss_f32(sum);
#endif #endif
} }
template <> template <>
EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) { EIGEN_STRONG_INLINE double predux<Packet8d>(const Packet8d& a) {
Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
Packet4d sum = padd(lane0, lane1); __m256d sum = _mm256_add_pd(lane0, lane1);
Packet4d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1)); __m256d tmp0 = _mm256_hadd_pd(sum, _mm256_permute2f128_pd(sum, sum, 1));
return pfirst(_mm256_hadd_pd(tmp0, tmp0)); return _mm_cvtsd_f64(_mm256_castpd256_pd128(_mm256_hadd_pd(tmp0, tmp0)));
} }
template <> template <>
EIGEN_STRONG_INLINE Packet8f predux_downto4<Packet16f>(const Packet16f& a) { EIGEN_STRONG_INLINE Packet8f predux_downto4<Packet16f>(const Packet16f& a) {
#ifdef EIGEN_VECTORIZE_AVX512DQ #ifdef EIGEN_VECTORIZE_AVX512DQ
Packet8f lane0 = _mm512_extractf32x8_ps(a, 0); __m256 lane0 = _mm512_extractf32x8_ps(a, 0);
Packet8f lane1 = _mm512_extractf32x8_ps(a, 1); __m256 lane1 = _mm512_extractf32x8_ps(a, 1);
return padd(lane0, lane1); return _mm256_add_ps(lane0, lane1);
#else #else
Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
Packet4f sum0 = padd(lane0, lane2); __m128 sum0 = _mm_add_ps(lane0, lane2);
Packet4f sum1 = padd(lane1, lane3); __m128 sum1 = _mm_add_ps(lane1, lane3);
return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1); return _mm256_insertf128_ps(_mm256_castps128_ps256(sum0), sum1, 1);
#endif #endif
} }
template <> template <>
EIGEN_STRONG_INLINE Packet4d predux_downto4<Packet8d>(const Packet8d& a) { EIGEN_STRONG_INLINE Packet4d predux_downto4<Packet8d>(const Packet8d& a) {
Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
Packet4d res = padd(lane0, lane1); __m256d res = _mm256_add_pd(lane0, lane1);
return res; return res;
} }
@ -907,58 +906,59 @@ EIGEN_STRONG_INLINE float predux_mul<Packet16f>(const Packet16f& a) {
res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
#else #else
Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
Packet4f res = pmul(pmul(lane0, lane1), pmul(lane2, lane3)); __m128 res = pmul(pmul(lane0, lane1), pmul(lane2, lane3));
res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); res = pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); return pfirst(pmul(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
#endif #endif
} }
template <> template <>
EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) { EIGEN_STRONG_INLINE double predux_mul<Packet8d>(const Packet8d& a) {
Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
Packet4d res = pmul(lane0, lane1); __m256d res = pmul(lane0, lane1);
res = pmul(res, _mm256_permute2f128_pd(res, res, 1)); res = pmul(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1))); return pfirst(pmul(res, _mm256_shuffle_pd(res, res, 1)));
} }
template <> template <>
EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) { EIGEN_STRONG_INLINE float predux_min<Packet16f>(const Packet16f& a) {
Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
Packet4f res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3)); __m128 res = _mm_min_ps(_mm_min_ps(lane0, lane1), _mm_min_ps(lane2, lane3));
res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); return pfirst(_mm_min_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
} }
template <> template <>
EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) { EIGEN_STRONG_INLINE double predux_min<Packet8d>(const Packet8d& a) {
Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
Packet4d res = _mm256_min_pd(lane0, lane1); __m256d res = _mm256_min_pd(lane0, lane1);
res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1)); res = _mm256_min_pd(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1))); return pfirst(_mm256_min_pd(res, _mm256_shuffle_pd(res, res, 1)));
} }
template <> template <>
EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) { EIGEN_STRONG_INLINE float predux_max<Packet16f>(const Packet16f& a) {
Packet4f lane0 = _mm512_extractf32x4_ps(a, 0); __m128 lane0 = _mm512_extractf32x4_ps(a, 0);
Packet4f lane1 = _mm512_extractf32x4_ps(a, 1); __m128 lane1 = _mm512_extractf32x4_ps(a, 1);
Packet4f lane2 = _mm512_extractf32x4_ps(a, 2); __m128 lane2 = _mm512_extractf32x4_ps(a, 2);
Packet4f lane3 = _mm512_extractf32x4_ps(a, 3); __m128 lane3 = _mm512_extractf32x4_ps(a, 3);
Packet4f res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3)); __m128 res = _mm_max_ps(_mm_max_ps(lane0, lane1), _mm_max_ps(lane2, lane3));
res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2))); res = _mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1)))); return pfirst(_mm_max_ps(res, _mm_permute_ps(res, _MM_SHUFFLE(0, 0, 0, 1))));
} }
template <> template <>
EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) { EIGEN_STRONG_INLINE double predux_max<Packet8d>(const Packet8d& a) {
Packet4d lane0 = _mm512_extractf64x4_pd(a, 0); __m256d lane0 = _mm512_extractf64x4_pd(a, 0);
Packet4d lane1 = _mm512_extractf64x4_pd(a, 1); __m256d lane1 = _mm512_extractf64x4_pd(a, 1);
Packet4d res = _mm256_max_pd(lane0, lane1); __m256d res = _mm256_max_pd(lane0, lane1);
res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1)); res = _mm256_max_pd(res, _mm256_permute2f128_pd(res, res, 1));
return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1))); return pfirst(_mm256_max_pd(res, _mm256_shuffle_pd(res, res, 1)));
} }