Updated pfrexp implementation.

The original implementation fails for 0, denormals, inf, and NaN.

See #2150
This commit is contained in:
Antonio Sanchez 2021-02-12 11:32:29 -08:00 committed by Rasmus Munk Larsen
parent 9ad4096ccb
commit 7ff0b7a980
9 changed files with 225 additions and 144 deletions

View File

@ -734,12 +734,13 @@ template<> EIGEN_STRONG_INLINE Packet4d pabs(const Packet4d& a)
}
template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) {
return pfrexp_float(a,exponent);
return pfrexp_generic(a,exponent);
}
template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
const Packet4d cst_1022d = pset1<Packet4d>(1022.0);
const Packet4d cst_half = pset1<Packet4d>(0.5);
// Extract exponent without existence of Packet4l.
template<>
EIGEN_STRONG_INLINE
Packet4d pfrexp_generic_get_biased_exponent(const Packet4d& a) {
const Packet4d cst_exp_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(0x7ff0000000000000ull));
__m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask));
#ifdef EIGEN_VECTORIZE_AVX2
@ -754,15 +755,18 @@ template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Pack
#endif
Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3));
Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3));
exponent = _mm256_insertf128_pd(exponent, exponent_lo, 0);
Packet4d exponent = _mm256_insertf128_pd(exponent, exponent_lo, 0);
exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1);
exponent = psub(exponent, cst_1022d);
const Packet4d cst_mant_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
return por(pand(a, cst_mant_mask), cst_half);
return exponent;
}
template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
return pfrexp_generic(a, exponent);
}
template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) {
return pldexp_float(a,exponent);
return pldexp_generic(a, exponent);
}
template<> EIGEN_STRONG_INLINE Packet4d pldexp<Packet4d>(const Packet4d& a, const Packet4d& exponent) {

View File

@ -895,25 +895,28 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
template<>
EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
return pfrexp_float(a, exponent);
return pfrexp_generic(a, exponent);
}
// Extract exponent without existence of Packet8l.
template<>
EIGEN_STRONG_INLINE
Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) {
const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull));
#ifdef EIGEN_VECTORIZE_AVX512DQ
return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52));
#else
return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)));
#endif
}
template<>
EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) {
const Packet8d cst_1022d = pset1<Packet8d>(1022.0);
#ifdef EIGEN_TEST_AVX512DQ
exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
#else
exponent = psub(_mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(a), 52))),
cst_1022d);
#endif
const Packet8d cst_half = pset1<Packet8d>(0.5);
const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
return por(pand(a, cst_inv_mant_mask), cst_half);
return pfrexp_generic(a, exponent);
}
template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
return pldexp_float(a,exponent);
return pldexp_generic(a, exponent);
}
template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {

View File

@ -1160,43 +1160,48 @@ template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
return pand<Packet8us>(p8us_abs_mask, a);
}
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a)
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a)
{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a)
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a)
{ return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a)
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(const Packet4i& a)
{ return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(Packet4f a)
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(const Packet4f& a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
Packet4ui r = vec_sl(reinterpret_cast<Packet4ui>(a), p4ui_mask);
return reinterpret_cast<Packet4f>(r);
}
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(Packet4f a)
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(const Packet4f& a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
Packet4ui r = vec_sr(reinterpret_cast<Packet4ui>(a), p4ui_mask);
return reinterpret_cast<Packet4f>(r);
}
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a)
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(const Packet4ui& a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
return vec_sr(a, p4ui_mask);
}
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a)
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(const Packet4ui& a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
return vec_sl(a, p4ui_mask);
}
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a)
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(const Packet8us& a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
return vec_sl(a, p8us_mask);
}
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_right(const Packet8us& a)
{
const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
return vec_sr(a, p8us_mask);
}
EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){
return plogical_shift_left<16>(reinterpret_cast<Packet4f>(bf.m_val));
@ -1323,14 +1328,14 @@ template<> EIGEN_STRONG_INLINE Packet8bf pexp<Packet8bf> (const Packet8bf& a){
}
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
return pldexp_float(a,exponent);
return pldexp_generic(a,exponent);
}
template<> EIGEN_STRONG_INLINE Packet8bf pldexp<Packet8bf> (const Packet8bf& a, const Packet8bf& exponent){
BF16_TO_F32_BINARY_OP_WRAPPER(pldexp_float, a, exponent);
BF16_TO_F32_BINARY_OP_WRAPPER(pldexp<Packet4f>, a, exponent);
}
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
return pfrexp_float(a,exponent);
return pfrexp_generic(a,exponent);
}
template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& e){
Packet4f a_even = Bf16ToF32Even(a);
@ -2324,6 +2329,11 @@ template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
return v;
}
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(unsigned long from) {
Packet2l v = {static_cast<long long>(from), static_cast<long long>(from)};
return reinterpret_cast<Packet2d>(v);
}
template<> EIGEN_STRONG_INLINE void
pbroadcast4<Packet2d>(const double *a,
Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
@ -2439,7 +2449,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs
// a slow version that works with older compilers.
// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
static inline Packet2l ConvertToPacket2l(const Packet2d& x) {
template<>
inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
#if EIGEN_GNUC_AT_LEAST(5, 4) || \
(EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1)
return vec_cts(x, 0); // TODO: check clang version.
@ -2452,6 +2463,15 @@ static inline Packet2l ConvertToPacket2l(const Packet2d& x) {
#endif
}
template<>
inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
unsigned long long tmp[2];
memcpy(tmp, &x, sizeof(tmp));
Packet2d d = { static_cast<double>(tmp[0]),
static_cast<double>(tmp[1]) };
return d;
}
// Packet2l shifts.
// For POWER8 we simply use vec_sr/l.
@ -2569,7 +2589,7 @@ EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) {
template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
// Clamp exponent to [-2099, 2099]
const Packet2d max_exponent = pset1<Packet2d>(2099.0);
const Packet2l e = ConvertToPacket2l(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
const Packet2l e = pcast<Packet2d, Packet2l>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
// Split 2^e into four factors and multiply:
const Packet2l bias = { 1023, 1023 };
@ -2582,14 +2602,16 @@ template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, cons
return out;
}
// Extract exponent without existence of Packet2l.
template<>
EIGEN_STRONG_INLINE
Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
return pcast<Packet2l, Packet2d>(plogical_shift_right<52>(reinterpret_cast<Packet2l>(pabs(a))));
}
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d> (const Packet2d& a, Packet2d& exponent) {
double exp[2] = { exponent[0], exponent[1] };
Packet2d ret = { pfrexp<double>(a[0], exp[0]), pfrexp<double>(a[1], exp[1]) };
exponent[0] = exp[0];
exponent[1] = exp[1];
return ret;
// This doesn't currently work (no integer_packet for Packet2d - but adding it causes other problems)
// return pfrexp_double(a, exponent);
return pfrexp_generic(a, exponent);
}
template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)

