diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index cdf0fdfb0..fdf1afbf0 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -1158,18 +1158,9 @@ EIGEN_STRONG_INLINE Packet8f half2float(const Packet8h& a) { #ifdef EIGEN_HAS_FP16_C return _mm256_cvtph_ps(a); #else - EIGEN_ALIGN32 Eigen::half aux[8]; - pstore(aux, a); - float f0(aux[0]); - float f1(aux[1]); - float f2(aux[2]); - float f3(aux[3]); - float f4(aux[4]); - float f5(aux[5]); - float f6(aux[6]); - float f7(aux[7]); - - return _mm256_set_ps(f7, f6, f5, f4, f3, f2, f1, f0); + Eigen::internal::Packet8f pp = _mm256_castsi256_ps(_mm256_insertf128_si256( + _mm256_castsi128_si256(half2floatsse(a)), half2floatsse(_mm_srli_si128(a, 8)), 1)); + return pp; #endif } @@ -1177,17 +1168,9 @@ EIGEN_STRONG_INLINE Packet8h float2half(const Packet8f& a) { #ifdef EIGEN_HAS_FP16_C return _mm256_cvtps_ph(a, _MM_FROUND_TO_NEAREST_INT|_MM_FROUND_NO_EXC); #else - EIGEN_ALIGN32 float aux[8]; - pstore(aux, a); - const numext::uint16_t s0 = numext::bit_cast(Eigen::half(aux[0])); - const numext::uint16_t s1 = numext::bit_cast(Eigen::half(aux[1])); - const numext::uint16_t s2 = numext::bit_cast(Eigen::half(aux[2])); - const numext::uint16_t s3 = numext::bit_cast(Eigen::half(aux[3])); - const numext::uint16_t s4 = numext::bit_cast(Eigen::half(aux[4])); - const numext::uint16_t s5 = numext::bit_cast(Eigen::half(aux[5])); - const numext::uint16_t s6 = numext::bit_cast(Eigen::half(aux[6])); - const numext::uint16_t s7 = numext::bit_cast(Eigen::half(aux[7])); - return _mm_set_epi16(s7, s6, s5, s4, s3, s2, s1, s0); + __m128i lo = float2half(_mm256_extractf128_ps(a, 0)); + __m128i hi = float2half(_mm256_extractf128_ps(a, 1)); + return _mm_packus_epi32(lo, hi); #endif } diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 3d915026d..f2d266778 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -1281,6 +1281,106 @@ template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, co } #endif +#ifdef EIGEN_VECTORIZE_SSE4_1 +// Helpers for half->float and float->half conversions. +// Currently only used by the AVX code. +EIGEN_STRONG_INLINE __m128i half2floatsse(__m128i h) { + __m128i input = _mm_cvtepu16_epi32(h); + + // Direct vectorization of half_to_float, C parts in the comments. + __m128i shifted_exp = _mm_set1_epi32(0x7c00 << 13); + // o.u = (h.x & 0x7fff) << 13; // exponent/mantissa bits + __m128i ou = _mm_slli_epi32(_mm_and_si128(input, _mm_set1_epi32(0x7fff)), 13); + // exp = shifted_exp & o.u; // just the exponent + __m128i exp = _mm_and_si128(ou, shifted_exp); + // o.u += (127 - 15) << 23; + ou = _mm_add_epi32(ou, _mm_set1_epi32((127 - 15) << 23)); + + // Inf/NaN? + __m128i naninf_mask = _mm_cmpeq_epi32(exp, shifted_exp); + // Inf/NaN adjust + __m128i naninf_adj = + _mm_and_si128(_mm_set1_epi32((128 - 16) << 23), naninf_mask); + // extra exp adjust for Inf/NaN + ou = _mm_add_epi32(ou, naninf_adj); + + // Zero/Denormal? + __m128i zeroden_mask = _mm_cmpeq_epi32(exp, _mm_setzero_si128()); + __m128i zeroden_adj = _mm_and_si128(zeroden_mask, _mm_set1_epi32(1 << 23)); + // o.u += 1 << 23; + ou = _mm_add_epi32(ou, zeroden_adj); + // magic.u = 113 << 23 + __m128i magic = _mm_and_si128(zeroden_mask, _mm_set1_epi32(113 << 23)); + // o.f -= magic.f + ou = _mm_castps_si128( + _mm_sub_ps(_mm_castsi128_ps(ou), _mm_castsi128_ps(magic))); + + __m128i sign = + _mm_slli_epi32(_mm_and_si128(input, _mm_set1_epi32(0x8000)), 16); + // o.u |= (h.x & 0x8000) << 16; // sign bit + ou = _mm_or_si128(ou, sign); + // return o.f; + // We are actually returning uint version, to make + // _mm256_insertf128_si256 work. + return ou; +} + +EIGEN_STRONG_INLINE __m128i float2half(__m128 f) { + __m128i o = _mm_setzero_si128(); + + // unsigned int sign_mask = 0x80000000u; + __m128i sign = _mm_set1_epi32(0x80000000u); + // unsigned int sign = f.u & sign_mask; + sign = _mm_and_si128(sign, _mm_castps_si128(f)); + // f.u ^= sign; + f = _mm_xor_ps(f, _mm_castsi128_ps(sign)); + + __m128i fu = _mm_castps_si128(f); + + __m128i f16max = _mm_set1_epi32((127 + 16) << 23); + __m128i f32infty = _mm_set1_epi32(255 << 23); + // if (f.u >= f16max.u) // result is Inf or NaN (all exponent bits set) + // there is no _mm_cmpge_epi32, so use lt and swap operands + __m128i infnan_mask = _mm_cmplt_epi32(f16max, _mm_castps_si128(f)); + __m128i inf_mask = _mm_cmpgt_epi32(_mm_castps_si128(f), f32infty); + __m128i nan_mask = _mm_andnot_si128(inf_mask, infnan_mask); + __m128i inf_value = _mm_and_si128(inf_mask, _mm_set1_epi32(0x7e00)); + __m128i nan_value = _mm_and_si128(nan_mask, _mm_set1_epi32(0x7c00)); + // o.x = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf + __m128i naninf_value = _mm_or_si128(inf_value, nan_value); + + __m128i denorm_magic = _mm_set1_epi32(((127 - 15) + (23 - 10) + 1) << 23); + __m128i subnorm_mask = + _mm_cmplt_epi32(_mm_castps_si128(f), _mm_set1_epi32(113 << 23)); + // f.f += denorm_magic.f; + f = _mm_add_ps(f, _mm_castsi128_ps(denorm_magic)); + // f.u - denorm_magic.u + o = _mm_sub_epi32(_mm_castps_si128(f), denorm_magic); + o = _mm_and_si128(o, subnorm_mask); + // Correct result for inf/nan/zero/subnormal, 0 otherwise + o = _mm_or_si128(o, naninf_value); + + __m128i mask = _mm_or_si128(infnan_mask, subnorm_mask); + o = _mm_and_si128(o, mask); + + // mant_odd = (f.u >> 13) & 1; + __m128i mand_odd = _mm_and_si128(_mm_srli_epi32(fu, 13), _mm_set1_epi32(0x1)); + // f.u += 0xc8000fffU; + fu = _mm_add_epi32(fu, _mm_set1_epi32(0xc8000fffU)); + // f.u += mant_odd; + fu = _mm_add_epi32(fu, mand_odd); + fu = _mm_andnot_si128(mask, fu); + // f.u >> 13 + fu = _mm_srli_epi32(fu, 13); + o = _mm_or_si128(fu, o); + + // o.x |= static_cast(sign >> 16); + o = _mm_or_si128(o, _mm_srli_epi32(sign, 16)); + + // 16 bit values + return _mm_and_si128(o, _mm_set1_epi32(0xffff)); +} +#endif // Packet math for Eigen::half // Disable the following code since it's broken on too many platforms / compilers.