Optimize casting for x86_64.

This commit is contained in:
Rasmus Munk Larsen 2023-03-21 18:24:16 +00:00
parent 8f9b8e3630
commit 09945f2cc1
3 changed files with 85 additions and 32 deletions

View File

@ -16,29 +16,7 @@ namespace Eigen {
namespace internal {
// For now we use SSE to handle integers, so we can't use AVX instructions to cast
// from int to float
template <>
struct type_casting_traits<float, int> {
enum {
VectorizedCast = 0,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
template <>
struct type_casting_traits<int, float> {
enum {
VectorizedCast = 0,
SrcCoeffRatio = 1,
TgtCoeffRatio = 1
};
};
#ifndef EIGEN_VECTORIZE_AVX512
template <>
struct type_casting_traits<Eigen::half, float> {
enum {
@ -76,8 +54,17 @@ struct type_casting_traits<float, bfloat16> {
};
};
template <>
struct type_casting_traits<float, bool> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 2,
TgtCoeffRatio = 1
};
};
#endif // EIGEN_VECTORIZE_AVX512
template<> EIGEN_STRONG_INLINE Packet8i pcast<Packet8f, Packet8i>(const Packet8f& a) {
return _mm256_cvttps_epi32(a);
}
@ -86,6 +73,43 @@ template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet8i, Packet8f>(const Packet8i
return _mm256_cvtepi32_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet8f pcast<Packet4d, Packet8f>(const Packet4d& a, const Packet4d& b) {
return _mm256_set_m128(_mm256_cvtpd_ps(a), _mm256_cvtpd_ps(b));
}
template <>
EIGEN_STRONG_INLINE Packet16b pcast<Packet8f, Packet16b>(const Packet8f& a,
const Packet8f& b) {
__m256 nonzero_a = _mm256_cmp_ps(a, pzero(a), _CMP_NEQ_UQ);
__m256 nonzero_b = _mm256_cmp_ps(b, pzero(b), _CMP_NEQ_UQ);
constexpr char kFF = '\255';
#ifndef EIGEN_VECTORIZE_AVX2
__m128i shuffle_mask128_a_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
__m128i shuffle_mask128_a_hi = _mm_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF);
__m128i shuffle_mask128_b_lo = _mm_set_epi8(kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
__m128i shuffle_mask128_b_hi = _mm_set_epi8(12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
__m128i a_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 1), shuffle_mask128_a_hi);
__m128i a_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_a), 0), shuffle_mask128_a_lo);
__m128i b_hi = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 1), shuffle_mask128_b_hi);
__m128i b_lo = _mm_shuffle_epi8(_mm256_extractf128_si256(_mm256_castps_si256(nonzero_b), 0), shuffle_mask128_b_lo);
#else
__m256i a_shuffle_mask = _mm256_set_epi8(kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF,
kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, 12, 8, 4, 0);
__m256i b_shuffle_mask = _mm256_set_epi8( 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF,
kFF, kFF, kFF, kFF, 12, 8, 4, 0, kFF, kFF, kFF, kFF, kFF, kFF, kFF, kFF);
__m256i a_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_a), a_shuffle_mask);
__m256i b_shuff = _mm256_shuffle_epi8(_mm256_castps_si256(nonzero_b), b_shuffle_mask);
__m128i a_hi = _mm256_extractf128_si256(a_shuff, 1);
__m128i a_lo = _mm256_extractf128_si256(a_shuff, 0);
__m128i b_hi = _mm256_extractf128_si256(b_shuff, 1);
__m128i b_lo = _mm256_extractf128_si256(b_shuff, 0);
#endif
__m128i merged = _mm_or_si128(_mm_or_si128(b_lo, b_hi), _mm_or_si128(a_lo, a_hi));
return _mm_and_si128(merged, _mm_set1_epi8(1));
}
template<> EIGEN_STRONG_INLINE Packet8i preinterpret<Packet8i,Packet8f>(const Packet8f& a) {
return _mm256_castps_si256(a);
}

View File

@ -26,7 +26,7 @@ struct type_casting_traits<float, bool> {
};
template <>
struct type_casting_traits<bool,float> {
struct type_casting_traits<bool, float> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
@ -40,7 +40,7 @@ template<> EIGEN_STRONG_INLINE Packet16b pcast<Packet16f, Packet16b>(const Packe
}
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16b, Packet16f>(const Packet16b& a) {
return _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(a));
return _mm512_cvtepi32_ps(_mm512_and_si512(_mm512_cvtepi8_epi32(a), _mm512_set1_epi32(1)));
}
template<> EIGEN_STRONG_INLINE Packet16i pcast<Packet16f, Packet16i>(const Packet16f& a) {
@ -51,6 +51,10 @@ template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet16i, Packet16f>(const Packe
return _mm512_cvtepi32_ps(a);
}
template<> EIGEN_STRONG_INLINE Packet16f pcast<Packet8d, Packet16f>(const Packet8d& a, const Packet8d& b) {
return cat256(_mm512_cvtpd_ps(a), _mm512_cvtpd_ps(b));
}
template<> EIGEN_STRONG_INLINE Packet16i preinterpret<Packet16i, Packet16f>(const Packet16f& a) {
return _mm512_castps_si512(a);
}

View File

@ -17,6 +17,15 @@ namespace Eigen {
namespace internal {
#ifndef EIGEN_VECTORIZE_AVX
template <>
struct type_casting_traits<float, bool> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 4,
TgtCoeffRatio = 1
};
};
template <>
struct type_casting_traits<float, int> {
enum {
@ -35,6 +44,16 @@ struct type_casting_traits<int, float> {
};
};
template <>
struct type_casting_traits<float, double> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 2
};
};
#endif
template <>
struct type_casting_traits<double, float> {
enum {
@ -45,14 +64,20 @@ struct type_casting_traits<double, float> {
};
template <>
struct type_casting_traits<float, double> {
enum {
VectorizedCast = 1,
SrcCoeffRatio = 1,
TgtCoeffRatio = 2
};
};
#endif
EIGEN_STRONG_INLINE Packet16b pcast<Packet4f, Packet16b>(const Packet4f& a,
const Packet4f& b,
const Packet4f& c,
const Packet4f& d) {
__m128 zero = pzero(a);
__m128 nonzero_a = _mm_cmpneq_ps(a, zero);
__m128 nonzero_b = _mm_cmpneq_ps(b, zero);
__m128 nonzero_c = _mm_cmpneq_ps(c, zero);
__m128 nonzero_d = _mm_cmpneq_ps(d, zero);
__m128i ab_bytes = _mm_packs_epi32(_mm_castps_si128(nonzero_a), _mm_castps_si128(nonzero_b));
__m128i cd_bytes = _mm_packs_epi32(_mm_castps_si128(nonzero_c), _mm_castps_si128(nonzero_d));
__m128i merged = _mm_packs_epi16(ab_bytes, cd_bytes);
return _mm_and_si128(merged, _mm_set1_epi8(1));
}
template<> EIGEN_STRONG_INLINE Packet4i pcast<Packet4f, Packet4i>(const Packet4f& a) {
return _mm_cvttps_epi32(a);