Add AVX vector path to float2half/half2float

Makes e. g. matrix multiplication 2x faster:
name         old cpu/op  new cpu/op  delta
BM_convers   181ms ± 1%    62ms ± 9%  -65.82%  (p=0.016 n=4+5)

Tested on all possible input values (not adding tests, since they
take a long time).
This commit is contained in:
Ilya Tokar 2021-10-01 17:37:59 -04:00
parent 03d4cbb307
commit e1cb6369b0
2 changed files with 106 additions and 23 deletions

View File

@ -1158,18 +1158,9 @@ EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) {
#ifdef EIGEN_HAS_FP16_C
return _mm256_cvtph_ps(a);
#else
EIGEN_ALIGN32 Eigen::half aux[8];
pstore(aux, a);
float f0(aux[0]);
float f1(aux[1]);
float f2(aux[2]);
float f3(aux[3]);
float f4(aux[4]);
float f5(aux[5]);
float f6(aux[6]);
float f7(aux[7]);
return _mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0);
Eigen::internal::Packet8f pp = _mm256_castsi256_ps(_mm256_insertf128_si256(
_mm256_castsi128_si256(half2floatsse(a)), half2floatsse(_mm_srli_si128(a, 8)), 1));
return pp;
#endif
}
@ -1177,17 +1168,9 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) {
#ifdef EIGEN_HAS_FP16_C
return _mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC);
#else
EIGEN_ALIGN32 float aux[8];
pstore(aux, a);
const numext::uint16_t s0 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[0]));
const numext::uint16_t s1 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[1]));
const numext::uint16_t s2 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[2]));
const numext::uint16_t s3 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[3]));
const numext::uint16_t s4 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[4]));
const numext::uint16_t s5 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[5]));
const numext::uint16_t s6 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[6]));
const numext::uint16_t s7 = numext::bit_cast<numext::uint16_t>(Eigen::half(aux[7]));
return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0);
__m128i lo = float2half(_mm256_extractf128_ps(a, 0));
__m128i hi = float2half(_mm256_extractf128_ps(a, 1));
return _mm_packus_epi32(lo, hi);
#endif
}

View File

@ -1281,6 +1281,106 @@ template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, co
}
#endif
#ifdef EIGEN_VECTORIZE_SSE4_1
// Helpers for half->float and float->half conversions.
// Currently only used by the AVX code.
EIGEN_STRONG_INLINE __m128i half2floatsse(__m128i h) {
__m128i input = _mm_cvtepu16_epi32(h);
// Direct vectorization of half_to_float, C parts in the comments.
__m128i shifted_exp = _mm_set1_epi32(0x7c00 << 13);
// o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits
__m128i ou = _mm_slli_epi32(_mm_and_si128(input, _mm_set1_epi32(0x7fff)), 13);
// exp = shifted_exp & o.u; // just the exponent
__m128i exp = _mm_and_si128(ou, shifted_exp);
// o.u += (127 - 15) << 23;
ou = _mm_add_epi32(ou, _mm_set1_epi32((127 - 15) << 23));
// Inf/NaN?
__m128i naninf_mask = _mm_cmpeq_epi32(exp, shifted_exp);
// Inf/NaN adjust
__m128i naninf_adj =
_mm_and_si128(_mm_set1_epi32((128 - 16) << 23), naninf_mask);
// extra exp adjust for Inf/NaN
ou = _mm_add_epi32(ou, naninf_adj);
// Zero/Denormal?
__m128i zeroden_mask = _mm_cmpeq_epi32(exp, _mm_setzero_si128());
__m128i zeroden_adj = _mm_and_si128(zeroden_mask, _mm_set1_epi32(1 << 23));
// o.u += 1 << 23;
ou = _mm_add_epi32(ou, zeroden_adj);
// magic.u = 113 << 23
__m128i magic = _mm_and_si128(zeroden_mask, _mm_set1_epi32(113 << 23));
// o.f -= magic.f
ou = _mm_castps_si128(
_mm_sub_ps(_mm_castsi128_ps(ou), _mm_castsi128_ps(magic)));
__m128i sign =
_mm_slli_epi32(_mm_and_si128(input, _mm_set1_epi32(0x8000)), 16);
// o.u |= (h.x & 0x8000) << 16; // sign bit
ou = _mm_or_si128(ou, sign);
// return o.f;
// We are actually returning uint version, to make
// _mm256_insertf128_si256 work.
return ou;
}
EIGEN_STRONG_INLINE __m128i float2half(__m128 f) {
__m128i o = _mm_setzero_si128();
// unsigned int sign_mask = 0x80000000u;
__m128i sign = _mm_set1_epi32(0x80000000u);
// unsigned int sign = f.u & sign_mask;
sign = _mm_and_si128(sign, _mm_castps_si128(f));
// f.u ^= sign;
f = _mm_xor_ps(f, _mm_castsi128_ps(sign));
__m128i fu = _mm_castps_si128(f);
__m128i f16max = _mm_set1_epi32((127 + 16) << 23);
__m128i f32infty = _mm_set1_epi32(255 << 23);
// if (f.u >= f16max.u) // result is Inf or NaN (all exponent bits set)
// there is no _mm_cmpge_epi32, so use lt and swap operands
__m128i infnan_mask = _mm_cmplt_epi32(f16max, _mm_castps_si128(f));
__m128i inf_mask = _mm_cmpgt_epi32(_mm_castps_si128(f), f32infty);
__m128i nan_mask = _mm_andnot_si128(inf_mask, infnan_mask);
__m128i inf_value = _mm_and_si128(inf_mask, _mm_set1_epi32(0x7e00));
__m128i nan_value = _mm_and_si128(nan_mask, _mm_set1_epi32(0x7c00));
// o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
__m128i naninf_value = _mm_or_si128(inf_value, nan_value);
__m128i denorm_magic = _mm_set1_epi32(((127 - 15) + (23 - 10) + 1) << 23);
__m128i subnorm_mask =
_mm_cmplt_epi32(_mm_castps_si128(f), _mm_set1_epi32(113 << 23));
// f.f += denorm_magic.f;
f = _mm_add_ps(f, _mm_castsi128_ps(denorm_magic));
// f.u - denorm_magic.u
o = _mm_sub_epi32(_mm_castps_si128(f), denorm_magic);
o = _mm_and_si128(o, subnorm_mask);
// Correct result for inf/nan/zero/subnormal, 0 otherwise
o = _mm_or_si128(o, naninf_value);
__m128i mask = _mm_or_si128(infnan_mask, subnorm_mask);
o = _mm_and_si128(o, mask);
// mant_odd = (f.u >> 13) & 1;
__m128i mand_odd = _mm_and_si128(_mm_srli_epi32(fu, 13), _mm_set1_epi32(0x1));
// f.u += 0xc8000fffU;
fu = _mm_add_epi32(fu, _mm_set1_epi32(0xc8000fffU));
// f.u += mant_odd;
fu = _mm_add_epi32(fu, mand_odd);
fu = _mm_andnot_si128(mask, fu);
// f.u >> 13
fu = _mm_srli_epi32(fu, 13);
o = _mm_or_si128(fu, o);
// o.x |= static_cast<numext::uint16_t>(sign >> 16);
o = _mm_or_si128(o, _mm_srli_epi32(sign, 16));
// 16 bit values
return _mm_and_si128(o, _mm_set1_epi32(0xffff));
}
#endif
// Packet math for Eigen::half
// Disable the following code since it's broken on too many platforms / compilers.