From b430eb31e23b13173365c8ebf632855243d6b02b Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Sat, 15 Jun 2024 17:45:02 +0000 Subject: [PATCH] AVX512F double->int64_t cast --- Eigen/src/Core/arch/AVX512/TypeCasting.h | 36 ++++++++++++++++++------ 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/Eigen/src/Core/arch/AVX512/TypeCasting.h b/Eigen/src/Core/arch/AVX512/TypeCasting.h index b16e9f6c8..e471684cc 100644 --- a/Eigen/src/Core/arch/AVX512/TypeCasting.h +++ b/Eigen/src/Core/arch/AVX512/TypeCasting.h @@ -82,14 +82,34 @@ EIGEN_STRONG_INLINE Packet8d pcast(const Packet8f& a) { template <> EIGEN_STRONG_INLINE Packet8l pcast(const Packet8d& a) { -#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL) +#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL) return _mm512_cvttpd_epi64(a); #else - EIGEN_ALIGN16 double aux[8]; - pstore(aux, a); - return _mm512_set_epi64(static_cast(aux[7]), static_cast(aux[6]), static_cast(aux[5]), - static_cast(aux[4]), static_cast(aux[3]), static_cast(aux[2]), - static_cast(aux[1]), static_cast(aux[0])); + constexpr int kTotalBits = sizeof(double) * CHAR_BIT, kMantissaBits = std::numeric_limits::digits - 1, + kExponentBits = kTotalBits - kMantissaBits - 1, kBias = (1 << (kExponentBits - 1)) - 1; + + const __m512i cst_one = _mm512_set1_epi64(1); + const __m512i cst_total_bits = _mm512_set1_epi64(kTotalBits); + const __m512i cst_bias = _mm512_set1_epi64(kBias); + + __m512i a_bits = _mm512_castpd_si512(a); + // shift left by 1 to clear the sign bit, and shift right by kMantissaBits + 1 to recover biased exponent + __m512i biased_e = _mm512_srli_epi64(_mm512_slli_epi64(a_bits, 1), kMantissaBits + 1); + __m512i e = _mm512_sub_epi64(biased_e, cst_bias); + + // shift to the left by kExponentBits + 1 to clear the sign and exponent bits + __m512i shifted_mantissa = _mm512_slli_epi64(a_bits, kExponentBits + 1); + // shift to the right by kTotalBits - e to convert the significand to an integer + __m512i result_significand = _mm512_srlv_epi64(shifted_mantissa, _mm512_sub_epi64(cst_total_bits, e)); + + // add the implied bit + __m512i result_exponent = _mm512_sllv_epi64(cst_one, e); + // e <= 0 is interpreted as a large positive shift (2's complement), which also conveniently results in zero + __m512i result = _mm512_add_epi64(result_significand, result_exponent); + // handle negative arguments + __mmask8 sign_mask = _mm512_cmplt_epi64_mask(a_bits, _mm512_setzero_si512()); + result = _mm512_mask_sub_epi64(result, sign_mask, _mm512_setzero_si512(), result); + return result; #endif } @@ -110,10 +130,10 @@ EIGEN_STRONG_INLINE Packet8d pcast(const Packet8i& a) { template <> EIGEN_STRONG_INLINE Packet8d pcast(const Packet8l& a) { -#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVS512VL) +#if defined(EIGEN_VECTORIZE_AVX512DQ) && defined(EIGEN_VECTORIZE_AVX512VL) return _mm512_cvtepi64_pd(a); #else - EIGEN_ALIGN16 int64_t aux[8]; + EIGEN_ALIGN64 int64_t aux[8]; pstore(aux, a); return _mm512_set_pd(static_cast(aux[7]), static_cast(aux[6]), static_cast(aux[5]), static_cast(aux[4]), static_cast(aux[3]), static_cast(aux[2]),