mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 02:33:59 +08:00
Updated pfrexp implementation.
The original implementation fails for 0, denormals, inf, and NaN. See #2150
This commit is contained in:
parent
9ad4096ccb
commit
7ff0b7a980
@ -734,12 +734,13 @@ template<> EIGEN_STRONG_INLINE Packet4d pabs(const Packet4d& a)
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pfrexp<Packet8f>(const Packet8f& a, Packet8f& exponent) {
|
||||
return pfrexp_float(a,exponent);
|
||||
return pfrexp_generic(a,exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
|
||||
const Packet4d cst_1022d = pset1<Packet4d>(1022.0);
|
||||
const Packet4d cst_half = pset1<Packet4d>(0.5);
|
||||
// Extract exponent without existence of Packet4l.
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet4d pfrexp_generic_get_biased_exponent(const Packet4d& a) {
|
||||
const Packet4d cst_exp_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||
__m256i a_expo = _mm256_castpd_si256(pand(a, cst_exp_mask));
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
@ -754,15 +755,18 @@ template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Pack
|
||||
#endif
|
||||
Packet2d exponent_lo = _mm_cvtepi32_pd(vec4i_swizzle1(lo, 0, 2, 1, 3));
|
||||
Packet2d exponent_hi = _mm_cvtepi32_pd(vec4i_swizzle1(hi, 0, 2, 1, 3));
|
||||
exponent = _mm256_insertf128_pd(exponent, exponent_lo, 0);
|
||||
Packet4d exponent = _mm256_insertf128_pd(exponent, exponent_lo, 0);
|
||||
exponent = _mm256_insertf128_pd(exponent, exponent_hi, 1);
|
||||
exponent = psub(exponent, cst_1022d);
|
||||
const Packet4d cst_mant_mask = pset1frombits<Packet4d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
|
||||
return por(pand(a, cst_mant_mask), cst_half);
|
||||
return exponent;
|
||||
}
|
||||
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pfrexp<Packet4d>(const Packet4d& a, Packet4d& exponent) {
|
||||
return pfrexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pldexp<Packet8f>(const Packet8f& a, const Packet8f& exponent) {
|
||||
return pldexp_float(a,exponent);
|
||||
return pldexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pldexp<Packet4d>(const Packet4d& a, const Packet4d& exponent) {
|
||||
|
@ -895,25 +895,28 @@ EIGEN_STRONG_INLINE Packet8d pabs(const Packet8d& a) {
|
||||
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE Packet16f pfrexp<Packet16f>(const Packet16f& a, Packet16f& exponent){
|
||||
return pfrexp_float(a, exponent);
|
||||
return pfrexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
// Extract exponent without existence of Packet8l.
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet8d pfrexp_generic_get_biased_exponent(const Packet8d& a) {
|
||||
const Packet8d cst_exp_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||
#ifdef EIGEN_VECTORIZE_AVX512DQ
|
||||
return _mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52));
|
||||
#else
|
||||
return _mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(pand(a, cst_exp_mask)), 52)));
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE Packet8d pfrexp<Packet8d>(const Packet8d& a, Packet8d& exponent) {
|
||||
const Packet8d cst_1022d = pset1<Packet8d>(1022.0);
|
||||
#ifdef EIGEN_TEST_AVX512DQ
|
||||
exponent = psub(_mm512_cvtepi64_pd(_mm512_srli_epi64(_mm512_castpd_si512(a), 52)), cst_1022d);
|
||||
#else
|
||||
exponent = psub(_mm512_cvtepi32_pd(_mm512_cvtepi64_epi32(_mm512_srli_epi64(_mm512_castpd_si512(a), 52))),
|
||||
cst_1022d);
|
||||
#endif
|
||||
const Packet8d cst_half = pset1<Packet8d>(0.5);
|
||||
const Packet8d cst_inv_mant_mask = pset1frombits<Packet8d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
|
||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||
return pfrexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet16f pldexp<Packet16f>(const Packet16f& a, const Packet16f& exponent) {
|
||||
return pldexp_float(a,exponent);
|
||||
return pldexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8d pldexp<Packet8d>(const Packet8d& a, const Packet8d& exponent) {
|
||||
|
@ -1160,43 +1160,48 @@ template<> EIGEN_STRONG_INLINE Packet8bf pabs(const Packet8bf& a) {
|
||||
return pand<Packet8us>(p8us_abs_mask, a);
|
||||
}
|
||||
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(Packet4i a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4i parithmetic_shift_right(const Packet4i& a)
|
||||
{ return vec_sra(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(Packet4i a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_right(const Packet4i& a)
|
||||
{ return vec_sr(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(Packet4i a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4i plogical_shift_left(const Packet4i& a)
|
||||
{ return vec_sl(a,reinterpret_cast<Packet4ui>(pset1<Packet4i>(N))); }
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(Packet4f a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_left(const Packet4f& a)
|
||||
{
|
||||
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||
Packet4ui r = vec_sl(reinterpret_cast<Packet4ui>(a), p4ui_mask);
|
||||
return reinterpret_cast<Packet4f>(r);
|
||||
}
|
||||
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(Packet4f a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4f plogical_shift_right(const Packet4f& a)
|
||||
{
|
||||
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||
Packet4ui r = vec_sr(reinterpret_cast<Packet4ui>(a), p4ui_mask);
|
||||
return reinterpret_cast<Packet4f>(r);
|
||||
}
|
||||
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(Packet4ui a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_right(const Packet4ui& a)
|
||||
{
|
||||
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||
return vec_sr(a, p4ui_mask);
|
||||
}
|
||||
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(Packet4ui a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet4ui plogical_shift_left(const Packet4ui& a)
|
||||
{
|
||||
const _EIGEN_DECLARE_CONST_FAST_Packet4ui(mask, N);
|
||||
return vec_sl(a, p4ui_mask);
|
||||
}
|
||||
|
||||
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(Packet8us a)
|
||||
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_left(const Packet8us& a)
|
||||
{
|
||||
const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
|
||||
return vec_sl(a, p8us_mask);
|
||||
}
|
||||
template<int N> EIGEN_STRONG_INLINE Packet8us plogical_shift_right(const Packet8us& a)
|
||||
{
|
||||
const _EIGEN_DECLARE_CONST_FAST_Packet8us(mask, N);
|
||||
return vec_sr(a, p8us_mask);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE Packet4f Bf16ToF32Even(const Packet8bf& bf){
|
||||
return plogical_shift_left<16>(reinterpret_cast<Packet4f>(bf.m_val));
|
||||
@ -1323,14 +1328,14 @@ template<> EIGEN_STRONG_INLINE Packet8bf pexp<Packet8bf> (const Packet8bf& a){
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
|
||||
return pldexp_float(a,exponent);
|
||||
return pldexp_generic(a,exponent);
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet8bf pldexp<Packet8bf> (const Packet8bf& a, const Packet8bf& exponent){
|
||||
BF16_TO_F32_BINARY_OP_WRAPPER(pldexp_float, a, exponent);
|
||||
BF16_TO_F32_BINARY_OP_WRAPPER(pldexp<Packet4f>, a, exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
|
||||
return pfrexp_float(a,exponent);
|
||||
return pfrexp_generic(a,exponent);
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet8bf pfrexp<Packet8bf> (const Packet8bf& a, Packet8bf& e){
|
||||
Packet4f a_even = Bf16ToF32Even(a);
|
||||
@ -2324,6 +2329,11 @@ template<> EIGEN_STRONG_INLINE Packet2d pset1<Packet2d>(const double& from) {
|
||||
return v;
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(unsigned long from) {
|
||||
Packet2l v = {static_cast<long long>(from), static_cast<long long>(from)};
|
||||
return reinterpret_cast<Packet2d>(v);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE void
|
||||
pbroadcast4<Packet2d>(const double *a,
|
||||
Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
|
||||
@ -2439,7 +2449,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pabs(const Packet2d& a) { return vec_abs
|
||||
// a slow version that works with older compilers.
|
||||
// Update: apparently vec_cts/vec_ctf intrinsics for 64-bit doubles
|
||||
// are buggy, https://gcc.gnu.org/bugzilla/show_bug.cgi?id=70963
|
||||
static inline Packet2l ConvertToPacket2l(const Packet2d& x) {
|
||||
template<>
|
||||
inline Packet2l pcast<Packet2d, Packet2l>(const Packet2d& x) {
|
||||
#if EIGEN_GNUC_AT_LEAST(5, 4) || \
|
||||
(EIGEN_GNUC_AT(6, 1) && __GNUC_PATCHLEVEL__ >= 1)
|
||||
return vec_cts(x, 0); // TODO: check clang version.
|
||||
@ -2452,6 +2463,15 @@ static inline Packet2l ConvertToPacket2l(const Packet2d& x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
template<>
|
||||
inline Packet2d pcast<Packet2l, Packet2d>(const Packet2l& x) {
|
||||
unsigned long long tmp[2];
|
||||
memcpy(tmp, &x, sizeof(tmp));
|
||||
Packet2d d = { static_cast<double>(tmp[0]),
|
||||
static_cast<double>(tmp[1]) };
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
// Packet2l shifts.
|
||||
// For POWER8 we simply use vec_sr/l.
|
||||
@ -2569,7 +2589,7 @@ EIGEN_STRONG_INLINE Packet2l plogical_shift_right(const Packet2l& a) {
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent) {
|
||||
// Clamp exponent to [-2099, 2099]
|
||||
const Packet2d max_exponent = pset1<Packet2d>(2099.0);
|
||||
const Packet2l e = ConvertToPacket2l(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
||||
const Packet2l e = pcast<Packet2d, Packet2l>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
||||
|
||||
// Split 2^e into four factors and multiply:
|
||||
const Packet2l bias = { 1023, 1023 };
|
||||
@ -2582,14 +2602,16 @@ template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, cons
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
// Extract exponent without existence of Packet2l.
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
|
||||
return pcast<Packet2l, Packet2d>(plogical_shift_right<52>(reinterpret_cast<Packet2l>(pabs(a))));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d> (const Packet2d& a, Packet2d& exponent) {
|
||||
double exp[2] = { exponent[0], exponent[1] };
|
||||
Packet2d ret = { pfrexp<double>(a[0], exp[0]), pfrexp<double>(a[1], exp[1]) };
|
||||
exponent[0] = exp[0];
|
||||
exponent[1] = exp[1];
|
||||
return ret;
|
||||
// This doesn't currently work (no integer_packet for Packet2d - but adding it causes other problems)
|
||||
// return pfrexp_double(a, exponent);
|
||||
return pfrexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a)
|
||||
|
@ -25,80 +25,114 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */) {
|
||||
return pload<Packet>(a);
|
||||
}
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_float(const Packet& a, Packet& exponent) {
|
||||
// Creates a Scalar integer type with same bit-width.
|
||||
template<typename T> struct make_integer;
|
||||
template<> struct make_integer<float> { typedef numext::int32_t type; };
|
||||
template<> struct make_integer<double> { typedef numext::int64_t type; };
|
||||
template<> struct make_integer<half> { typedef numext::int16_t type; };
|
||||
template<> struct make_integer<bfloat16> { typedef numext::int16_t type; };
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic_get_biased_exponent(const Packet& a) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
const Packet cst_126f = pset1<Packet>(126.0f);
|
||||
const Packet cst_half = pset1<Packet>(0.5f);
|
||||
const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7f800000u);
|
||||
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<23>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_126f);
|
||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||
return pcast<PacketI, Packet>(plogical_shift_right<mantissa_bits>(preinterpret<PacketI>(pabs(a))));
|
||||
}
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_double(const Packet& a, Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
const Packet cst_1022d = pset1<Packet>(1022.0);
|
||||
const Packet cst_half = pset1<Packet>(0.5);
|
||||
const Packet cst_inv_mant_mask = pset1frombits<Packet, uint64_t>(static_cast<uint64_t>(~0x7ff0000000000000ull));
|
||||
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(pabs<Packet>(a)))), cst_1022d);
|
||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||
// Safely applies frexp, correctly handles denormals.
|
||||
// Assumes IEEE floating point format.
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic(const Packet& a, Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename make_unsigned<typename make_integer<Scalar>::type>::type ScalarUI;
|
||||
|
||||
EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
|
||||
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||
EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
|
||||
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_sign_mantissa_mask =
|
||||
~(((ScalarUI(1) << exponent_bits) - ScalarUI(1)) << mantissa_bits); // ~0x7f800000
|
||||
const Packet sign_mantissa_mask = pset1frombits<Packet>(static_cast<ScalarUI>(scalar_sign_mantissa_mask));
|
||||
const Packet half = pset1<Packet>(Scalar(0.5));
|
||||
const Packet zero = pzero(a);
|
||||
const Packet normal_min = pset1<Packet>((numext::numeric_limits<Scalar>::min)()); // Minimum normal value, 2^-126
|
||||
|
||||
// To handle denormals, normalize by multiplying by 2^(mantissa_bits+1).
|
||||
const Packet is_denormal = pcmp_lt(pabs(a), normal_min);
|
||||
EIGEN_CONSTEXPR ScalarUI scalar_normalization_offset = ScalarUI(mantissa_bits + 1); // 24
|
||||
// The following cannot be constexpr because bfloat16(uint16_t) is not constexpr.
|
||||
const Scalar scalar_normalization_factor = Scalar(ScalarUI(1) << int(scalar_normalization_offset)); // 2^24
|
||||
const Packet normalization_factor = pset1<Packet>(scalar_normalization_factor);
|
||||
const Packet normalized_a = pselect(is_denormal, pmul(a, normalization_factor), a);
|
||||
|
||||
// Determine exponent offset: -126 if normal, -126-24 if denormal
|
||||
const Scalar scalar_exponent_offset = -Scalar((ScalarUI(1)<<(exponent_bits-1)) - ScalarUI(2)); // -126
|
||||
Packet exponent_offset = pset1<Packet>(scalar_exponent_offset);
|
||||
const Packet normalization_offset = pset1<Packet>(-Scalar(scalar_normalization_offset)); // -24
|
||||
exponent_offset = pselect(is_denormal, padd(exponent_offset, normalization_offset), exponent_offset);
|
||||
|
||||
// Determine exponent and mantissa from normalized_a.
|
||||
exponent = pfrexp_generic_get_biased_exponent(normalized_a);
|
||||
// Zero, Inf and NaN return 'a' unmodified, exponent is zero
|
||||
// (technically the exponent is unspecified for inf/NaN, but GCC/Clang set it to zero)
|
||||
const Scalar scalar_non_finite_exponent = Scalar((ScalarUI(1) << exponent_bits) - ScalarUI(1)); // 255
|
||||
const Packet non_finite_exponent = pset1<Packet>(scalar_non_finite_exponent);
|
||||
const Packet is_zero_or_not_finite = por(pcmp_eq(a, zero), pcmp_eq(exponent, non_finite_exponent));
|
||||
const Packet m = pselect(is_zero_or_not_finite, a, por(pand(normalized_a, sign_mantissa_mask), half));
|
||||
exponent = pselect(is_zero_or_not_finite, zero, padd(exponent, exponent_offset));
|
||||
return m;
|
||||
}
|
||||
|
||||
// Safely applies ldexp, correctly handles overflows, underflows and denormals.
|
||||
// Assumes IEEE floating point format.
|
||||
template<typename Packet>
|
||||
struct pldexp_impl {
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pldexp_generic(const Packet& a, const Packet& exponent) {
|
||||
// We want to return a * 2^exponent, allowing for all possible integer
|
||||
// exponents without overflowing or underflowing in intermediate
|
||||
// computations.
|
||||
//
|
||||
// Since 'a' and the output can be denormal, the maximum range of 'exponent'
|
||||
// to consider for a float is:
|
||||
// -255-23 -> 255+23
|
||||
// Below -278 any finite float 'a' will become zero, and above +278 any
|
||||
// finite float will become inf, including when 'a' is the smallest possible
|
||||
// denormal.
|
||||
//
|
||||
// Unfortunately, 2^(278) cannot be represented using either one or two
|
||||
// finite normal floats, so we must split the scale factor into at least
|
||||
// three parts. It turns out to be faster to split 'exponent' into four
|
||||
// factors, since [exponent>>2] is much faster to compute that [exponent/3].
|
||||
//
|
||||
// Set e = min(max(exponent, -278), 278);
|
||||
// b = floor(e/4);
|
||||
// out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
|
||||
//
|
||||
// This will avoid any intermediate overflows and correctly handle 0, inf,
|
||||
// NaN cases.
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
typedef typename unpacket_traits<Packet>::type Scalar;
|
||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||
enum {
|
||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||
MantissaBits = std::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
||||
};
|
||||
|
||||
static EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet run(const Packet& a, const Packet& exponent) {
|
||||
// We want to return a * 2^exponent, allowing for all possible integer
|
||||
// exponents without overflowing or underflowing in intermediate
|
||||
// computations.
|
||||
//
|
||||
// Since 'a' and the output can be denormal, the maximum range of 'exponent'
|
||||
// to consider for a float is:
|
||||
// -255-23 -> 255+23
|
||||
// Below -278 any finite float 'a' will become zero, and above +278 any
|
||||
// finite float will become inf, including when 'a' is the smallest possible
|
||||
// denormal.
|
||||
//
|
||||
// Unfortunately, 2^(278) cannot be represented using either one or two
|
||||
// finite normal floats, so we must split the scale factor into at least
|
||||
// three parts. It turns out to be faster to split 'exponent' into four
|
||||
// factors, since [exponent>>2] is much faster to compute that [exponent/3].
|
||||
//
|
||||
// Set e = min(max(exponent, -278), 278);
|
||||
// b = floor(e/4);
|
||||
// out = ((((a * 2^(b)) * 2^(b)) * 2^(b)) * 2^(e-3*b))
|
||||
//
|
||||
// This will avoid any intermediate overflows and correctly handle 0, inf,
|
||||
// NaN cases.
|
||||
const Packet max_exponent = pset1<Packet>(Scalar( (ScalarI(1)<<int(ExponentBits)) + ScalarI(MantissaBits) - ScalarI(1))); // 278
|
||||
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(int(ExponentBits)-1)) - ScalarI(1)); // 127
|
||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
||||
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
|
||||
Packet c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^b
|
||||
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
|
||||
b = psub(psub(psub(e, b), b), b); // e - 3b
|
||||
c = preinterpret<Packet>(plogical_shift_left<int(MantissaBits)>(padd(b, bias))); // 2^(e-3*b)
|
||||
out = pmul(out, c);
|
||||
return out;
|
||||
}
|
||||
};
|
||||
EIGEN_CONSTEXPR int total_bits = sizeof(Scalar) * CHAR_BIT;
|
||||
EIGEN_CONSTEXPR int mantissa_bits = numext::numeric_limits<Scalar>::digits - 1;
|
||||
EIGEN_CONSTEXPR int exponent_bits = total_bits - mantissa_bits - 1;
|
||||
|
||||
const Packet max_exponent = pset1<Packet>(Scalar((ScalarI(1)<<exponent_bits) + ScalarI(mantissa_bits - 1))); // 278
|
||||
const PacketI bias = pset1<PacketI>((ScalarI(1)<<(exponent_bits-1)) - ScalarI(1)); // 127
|
||||
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
|
||||
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
|
||||
Packet c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^b
|
||||
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b)
|
||||
b = psub(psub(psub(e, b), b), b); // e - 3b
|
||||
c = preinterpret<Packet>(plogical_shift_left<mantissa_bits>(padd(b, bias))); // 2^(e-3*b)
|
||||
out = pmul(out, c);
|
||||
return out;
|
||||
}
|
||||
|
||||
// Explicitly multiplies
|
||||
// a * (2^e)
|
||||
// clamping e to the range
|
||||
// [std::numeric_limits<Scalar>::min_exponent-2, std::numeric_limits<Scalar>::max_exponent]
|
||||
// [numeric_limits<Scalar>::min_exponent-2, numeric_limits<Scalar>::max_exponent]
|
||||
//
|
||||
// This is approx 7x faster than pldexp_impl, but will prematurely over/underflow
|
||||
// if 2^e doesn't fit into a normal floating-point Scalar.
|
||||
@ -111,7 +145,7 @@ struct pldexp_fast_impl {
|
||||
typedef typename unpacket_traits<PacketI>::type ScalarI;
|
||||
enum {
|
||||
TotalBits = sizeof(Scalar) * CHAR_BIT,
|
||||
MantissaBits = std::numeric_limits<Scalar>::digits - 1,
|
||||
MantissaBits = numext::numeric_limits<Scalar>::digits - 1,
|
||||
ExponentBits = int(TotalBits) - int(MantissaBits) - 1
|
||||
};
|
||||
|
||||
@ -126,14 +160,6 @@ struct pldexp_fast_impl {
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_float(const Packet& a, const Packet& exponent)
|
||||
{ return pldexp_impl<Packet>::run(a, exponent); }
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_double(const Packet& a, const Packet& exponent)
|
||||
{ return pldexp_impl<Packet>::run(a, exponent); }
|
||||
|
||||
// Natural or base 2 logarithm.
|
||||
// Computes log(x) as log(2^e * m) = C*e + log(m), where the constant C =log(2)
|
||||
// and m is in the range [sqrt(1/2),sqrt(2)). In this range, the logarithm can
|
||||
|
@ -25,29 +25,23 @@ pset(const typename unpacket_traits<Packet>::type (&a)[N] /* a */);
|
||||
* Some generic implementations to be used by implementors
|
||||
***************************************************************************/
|
||||
|
||||
/** Default implementation of pfrexp for float.
|
||||
/** Default implementation of pfrexp.
|
||||
* It is expected to be called by implementers of template<> pfrexp.
|
||||
*/
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_float(const Packet& a, Packet& exponent);
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic(const Packet& a, Packet& exponent);
|
||||
|
||||
/** Default implementation of pfrexp for double.
|
||||
* It is expected to be called by implementers of template<> pfrexp.
|
||||
*/
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_double(const Packet& a, Packet& exponent);
|
||||
// Extracts the biased exponent value from Packet p, and casts the results to
|
||||
// a floating-point Packet type. Used by pfrexp_generic. Override this if
|
||||
// there is no unpacket_traits<Packet>::integer_packet.
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pfrexp_generic_get_biased_exponent(const Packet& p);
|
||||
|
||||
/** Default implementation of pldexp for float.
|
||||
/** Default implementation of pldexp.
|
||||
* It is expected to be called by implementers of template<> pldexp.
|
||||
*/
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_float(const Packet& a, const Packet& exponent);
|
||||
|
||||
/** Default implementation of pldexp for double.
|
||||
* It is expected to be called by implementers of template<> pldexp.
|
||||
*/
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_double(const Packet& a, const Packet& exponent);
|
||||
template<typename Packet> EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
|
||||
Packet pldexp_generic(const Packet& a, const Packet& exponent);
|
||||
|
||||
/** \internal \returns log(x) for single precision float */
|
||||
template <typename Packet>
|
||||
|
@ -2402,14 +2402,14 @@ template<> EIGEN_STRONG_INLINE Packet2l pabs(const Packet2l& a) {
|
||||
template<> EIGEN_STRONG_INLINE Packet2ul pabs(const Packet2ul& a) { return a; }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2f pfrexp<Packet2f>(const Packet2f& a, Packet2f& exponent)
|
||||
{ return pfrexp_float(a,exponent); }
|
||||
{ return pfrexp_generic(a,exponent); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent)
|
||||
{ return pfrexp_float(a,exponent); }
|
||||
{ return pfrexp_generic(a,exponent); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2f pldexp<Packet2f>(const Packet2f& a, const Packet2f& exponent)
|
||||
{ return pldexp_float(a,exponent); }
|
||||
{ return pldexp_generic(a,exponent); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent)
|
||||
{ return pldexp_float(a,exponent); }
|
||||
{ return pldexp_generic(a,exponent); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE float predux<Packet2f>(const Packet2f& a) { return vget_lane_f32(vpadd_f32(a,a), 0); }
|
||||
template<> EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a)
|
||||
@ -3907,10 +3907,10 @@ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet2d pselect( const Packet2
|
||||
{ return vbslq_f64(vreinterpretq_u64_f64(mask), a, b); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent)
|
||||
{ return pldexp_double(a, exponent); }
|
||||
{ return pldexp_generic(a, exponent); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent)
|
||||
{ return pfrexp_double(a,exponent); }
|
||||
{ return pfrexp_generic(a,exponent); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(uint64_t from)
|
||||
{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
|
||||
|
@ -887,21 +887,24 @@ template<> EIGEN_STRONG_INLINE Packet4i pabs(const Packet4i& a)
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Packet4f& exponent) {
|
||||
return pfrexp_float(a,exponent);
|
||||
return pfrexp_generic(a,exponent);
|
||||
}
|
||||
|
||||
// Extract exponent without existence of Packet2l.
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE
|
||||
Packet2d pfrexp_generic_get_biased_exponent(const Packet2d& a) {
|
||||
const Packet2d cst_exp_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||
__m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52);
|
||||
return _mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent) {
|
||||
const Packet2d cst_1022d = pset1<Packet2d>(1022.0);
|
||||
const Packet2d cst_half = pset1<Packet2d>(0.5);
|
||||
const Packet2d cst_exp_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(0x7ff0000000000000ull));
|
||||
__m128i a_expo = _mm_srli_epi64(_mm_castpd_si128(pand(a, cst_exp_mask)), 52);
|
||||
exponent = psub(_mm_cvtepi32_pd(vec4i_swizzle1(a_expo, 0, 2, 1, 3)), cst_1022d);
|
||||
const Packet2d cst_mant_mask = pset1frombits<Packet2d>(static_cast<uint64_t>(~0x7ff0000000000000ull));
|
||||
return por(pand(a, cst_mant_mask), cst_half);
|
||||
return pfrexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
|
||||
return pldexp_float(a,exponent);
|
||||
return pldexp_generic(a,exponent);
|
||||
}
|
||||
|
||||
// We specialize pldexp here, since the generic implementation uses Packet2l, which is not well
|
||||
|
@ -669,7 +669,7 @@ EIGEN_STRONG_INLINE PacketXf pabs(const PacketXf& a)
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE PacketXf pfrexp<PacketXf>(const PacketXf& a, PacketXf& exponent)
|
||||
{
|
||||
return pfrexp_float(a, exponent);
|
||||
return pfrexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -747,7 +747,7 @@ EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<PacketXf, N>& kernel)
|
||||
template<>
|
||||
EIGEN_STRONG_INLINE PacketXf pldexp<PacketXf>(const PacketXf& a, const PacketXf& exponent)
|
||||
{
|
||||
return pldexp_float(a, exponent);
|
||||
return pldexp_generic(a, exponent);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
|
@ -567,7 +567,36 @@ void packetmath_real() {
|
||||
data2[i] = Scalar(internal::random<double>(-87, 88));
|
||||
}
|
||||
CHECK_CWISE1_IF(PacketTraits::HasExp, std::exp, internal::pexp);
|
||||
|
||||
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
|
||||
if (PacketTraits::HasExp) {
|
||||
// Check denormals:
|
||||
for (int j=0; j<3; ++j) {
|
||||
data1[0] = Scalar(std::ldexp(1, std::numeric_limits<Scalar>::min_exponent-j));
|
||||
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
|
||||
data1[0] = -data1[0];
|
||||
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
|
||||
}
|
||||
|
||||
// zero
|
||||
data1[0] = Scalar(0);
|
||||
CHECK_CWISE1_BYREF1_IF(PacketTraits::HasExp, REF_FREXP, internal::pfrexp);
|
||||
|
||||
// inf and NaN only compare output fraction, not exponent.
|
||||
test::packet_helper<PacketTraits::HasExp,Packet> h;
|
||||
Packet pout;
|
||||
Scalar sout;
|
||||
Scalar special[] = { NumTraits<Scalar>::infinity(),
|
||||
-NumTraits<Scalar>::infinity(),
|
||||
NumTraits<Scalar>::quiet_NaN()};
|
||||
for (int i=0; i<3; ++i) {
|
||||
data1[0] = special[i];
|
||||
ref[0] = Scalar(REF_FREXP(data1[0], ref[PacketSize]));
|
||||
h.store(data2, internal::pfrexp(h.load(data1), h.forward_reference(pout, sout)));
|
||||
VERIFY(test::areApprox(ref, data2, 1) && "internal::pfrexp");
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < PacketSize; ++i) {
|
||||
data1[i] = Scalar(internal::random<double>(-1, 1));
|
||||
data2[i] = Scalar(internal::random<double>(-1, 1));
|
||||
|
Loading…
x
Reference in New Issue
Block a user