mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 12:46:00 +08:00
Improve accuracy of fast approximate tanh and the logistic functions in Eigen, such that they preserve relative accuracy to within a few ULPs where their function values tend to zero (around x=0 for tanh, and for large negative x for the logistic function).
This change re-instates the fast rational approximation of the logistic function for float32 in Eigen (removed in 66f07efeae
), but uses the more accurate approximation 1/(1+exp(-1)) ~= exp(x) below -9. The exponential is only calculated on the vectorized path if at least one element in the SIMD input vector is less than -9.
This change also contains a few improvements to speed up the original float specialization of logistic:
- Introduce EIGEN_PREDICT_{FALSE,TRUE} for __builtin_predict and use it to predict that the logistic-only path is most likely (~2-3% speedup for the common case).
- Carefully set the upper clipping point to the smallest x where the approximation evaluates to exactly 1. This saves the explicit clamping of the output (~7% speedup).
The increased accuracy for tanh comes at a cost of 10-20% depending on instruction set.
The benchmarks below repeated calls
u = v.logistic() (u = v.tanh(), respectively)
where u and v are of type Eigen::ArrayXf, have length 8k, and v contains random numbers in [-1,1].
Benchmark numbers for logistic:
Before:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_logistic_float 4467 4468 155835 model_time: 4827
AVX
BM_eigen_logistic_float 2347 2347 299135 model_time: 2926
AVX+FMA
BM_eigen_logistic_float 1467 1467 476143 model_time: 2926
AVX512
BM_eigen_logistic_float 805 805 858696 model_time: 1463
After:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_logistic_float 2589 2590 270264 model_time: 4827
AVX
BM_eigen_logistic_float 1428 1428 489265 model_time: 2926
AVX+FMA
BM_eigen_logistic_float 1059 1059 662255 model_time: 2926
AVX512
BM_eigen_logistic_float 673 673 1000000 model_time: 1463
Benchmark numbers for tanh:
Before:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_tanh_float 2391 2391 292624 model_time: 4242
AVX
BM_eigen_tanh_float 1256 1256 554662 model_time: 2633
AVX+FMA
BM_eigen_tanh_float 823 823 866267 model_time: 1609
AVX512
BM_eigen_tanh_float 443 443 1578999 model_time: 805
After:
Benchmark Time(ns) CPU(ns) Iterations
-----------------------------------------------------------------
SSE
BM_eigen_tanh_float 2588 2588 273531 model_time: 4242
AVX
BM_eigen_tanh_float 1536 1536 452321 model_time: 2633
AVX+FMA
BM_eigen_tanh_float 1007 1007 694681 model_time: 1609
AVX512
BM_eigen_tanh_float 471 471 1472178 model_time: 805
This commit is contained in:
parent
8e5da71466
commit
a566074480
@ -18,9 +18,10 @@ namespace internal {
|
||||
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
|
||||
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
|
||||
is accurate up to a couple of ulps in the (approximate) range [-8, 8],
|
||||
outside of which tanh(x) = +/-1 in single precision. This is done by
|
||||
Clamp the inputs to the range [-c, c]. The value c is chosen as the smallest
|
||||
value where the approximation evaluates to exactly 1.
|
||||
outside of which tanh(x) = +/-1 in single precision. The input is clamped
|
||||
to the range [-c, c]. The value c is chosen as the smallest value where
|
||||
the approximation evaluates to exactly 1. In the reange [-0.0004, 0.0004]
|
||||
the approxmation tanh(x) ~= x is used for better accuracy as x tends to zero.
|
||||
|
||||
This implementation works on both scalars and packets.
|
||||
*/
|
||||
@ -29,13 +30,15 @@ T generic_fast_tanh_float(const T& a_x)
|
||||
{
|
||||
// Clamp the inputs to the range [-c, c]
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
const T plus_clamp = pset1<T>(7.99881172180175781);
|
||||
const T minus_clamp = pset1<T>(-7.99881172180175781);
|
||||
const T plus_clamp = pset1<T>(7.99881172180175781f);
|
||||
const T minus_clamp = pset1<T>(-7.99881172180175781f);
|
||||
#else
|
||||
const T plus_clamp = pset1<T>(7.90531110763549805);
|
||||
const T minus_clamp = pset1<T>(-7.90531110763549805);
|
||||
const T plus_clamp = pset1<T>(7.90531110763549805f);
|
||||
const T minus_clamp = pset1<T>(-7.90531110763549805f);
|
||||
#endif
|
||||
const T tiny = pset1<T>(0.0004f);
|
||||
const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
|
||||
const T tiny_mask = pcmp_lt(pabs(a_x), tiny);
|
||||
// The monomial coefficients of the numerator polynomial (odd).
|
||||
const T alpha_1 = pset1<T>(4.89352455891786e-03f);
|
||||
const T alpha_3 = pset1<T>(6.37261928875436e-04f);
|
||||
@ -63,13 +66,13 @@ T generic_fast_tanh_float(const T& a_x)
|
||||
p = pmadd(x2, p, alpha_1);
|
||||
p = pmul(x, p);
|
||||
|
||||
// Evaluate the denominator polynomial p.
|
||||
// Evaluate the denominator polynomial q.
|
||||
T q = pmadd(x2, beta_6, beta_4);
|
||||
q = pmadd(x2, q, beta_2);
|
||||
q = pmadd(x2, q, beta_0);
|
||||
|
||||
// Divide the numerator by the denominator.
|
||||
return pdiv(p, q);
|
||||
return pselect(tiny_mask, x, pdiv(p, q));
|
||||
}
|
||||
|
||||
template<typename RealScalar>
|
||||
|
@ -297,9 +297,14 @@ template<> EIGEN_STRONG_INLINE Packet4d pmax<Packet4d>(const Packet4d& a, const
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
|
||||
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) {
|
||||
#ifdef EIGEN_VECTORIZE_AVX2
|
||||
@ -311,6 +316,7 @@ template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet8f pceil<Packet8f>(const Packet8f& a) { return _mm256_ceil_ps(a); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { return _mm256_ceil_pd(a); }
|
||||
|
||||
|
@ -365,6 +365,12 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
|
||||
return _mm512_castsi512_ps(
|
||||
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
|
||||
}
|
||||
template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
|
||||
return _mm512_castsi512_ps(
|
||||
@ -388,12 +394,6 @@ template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packe
|
||||
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
|
||||
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
|
||||
return _mm512_castsi512_ps(
|
||||
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
|
||||
@ -401,6 +401,24 @@ EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
|
||||
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGT_UQ);
|
||||
return _mm512_castsi512_pd(
|
||||
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {
|
||||
|
@ -156,6 +156,15 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double eq_mask(const double& a,
|
||||
return __longlong_as_double(a == b ? 0xffffffffffffffffull : 0ull);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a,
|
||||
const float& b) {
|
||||
return __int_as_float(a < b ? 0xffffffffu : 0u);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a,
|
||||
const double& b) {
|
||||
return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <>
|
||||
@ -213,10 +222,21 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_eq<float4>(const float4& a,
|
||||
eq_mask(a.w, b.w));
|
||||
}
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt<float4>(const float4& a,
|
||||
const float4& b) {
|
||||
return make_float4(lt_mask(a.x, b.x), lt_mask(a.y, b.y), lt_mask(a.z, b.z),
|
||||
lt_mask(a.w, b.w));
|
||||
}
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
|
||||
pcmp_eq<double2>(const double2& a, const double2& b) {
|
||||
return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y));
|
||||
}
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
|
||||
pcmp_lt<double2>(const double2& a, const double2& b) {
|
||||
return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y));
|
||||
}
|
||||
#endif // EIGEN_CUDA_ARCH || defined(EIGEN_HIP_DEVICE_COMPILE)
|
||||
|
||||
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
|
||||
@ -646,6 +666,20 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq<half2>(const half2& a,
|
||||
return __halves2half2(eq1, eq2);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt<half2>(const half2& a,
|
||||
const half2& b) {
|
||||
half true_half = half_impl::raw_uint16_to_half(0xffffu);
|
||||
half false_half = half_impl::raw_uint16_to_half(0x0000u);
|
||||
half a1 = __low2half(a);
|
||||
half a2 = __high2half(a);
|
||||
half b1 = __low2half(b);
|
||||
half b2 = __high2half(b);
|
||||
half eq1 = __half2float(a1) < __half2float(b1) ? true_half : false_half;
|
||||
half eq2 = __half2float(a2) < __half2float(b2) ? true_half : false_half;
|
||||
return __halves2half2(eq1, eq2);
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand<half2>(const half2& a,
|
||||
const half2& b) {
|
||||
|
@ -713,6 +713,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, con
|
||||
return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
|
||||
}
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vcleq_f64(a,b)); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vcltq_f64(a,b)); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vceqq_f64(a,b)); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); }
|
||||
|
@ -384,10 +384,17 @@ template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return _mm_cmple_ps(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return _mm_cmplt_ps(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { return _mm_cmpnge_ps(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return _mm_cmple_pd(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return _mm_cmplt_pd(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { return _mm_cmpnge_pd(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return _mm_cmplt_epi32(a,b); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
|
||||
|
||||
|
||||
template<> EIGEN_STRONG_INLINE Packet4i ptrue<Packet4i>(const Packet4i& a) { return _mm_cmpeq_epi32(a, a); }
|
||||
template<> EIGEN_STRONG_INLINE Packet4f
|
||||
|
@ -905,14 +905,106 @@ struct scalar_logistic_op {
|
||||
}
|
||||
};
|
||||
|
||||
/** \internal
|
||||
* \brief Template specialization of the logistic function for float.
|
||||
*
|
||||
* Uses just a 9/10-degree rational interpolant which
|
||||
* interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range
|
||||
* [-9, 18]. Below -9 we use the more accurate approximation
|
||||
* 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 withing
|
||||
* one ulp. The shifted logistic is interpolated because it was easier to
|
||||
* make the fit converge.
|
||||
*
|
||||
*/
|
||||
template <>
|
||||
struct scalar_logistic_op<float> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
|
||||
// The upper cut-off is the smallest x for which the rational approximation evaluates to 1.
|
||||
// Choosing this value saves us a few instructions clamping the results at the end.
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
const float cutoff_upper = 16.285715103149414062f;
|
||||
#else
|
||||
const float cutoff_upper = 16.619047164916992188f;
|
||||
#endif
|
||||
const float cutoff_lower = -9.f;
|
||||
if (x > cutoff_upper) return 1.0f;
|
||||
else if (x < cutoff_lower) return numext::exp(x);
|
||||
else return 1.0f / (1.0f + numext::exp(-x));
|
||||
}
|
||||
|
||||
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Packet packetOp(const Packet& _x) const {
|
||||
const Packet cutoff_lower = pset1<Packet>(-9.f);
|
||||
const Packet lt_mask = pcmp_lt<Packet>(_x, cutoff_lower);
|
||||
const bool any_small = predux(lt_mask);
|
||||
|
||||
// Clamp the input to be at most 'cutoff_upper'.
|
||||
#ifdef EIGEN_VECTORIZE_FMA
|
||||
const Packet cutoff_upper = pset1<Packet>(16.285715103149414062f);
|
||||
#else
|
||||
const Packet cutoff_upper = pset1<Packet>(16.619047164916992188f);
|
||||
#endif
|
||||
const Packet x = pmin(_x, cutoff_upper);
|
||||
|
||||
// The monomial coefficients of the numerator polynomial (odd).
|
||||
const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01f);
|
||||
const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03f);
|
||||
const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05f);
|
||||
const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07f);
|
||||
const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11f);
|
||||
|
||||
// The monomial coefficients of the denominator polynomial (even).
|
||||
const Packet beta_0 = pset1<Packet>(9.93151921023180e-01f);
|
||||
const Packet beta_2 = pset1<Packet>(1.16817656904453e-01f);
|
||||
const Packet beta_4 = pset1<Packet>(1.70198817374094e-03f);
|
||||
const Packet beta_6 = pset1<Packet>(6.29106785017040e-06f);
|
||||
const Packet beta_8 = pset1<Packet>(5.76102136993427e-09f);
|
||||
const Packet beta_10 = pset1<Packet>(6.10247389755681e-13f);
|
||||
|
||||
// Since the polynomials are odd/even, we need x^2.
|
||||
const Packet x2 = pmul(x, x);
|
||||
|
||||
// Evaluate the numerator polynomial p.
|
||||
Packet p = pmadd(x2, alpha_9, alpha_7);
|
||||
p = pmadd(x2, p, alpha_5);
|
||||
p = pmadd(x2, p, alpha_3);
|
||||
p = pmadd(x2, p, alpha_1);
|
||||
p = pmul(x, p);
|
||||
|
||||
// Evaluate the denominator polynomial q.
|
||||
Packet q = pmadd(x2, beta_10, beta_8);
|
||||
q = pmadd(x2, q, beta_6);
|
||||
q = pmadd(x2, q, beta_4);
|
||||
q = pmadd(x2, q, beta_2);
|
||||
q = pmadd(x2, q, beta_0);
|
||||
// Divide the numerator by the denominator and shift it up.
|
||||
const Packet logistic = padd(pdiv(p, q), pset1<Packet>(0.5f));
|
||||
if (EIGEN_PREDICT_FALSE(any_small)) {
|
||||
const Packet exponential = pexp(_x);
|
||||
return pselect(lt_mask, exponential, logistic);
|
||||
} else {
|
||||
return logistic;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct functor_traits<scalar_logistic_op<T> > {
|
||||
enum {
|
||||
// The cost estimate for float here here is for the common(?) case where
|
||||
// all arguments are greater than -9.
|
||||
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
|
||||
NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T> >::Cost,
|
||||
(internal::is_same<T, float>::value
|
||||
? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
|
||||
: NumTraits<T>::AddCost * 2 +
|
||||
functor_traits<scalar_exp_op<T> >::Cost),
|
||||
PacketAccess =
|
||||
packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
|
||||
packet_traits<T>::HasNegate && packet_traits<T>::HasExp
|
||||
(internal::is_same<T, float>::value
|
||||
? packet_traits<T>::HasMul && packet_traits<T>::HasMax &&
|
||||
packet_traits<T>::HasMin
|
||||
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
|
||||
};
|
||||
};
|
||||
|
||||
|
@ -1112,6 +1112,11 @@ namespace Eigen {
|
||||
|
||||
#define EIGEN_IMPLIES(a,b) (!(a) || (b))
|
||||
|
||||
#if EIGEN_HAS_BUILTIN(__builtin_expect) || EIGEN_COMP_GNUC
|
||||
#define EIGEN_PREDICT_FALSE(x) (__builtin_expect(x, false))
|
||||
#define EIGEN_PREDICT_TRUE(x) (__builtin_expect(false || (x), true))
|
||||
#endif
|
||||
|
||||
// the expression type of a standard coefficient wise binary operation
|
||||
#define EIGEN_CWISE_BINARY_RETURN_TYPE(LHS,RHS,OPNAME) \
|
||||
CwiseBinaryOp< \
|
||||
|
@ -545,6 +545,7 @@ template<typename Scalar,typename Packet> void packetmath_real()
|
||||
data1[i] = internal::random<Scalar>(-1,1) * std::pow(Scalar(10), internal::random<Scalar>(-6,6));
|
||||
data2[i] = internal::random<Scalar>(-1,1) * std::pow(Scalar(10), internal::random<Scalar>(-6,6));
|
||||
}
|
||||
data1[0] = 1e-20;
|
||||
CHECK_CWISE1_IF(PacketTraits::HasTanh, std::tanh, internal::ptanh);
|
||||
if(PacketTraits::HasExp && PacketSize>=2)
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user