View File

@ -25,80 +25,114 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */) {
return pload<Packet>(a);
}
template<typename Packet> EIGEN_STRONG_INLINE Packet
pfrexp_float(const Packet& a, Packet& exponent) {
// Creates a Scalar integer type with same bit-width.
template<typename T> struct make_integer;
template<> struct make_integer<float> { typedef numext::int32_t type; };
template<> struct make_integer<double> { typedef numext::int64_t type; };
template<> struct make_integer<half> { typedef numext::int16_t type; };
template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
const Packet cst_126f = pset1<Packet>(126.0f);
const Packet cst_half = pset1<Packet>(0.5f);
const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u);
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_126f);
return por(pand(a, cst_inv_mant_mask), cst_half);
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
}
template<typename Packet> EIGEN_STRONG_INLINE Packet
pfrexp_double(const Packet& a, Packet& exponent) {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
const Packet cst_1022d = pset1<Packet>(1022.0);
const Packet cst_half = pset1<Packet>(0.5);
const Packet cst_inv_mant_mask = pset1frombits<Packet, uint64_t>(static_cast<uint64_t>(~0x7ff0000000000000ull));
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d);
return por(pand(a, cst_inv_mant_mask), cst_half);
// Safely applies frexp, correctly handles denormals.
// Assumes IEEE floating point format.
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
const Packet half = pset1<Packet>(Scalar(0.5));
const Packet zero = pzero(a);
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
// To handle denormals, normalize by multiplying by 2^(mantissa_bits+1).
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
// Determine exponent offset: -126 if normal, -126-24 if denormal
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
// Determine exponent and mantissa from normalized_a.
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
return m;
}
// Safely applies ldexp, correctly handles overflows, underflows and denormals.
// Assumes IEEE floating point format.
template<typename Packet>
struct pldexp_impl {
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pldexp_generic(const Packet& a, const Packet& exponent) {
// We want to return a * 2^exponent, allowing for all possible integer
// exponents without overflowing or underflowing in intermediate
// computations.
//
// Since 'a' and the output can be denormal, the maximum range of 'exponent'
// to consider for a float is:
// -255-23 -> 255+23
// Below -278 any finite float 'a' will become zero, and above +278 any
// finite float will become inf, including when 'a' is the smallest possible
// denormal.
//
// Unfortunately, 2^(278) cannot be represented using either one or two
// finite normal floats, so we must split the scale factor into at least
// three parts. It turns out to be faster to split 'exponent' into four
// factors, since [exponent>>2] is much faster to compute that [exponent/3].
//
// Set e = min(max(exponent, -278), 278);
// b = floor(e/4);
// out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
//
// This will avoid any intermediate overflows and correctly handle 0, inf,
// NaN cases.
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
typedef typename unpacket_traits<Packet>::type Scalar;
typedef typename unpacket_traits<PacketI>::type ScalarI;
enum {
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = std::numeric_limits<Scalar>::digits - 1,
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet run(const Packet& a, const Packet& exponent) {
// We want to return a * 2^exponent, allowing for all possible integer
// exponents without overflowing or underflowing in intermediate
// computations.
//
// Since 'a' and the output can be denormal, the maximum range of 'exponent'
// to consider for a float is:
// -255-23 -> 255+23
// Below -278 any finite float 'a' will become zero, and above +278 any
// finite float will become inf, including when 'a' is the smallest possible
// denormal.
//
// Unfortunately, 2^(278) cannot be represented using either one or two
// finite normal floats, so we must split the scale factor into at least
// three parts. It turns out to be faster to split 'exponent' into four
// factors, since [exponent>>2] is much faster to compute that [exponent/3].
//
// Set e = min(max(exponent, -278), 278);
// b = floor(e/4);
// out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
//
// This will avoid any intermediate overflows and correctly handle 0, inf,
// NaN cases.
const Packet max_exponent = pset1<Packet>(Scalar( (ScalarI(1)<<int(ExponentBits)) + ScalarI(MantissaBits) - ScalarI(1))); // 278
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
out = pmul(out, c);
return out;
}
};
EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<exponent_bits) + ScalarI(mantissa_bits - 1))); // 278
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
Packet c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^b
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
b = psub(psub(psub(e, b), b), b); // e - 3b
c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^(e-3*b)
out = pmul(out, c);
return out;
}
// Explicitly multiplies
// a * (2^e)
// clamping e to the range
// [std::numeric_limits<Scalar>::min_exponent-2, std::numeric_limits<Scalar>::max_exponent]
// [numeric_limits<Scalar>::min_exponent-2, numeric_limits<Scalar>::max_exponent]
//
// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
// if 2^e doesn't fit into a normal floating-point Scalar.
@ -111,7 +145,7 @@ struct pldexp_fast_impl {
typedef typename unpacket_traits<PacketI>::type ScalarI;
enum {
TotalBits = sizeof(Scalar) * CHAR_BIT,
MantissaBits = std::numeric_limits<Scalar>::digits - 1,
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
};
@ -126,14 +160,6 @@ struct pldexp_fast_impl {
}
};
template<typename Packet> EIGEN_STRONG_INLINE Packet
pldexp_float(const Packet& a, const Packet& exponent)
{ return pldexp_impl<Packet>::run(a, exponent); }
template<typename Packet> EIGEN_STRONG_INLINE Packet
pldexp_double(const Packet& a, const Packet& exponent)
{ return pldexp_impl<Packet>::run(a, exponent); }
// Natural or base 2 logarithm.
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can

