Fix the specialization of pfrexp for AVX to be faster when AVX2/AVX512DQ is not available, and avoid undefined behavior in C++. Also mask off the sign bit when extracting the exponent.

This commit is contained in:
Rasmus Munk Larsen 2020-10-15 18:39:58 -07:00
parent 011e0db31d
commit 21edea5edd
2 changed files with 20 additions and 18 deletions

View File

@ -691,26 +691,27 @@ template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Pack
template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
const Packet4d cst_1022d = pset1<Packet4d>(1022.0);
const Packet4d cst_half = pset1<Packet4d>(0.5);
const Packet4d cst_inv_mant_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
__m256i a_expo = _mm256_castpd_si256(a);
const Packet4d cst_exp_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(0x7ff0000000000000ull));
__m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask));
#ifdef EIGEN_VECTORIZE_AVX2
a_expo = _mm256_srli_epi64(a_expo, 52);
#else
__m128i lo = _mm_srli_epi64(_mm256_extractf128_si256(a_expo, 0), 52);
__m128i hi = _mm_srli_epi64(_mm256_extractf128_si256(a_expo, 1), 52);
a_expo = _mm256_insertf128_si256(_mm256_castsi128_si256(lo), (hi), 1);
#endif
#ifdef EIGEN_VECTORIZE_AVX512DQ
// AVX512DQ finally provides an instruction for this
#if defined(EIGEN_VECTORIZE_AVX2) && defined(EIGEN_VECTORIZE_AVX512DQ)
exponent = _mm256_cvtepi64_pd(a_expo);
#else
exponent = _mm256_set_pd(static_cast<double>(_mm256_extract_epi64(a_expo, 3)),
static_cast<double>(_mm256_extract_epi64(a_expo, 2)),
static_cast<double>(_mm256_extract_epi64(a_expo, 1)),
static_cast<double>(_mm256_extract_epi64(a_expo, 0)));
#endif
#else
__m128i lo = _mm256_extractf128_si256(a_expo, 0);
__m128i hi = _mm256_extractf128_si256(a_expo, 1);
#ifndef EIGEN_VECTORIZE_AVX2
lo = _mm_srli_epi64(lo, 52);
hi = _mm_srli_epi64(hi, 52);
#endif
Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3));
Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3));
exponent = _mm256_set_m128d(exponent_hi, exponent_lo);
#endif // EIGEN_VECTORIZE_AVX512DQ
exponent = psub(exponent, cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
const Packet4d cst_mant_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
return por(pand(a, cst_mant_mask), cst_half);
}
template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) {

View File

@ -805,10 +805,11 @@ template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Pack
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent) {
const Packet2d cst_1022d = pset1<Packet2d>(1022.0);
const Packet2d cst_half = pset1<Packet2d>(0.5);
const Packet2d cst_inv_mant_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
__m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(a), 52);
const Packet2d cst_exp_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(0x7ff0000000000000ull));
__m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52);
exponent = psub(_mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3)), cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
const Packet2d cst_mant_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
return por(pand(a, cst_mant_mask), cst_half);
}
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {