mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-19 16:19:37 +08:00
Add plog ops support packet2d for NEON
This commit is contained in:
parent
e4fb0ddf78
commit
3012e755e9
@ -29,6 +29,16 @@ pfrexp_float(const Packet& a, Packet& exponent) {
|
||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||
}
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_double(const Packet& a, Packet& exponent) {
|
||||
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
|
||||
const Packet cst_1022d = pset1<Packet>(1022.0);
|
||||
const Packet cst_half = pset1<Packet>(0.5);
|
||||
const Packet cst_inv_mant_mask = pset1frombits<Packet>(~0x7ff0000000000000u);
|
||||
exponent = psub(pcast<PacketI,Packet>(plogical_shift_right<52>(preinterpret<PacketI>(a))), cst_1022d);
|
||||
return por(pand(a, cst_inv_mant_mask), cst_half);
|
||||
}
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_float(Packet a, Packet exponent)
|
||||
{
|
||||
@ -139,6 +149,114 @@ Packet plog_float(const Packet _x)
|
||||
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
|
||||
}
|
||||
|
||||
|
||||
/* Returns the base e (2.718...) logarithm of x.
|
||||
* The argument is separated into its exponent and fractional
|
||||
* parts. If the exponent is between -1 and +1, the logarithm
|
||||
* of the fraction is approximated by
|
||||
*
|
||||
* log(1+x) = x - 0.5 x**2 + x**3 P(x)/Q(x).
|
||||
*
|
||||
* Otherwise, setting z = 2(x-1)/x+1),
|
||||
* log(x) = z + z**3 P(z)/Q(z).
|
||||
*
|
||||
* for more detail see: http://www.netlib.org/cephes/
|
||||
*/
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet plog_double(const Packet _x)
|
||||
{
|
||||
Packet x = _x;
|
||||
|
||||
const Packet cst_1 = pset1<Packet>(1.0);
|
||||
const Packet cst_half = pset1<Packet>(0.5);
|
||||
// The smallest non denormalized float number.
|
||||
const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x0010000000000000u);
|
||||
const Packet cst_minus_inf = pset1frombits<Packet>( 0xfff0000000000000u);
|
||||
const Packet cst_pos_inf = pset1frombits<Packet>( 0x7ff0000000000000u);
|
||||
|
||||
// Polynomial Coefficients for log(1+x) = x - x**2/2 + x**3 P(x)/Q(x)
|
||||
// 1/sqrt(2) <= x < sqrt(2)
|
||||
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.70710678118654752440E0);
|
||||
const Packet cst_cephes_log_p0 = pset1<Packet>(1.01875663804580931796E-4);
|
||||
const Packet cst_cephes_log_p1 = pset1<Packet>(4.97494994976747001425E-1);
|
||||
const Packet cst_cephes_log_p2 = pset1<Packet>(4.70579119878881725854E0);
|
||||
const Packet cst_cephes_log_p3 = pset1<Packet>(1.44989225341610930846E1);
|
||||
const Packet cst_cephes_log_p4 = pset1<Packet>(1.79368678507819816313E1);
|
||||
const Packet cst_cephes_log_p5 = pset1<Packet>(7.70838733755885391666E0);
|
||||
|
||||
const Packet cst_cephes_log_r0 = pset1<Packet>(1.0);
|
||||
const Packet cst_cephes_log_r1 = pset1<Packet>(1.12873587189167450590E1);
|
||||
const Packet cst_cephes_log_r2 = pset1<Packet>(4.52279145837532221105E1);
|
||||
const Packet cst_cephes_log_r3 = pset1<Packet>(8.29875266912776603211E1);
|
||||
const Packet cst_cephes_log_r4 = pset1<Packet>(7.11544750618563894466E1);
|
||||
const Packet cst_cephes_log_r5 = pset1<Packet>(2.31251620126765340583E1);
|
||||
|
||||
const Packet cst_cephes_log_q1 = pset1<Packet>(-2.121944400546905827679e-4);
|
||||
const Packet cst_cephes_log_q2 = pset1<Packet>(0.693359375);
|
||||
|
||||
// Truncate input values to the minimum positive normal.
|
||||
x = pmax(x, cst_min_norm_pos);
|
||||
|
||||
Packet e;
|
||||
// extract significant in the range [0.5,1) and exponent
|
||||
x = pfrexp(x,e);
|
||||
|
||||
// Shift the inputs from the range [0.5,1) to [sqrt(1/2),sqrt(2))
|
||||
// and shift by -1. The values are then centered around 0, which improves
|
||||
// the stability of the polynomial evaluation.
|
||||
// if( x < SQRTHF ) {
|
||||
// e -= 1;
|
||||
// x = x + x - 1.0;
|
||||
// } else { x = x - 1.0; }
|
||||
Packet mask = pcmp_lt(x, cst_cephes_SQRTHF);
|
||||
Packet tmp = pand(x, mask);
|
||||
x = psub(x, cst_1);
|
||||
e = psub(e, pand(cst_1, mask));
|
||||
x = padd(x, tmp);
|
||||
|
||||
Packet x2 = pmul(x, x);
|
||||
Packet x3 = pmul(x2, x);
|
||||
|
||||
// Evaluate the polynomial approximant , probably to improve instruction-level parallelism.
|
||||
// y = x * ( z * polevl( x, P, 5 ) / p1evl( x, Q, 5 ) );
|
||||
Packet y, y1, y2,y_;
|
||||
y = pmadd(cst_cephes_log_p0, x, cst_cephes_log_p1);
|
||||
y1 = pmadd(cst_cephes_log_p3, x, cst_cephes_log_p4);
|
||||
y = pmadd(y, x, cst_cephes_log_p2);
|
||||
y1 = pmadd(y1, x, cst_cephes_log_p5);
|
||||
y_ = pmadd(y, x3, y1);
|
||||
|
||||
y = pmadd(cst_cephes_log_r0, x, cst_cephes_log_r1);
|
||||
y1 = pmadd(cst_cephes_log_r3, x, cst_cephes_log_r4);
|
||||
y = pmadd(y, x, cst_cephes_log_r2);
|
||||
y1 = pmadd(y1, x, cst_cephes_log_r5);
|
||||
y = pmadd(y, x3, y1);
|
||||
|
||||
y_ = pmul(y_, x3);
|
||||
y = pdiv(y_, y);
|
||||
|
||||
// Add the logarithm of the exponent back to the result of the interpolation.
|
||||
y1 = pmul(e, cst_cephes_log_q1);
|
||||
tmp = pmul(x2, cst_half);
|
||||
y = padd(y, y1);
|
||||
x = psub(x, tmp);
|
||||
y2 = pmul(e, cst_cephes_log_q2);
|
||||
x = padd(x, y);
|
||||
x = padd(x, y2);
|
||||
|
||||
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
|
||||
Packet iszero_mask = pcmp_eq(_x,pzero(_x));
|
||||
Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
|
||||
// Filter out invalid inputs, i.e.:
|
||||
// - negative arg will be NAN
|
||||
// - 0 will be -INF
|
||||
// - +INF will be +INF
|
||||
return pselect(iszero_mask, cst_minus_inf,
|
||||
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
|
||||
}
|
||||
|
||||
/** \internal \returns log(1 + x) computed using W. Kahan's formula.
|
||||
See: http://www.plunk.org/~hatch/rightway.php
|
||||
*/
|
||||
|
@ -20,6 +20,9 @@ namespace internal {
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_float(const Packet& a, Packet& exponent);
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pfrexp_double(const Packet& a, Packet& exponent);
|
||||
|
||||
template<typename Packet> EIGEN_STRONG_INLINE Packet
|
||||
pldexp_float(Packet a, Packet exponent);
|
||||
|
||||
@ -29,6 +32,12 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet plog_float(const Packet _x);
|
||||
|
||||
/** \internal \returns log(x) for single precision float */
|
||||
template <typename Packet>
|
||||
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||
EIGEN_UNUSED
|
||||
Packet plog_double(const Packet _x);
|
||||
|
||||
/** \internal \returns log(1 + x) */
|
||||
template<typename Packet>
|
||||
Packet generic_plog1p(const Packet& x);
|
||||
|
@ -51,6 +51,9 @@ BF16_PACKET_FUNCTION(Packet4f, Packet4bf, ptanh)
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d pexp<Packet2d>(const Packet2d& x)
|
||||
{ return pexp_double(x); }
|
||||
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet2d plog<Packet2d>(const Packet2d& x)
|
||||
{ return plog_double(x); }
|
||||
|
||||
#endif
|
||||
|
||||
} // end namespace internal
|
||||
|
@ -3579,7 +3579,7 @@ template<> struct packet_traits<double> : default_packet_traits
|
||||
|
||||
HasSin = 0,
|
||||
HasCos = 0,
|
||||
HasLog = 0,
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasTanh = 0,
|
||||
@ -3753,6 +3753,12 @@ template<> EIGEN_DEVICE_FUNC inline Packet2d pselect( const Packet2d& mask, cons
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pldexp<Packet2d>(const Packet2d& a, const Packet2d& exponent)
|
||||
{ return pldexp_double(a, exponent); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent)
|
||||
{ return pfrexp_double(a,exponent); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(unsigned long from)
|
||||
{ return vreinterpretq_f64_u64(vdupq_n_u64(from)); }
|
||||
|
||||
#if EIGEN_FAST_MATH
|
||||
|
||||
// Functions for sqrt support packet2d.
|
||||
|
@ -24,6 +24,11 @@ Packet4f plog<Packet4f>(const Packet4f& _x) {
|
||||
return plog_float(_x);
|
||||
}
|
||||
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||
Packet2d plog<Packet2d>(const Packet2d& _x) {
|
||||
return plog_double(_x);
|
||||
}
|
||||
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
|
||||
Packet4f plog1p<Packet4f>(const Packet4f& _x) {
|
||||
return generic_plog1p(_x);
|
||||
|
@ -132,6 +132,7 @@ struct packet_traits<double> : default_packet_traits {
|
||||
|
||||
HasCmp = 1,
|
||||
HasDiv = 1,
|
||||
HasLog = 1,
|
||||
HasExp = 1,
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
@ -227,6 +228,7 @@ template<> EIGEN_STRONG_INLINE Packet4i pset1<Packet4i>(const int& from) { re
|
||||
template<> EIGEN_STRONG_INLINE Packet16b pset1<Packet16b>(const bool& from) { return _mm_set1_epi8(static_cast<char>(from)); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pset1frombits<Packet4f>(unsigned int from) { return _mm_castsi128_ps(pset1<Packet4i>(from)); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pset1frombits<Packet2d>(unsigned long from) { return _mm_castsi128_pd(_mm_set1_epi64x(from)); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pzero(const Packet4f& /*a*/) { return _mm_setzero_ps(); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pzero(const Packet2d& /*a*/) { return _mm_setzero_pd(); }
|
||||
@ -753,6 +755,10 @@ template<> EIGEN_STRONG_INLINE Packet4f pfrexp<Packet4f>(const Packet4f& a, Pack
|
||||
return pfrexp_float(a,exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pfrexp<Packet2d>(const Packet2d& a, Packet2d& exponent) {
|
||||
return pfrexp_double(a,exponent);
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pldexp<Packet4f>(const Packet4f& a, const Packet4f& exponent) {
|
||||
return pldexp_float(a,exponent);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user