View File

@ -25,29 +25,23 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */);
* Some generic implementations to be used by implementors
***************************************************************************/
/** Default implementation of pfrexp for float.
/** Default implementation of pfrexp.
* It is expected to be called by implementers of template<> pfrexp.
*/
template<typename Packet> EIGEN_STRONG_INLINE Packet
pfrexp_float(const Packet& a, Packet& exponent);
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic(const Packet& a, Packet& exponent);
/** Default implementation of pfrexp for double.
* It is expected to be called by implementers of template<> pfrexp.
*/
template<typename Packet> EIGEN_STRONG_INLINE Packet
pfrexp_double(const Packet& a, Packet& exponent);
// Extracts the biased exponent value from Packet p, and casts the results to
// a floating-point Packet type. Used by pfrexp_generic. Override this if
// there is no unpacket_traits<Packet>::integer_packet.
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pfrexp_generic_get_biased_exponent(const Packet& p);
/** Default implementation of pldexp for float.
/** Default implementation of pldexp.
* It is expected to be called by implementers of template<> pldexp.
*/
template<typename Packet> EIGEN_STRONG_INLINE Packet
pldexp_float(const Packet& a, const Packet& exponent);
/** Default implementation of pldexp for double.
* It is expected to be called by implementers of template<> pldexp.
*/
template<typename Packet> EIGEN_STRONG_INLINE Packet
pldexp_double(const Packet& a, const Packet& exponent);
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
Packet pldexp_generic(const Packet& a, const Packet& exponent);
/** \internal \returns log(x) for single precision float */
template <typename Packet>

View File

@ -2402,14 +2402,14 @@ template<> EIGEN_STRONG_INLINE Packet2l pabs(const Packet2l& a) {
template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; }
template<> EIGEN_STRONG_INLINE Packet2f pfrexp<Packet2f>(const Packet2f& a, Packet2f& exponent)
{ return pfrexp_float(a,exponent); }
{ return pfrexp_generic(a,exponent); }
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent)
{ return pfrexp_float(a,exponent); }
{ return pfrexp_generic(a,exponent); }
template<> EIGEN_STRONG_INLINE Packet2f pldexp<Packet2f>(const Packet2f& a, const Packet2f& exponent)
{ return pldexp_float(a,exponent); }
{ return pldexp_generic(a,exponent); }
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent)
{ return pldexp_float(a,exponent); }
{ return pldexp_generic(a,exponent); }
template<> EIGEN_STRONG_INLINE float predux<Packet2f>(const Packet2f& a) { return vget_lane_f32(vpadd_f32(a,a), 0); }
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
@ -3907,10 +3907,10 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2
{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); }
template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent)
{ return pldexp_double(a, exponent); }
{ return pldexp_generic(a, exponent); }
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent)
{ return pfrexp_double(a,exponent); }
{ return pfrexp_generic(a,exponent); }
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from)
{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }

View File

@ -887,21 +887,24 @@ template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
}
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
return pfrexp_float(a,exponent);
return pfrexp_generic(a,exponent);
}
// Extract exponent without existence of Packet2l.
template<>
EIGEN_STRONG_INLINE
Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
const Packet2d cst_exp_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(0x7ff0000000000000ull));
__m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52);
return _mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3));
}
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent) {
const Packet2d cst_1022d = pset1<Packet2d>(1022.0);
const Packet2d cst_half = pset1<Packet2d>(0.5);
const Packet2d cst_exp_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(0x7ff0000000000000ull));
__m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52);
exponent = psub(_mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3)), cst_1022d);
const Packet2d cst_mant_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
return por(pand(a, cst_mant_mask), cst_half);
return pfrexp_generic(a, exponent);
}
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
return pldexp_float(a,exponent);
return pldexp_generic(a,exponent);
}
// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well

View File

@ -669,7 +669,7 @@ EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a)
template <>
EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent)
{
return pfrexp_float(a, exponent);
return pfrexp_generic(a, exponent);
}
template <>
@ -747,7 +747,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel)
template<>
EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent)
{
return pldexp_float(a, exponent);
return pldexp_generic(a, exponent);
}
} // namespace internal

View File

@ -567,7 +567,36 @@ void packetmath_real() {
data2[i] = Scalar(internal::random<double>(-87, 88));
}
CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp);
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
if (PacketTraits::HasExp) {
// Check denormals:
for (int j=0; j<3; ++j) {
data1[0] = Scalar(std::ldexp(1, std::numeric_limits<Scalar>::min_exponent-j));
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
data1[0] = -data1[0];
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
}
// zero
data1[0] = Scalar(0);
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
// inf and NaN only compare output fraction, not exponent.
test::packet_helper<PacketTraits::HasExp,Packet> h;
Packet pout;
Scalar sout;
Scalar special[] = { NumTraits<Scalar>::infinity(),
-NumTraits<Scalar>::infinity(),
NumTraits<Scalar>::quiet_NaN()};
for (int i=0; i<3; ++i) {
data1[0] = special[i];
ref[0] = Scalar(REF_FREXP(data1[0], ref[PacketSize]));
h.store(data2, internal::pfrexp(h.load(data1), h.forward_reference(pout, sout)));
VERIFY(test::areApprox(ref, data2, 1) && "internal::pfrexp");
}
}
for (int i = 0; i < PacketSize; ++i) {
data1[i] = Scalar(internal::random<double>(-1, 1));
data2[i] = Scalar(internal::random<double>(-1, 1));