Improve accuracy of fast approximate tanh and the logistic functions in Eigen, such that they preserve relative accuracy to within a few ULPs where their function values tend to zero (around x=0 for tanh, and for large negative x for the logistic function).

This change re-instates the fast rational approximation of the logistic function for float32 in Eigen (removed in 66f07efeae), but uses the more accurate approximation 1/(1+exp(-1)) ~= exp(x) below -9. The exponential is only calculated on the vectorized path if at least one element in the SIMD input vector is less than -9.

This change also contains a few improvements to speed up the original float specialization of logistic:
  - Introduce EIGEN_PREDICT_{FALSE,TRUE} for __builtin_predict and use it to predict that the logistic-only path is most likely (~2-3% speedup for the common case).
  - Carefully set the upper clipping point to the smallest x where the approximation evaluates to exactly 1. This saves the explicit clamping of the output (~7% speedup).

The increased accuracy for tanh comes at a cost of 10-20% depending on instruction set.

The benchmarks below repeated calls

   u = v.logistic()  (u = v.tanh(), respectively)

where u and v are of type Eigen::ArrayXf, have length 8k, and v contains random numbers in [-1,1].

Benchmark numbers for logistic:

Before:
Benchmark                  Time(ns)        CPU(ns)     Iterations
-----------------------------------------------------------------
SSE
BM_eigen_logistic_float        4467           4468         155835  model_time: 4827
AVX
BM_eigen_logistic_float        2347           2347         299135  model_time: 2926
AVX+FMA
BM_eigen_logistic_float        1467           1467         476143  model_time: 2926
AVX512
BM_eigen_logistic_float         805            805         858696  model_time: 1463

After:
Benchmark                  Time(ns)        CPU(ns)     Iterations
-----------------------------------------------------------------
SSE
BM_eigen_logistic_float        2589           2590         270264  model_time: 4827
AVX
BM_eigen_logistic_float        1428           1428         489265  model_time: 2926
AVX+FMA
BM_eigen_logistic_float        1059           1059         662255  model_time: 2926
AVX512
BM_eigen_logistic_float         673            673        1000000  model_time: 1463

Benchmark numbers for tanh:

Before:
Benchmark                  Time(ns)        CPU(ns)     Iterations
-----------------------------------------------------------------
SSE
BM_eigen_tanh_float        2391           2391         292624  model_time: 4242
AVX
BM_eigen_tanh_float        1256           1256         554662  model_time: 2633
AVX+FMA
BM_eigen_tanh_float         823            823         866267  model_time: 1609
AVX512
BM_eigen_tanh_float         443            443        1578999  model_time: 805

After:
Benchmark                  Time(ns)        CPU(ns)     Iterations
-----------------------------------------------------------------
SSE
BM_eigen_tanh_float        2588           2588         273531  model_time: 4242
AVX
BM_eigen_tanh_float        1536           1536         452321  model_time: 2633
AVX+FMA
BM_eigen_tanh_float        1007           1007         694681  model_time: 1609
AVX512
BM_eigen_tanh_float         471            471        1472178  model_time: 805
This commit is contained in:
Rasmus Munk Larsen 2019-12-16 21:33:42 +00:00
parent 8e5da71466
commit a566074480
9 changed files with 191 additions and 23 deletions

View File

@ -17,10 +17,11 @@ namespace internal {
/** \internal \returns the hyperbolic tan of \a a (coeff-wise)
Doesn't do anything fancy, just a 13/6-degree rational interpolant which
is accurate up to a couple of ulps in the (approximate) range [-8, 8],
outside of which tanh(x) = +/-1 in single precision. This is done by
Clamp the inputs to the range [-c, c]. The value c is chosen as the smallest
value where the approximation evaluates to exactly 1.
is accurate up to a couple of ulps in the (approximate) range [-8, 8],
outside of which tanh(x) = +/-1 in single precision. The input is clamped
to the range [-c, c]. The value c is chosen as the smallest value where
the approximation evaluates to exactly 1. In the reange [-0.0004, 0.0004]
the approxmation tanh(x) ~= x is used for better accuracy as x tends to zero.
This implementation works on both scalars and packets.
*/
@ -29,13 +30,15 @@ T generic_fast_tanh_float(const T& a_x)
{
// Clamp the inputs to the range [-c, c]
#ifdef EIGEN_VECTORIZE_FMA
const T plus_clamp = pset1<T>(7.99881172180175781);
const T minus_clamp = pset1<T>(-7.99881172180175781);
const T plus_clamp = pset1<T>(7.99881172180175781f);
const T minus_clamp = pset1<T>(-7.99881172180175781f);
#else
const T plus_clamp = pset1<T>(7.90531110763549805);
const T minus_clamp = pset1<T>(-7.90531110763549805);
const T plus_clamp = pset1<T>(7.90531110763549805f);
const T minus_clamp = pset1<T>(-7.90531110763549805f);
#endif
const T tiny = pset1<T>(0.0004f);
const T x = pmax(pmin(a_x, plus_clamp), minus_clamp);
const T tiny_mask = pcmp_lt(pabs(a_x), tiny);
// The monomial coefficients of the numerator polynomial (odd).
const T alpha_1 = pset1<T>(4.89352455891786e-03f);
const T alpha_3 = pset1<T>(6.37261928875436e-04f);
@ -63,13 +66,13 @@ T generic_fast_tanh_float(const T& a_x)
p = pmadd(x2, p, alpha_1);
p = pmul(x, p);
// Evaluate the denominator polynomial p.
// Evaluate the denominator polynomial q.
T q = pmadd(x2, beta_6, beta_4);
q = pmadd(x2, q, beta_2);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator.
return pdiv(p, q);
return pselect(tiny_mask, x, pdiv(p, q));
}
template<typename RealScalar>

View File

@ -297,9 +297,14 @@ template<> EIGEN_STRONG_INLINE Packet4d pmax<Packet4d>(const Packet4d& a, const
template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LT_OQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_lt_or_nan(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a, b, _CMP_NGE_UQ); }
template<> EIGEN_STRONG_INLINE Packet8f pcmp_eq(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_EQ_OQ); }
template<> EIGEN_STRONG_INLINE Packet4d pcmp_le(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LE_OQ); }
template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_LT_OQ); }
template<> EIGEN_STRONG_INLINE Packet4d pcmp_lt_or_nan(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a, b, _CMP_NGE_UQ); }
template<> EIGEN_STRONG_INLINE Packet4d pcmp_eq(const Packet4d& a, const Packet4d& b) { return _mm256_cmp_pd(a,b,_CMP_EQ_OQ); }
template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8i& b) {
#ifdef EIGEN_VECTORIZE_AVX2
@ -311,6 +316,7 @@ template<> EIGEN_STRONG_INLINE Packet8i pcmp_eq(const Packet8i& a, const Packet8
#endif
}
template<> EIGEN_STRONG_INLINE Packet8f pceil<Packet8f>(const Packet8f& a) { return _mm256_ceil_ps(a); }
template<> EIGEN_STRONG_INLINE Packet4d pceil<Packet4d>(const Packet4d& a) { return _mm256_ceil_pd(a); }

View File

@ -365,6 +365,12 @@ EIGEN_STRONG_INLINE Packet16f cat256(Packet8f a, Packet8f b) {
}
#endif
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
}
template<> EIGEN_STRONG_INLINE Packet16f pcmp_le(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_ps(
@ -388,12 +394,6 @@ template<> EIGEN_STRONG_INLINE Packet16i pcmp_eq(const Packet16i& a, const Packe
return _mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu);
}
template <>
EIGEN_STRONG_INLINE Packet16f pcmp_eq(const Packet16f& a, const Packet16f& b) {
__mmask16 mask = _mm512_cmp_ps_mask(a, b, _CMP_EQ_OQ);
return _mm512_castsi512_ps(
_mm512_mask_set1_epi32(_mm512_set1_epi32(0), mask, 0xffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
@ -401,6 +401,24 @@ EIGEN_STRONG_INLINE Packet8d pcmp_eq(const Packet8d& a, const Packet8d& b) {
return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_le(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LE_OQ);
return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_lt(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_LT_OQ);
return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet8d pcmp_lt_or_nan(const Packet8d& a, const Packet8d& b) {
__mmask8 mask = _mm512_cmp_pd_mask(a, b, _CMP_NGT_UQ);
return _mm512_castsi512_pd(
_mm512_mask_set1_epi64(_mm512_set1_epi64(0), mask, 0xffffffffffffffffu));
}
template <>
EIGEN_STRONG_INLINE Packet16i ptrue<Packet16i>(const Packet16i& /*a*/) {

View File

@ -156,6 +156,15 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double eq_mask(const double& a,
return __longlong_as_double(a == b ? 0xffffffffffffffffull : 0ull);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float lt_mask(const float& a,
const float& b) {
return __int_as_float(a < b ? 0xffffffffu : 0u);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double lt_mask(const double& a,
const double& b) {
return __longlong_as_double(a < b ? 0xffffffffffffffffull : 0ull);
}
} // namespace
template <>
@ -213,10 +222,21 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_eq<float4>(const float4& a,
eq_mask(a.w, b.w));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 pcmp_lt<float4>(const float4& a,
const float4& b) {
return make_float4(lt_mask(a.x, b.x), lt_mask(a.y, b.y), lt_mask(a.z, b.z),
lt_mask(a.w, b.w));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pcmp_eq<double2>(const double2& a, const double2& b) {
return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y));
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double2
pcmp_lt<double2>(const double2& a, const double2& b) {
return make_double2(lt_mask(a.x, b.x), lt_mask(a.y, b.y));
}
#endif // EIGEN_CUDA_ARCH || defined(EIGEN_HIP_DEVICE_COMPILE)
template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
@ -646,6 +666,20 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_eq<half2>(const half2& a,
return __halves2half2(eq1, eq2);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pcmp_lt<half2>(const half2& a,
const half2& b) {
half true_half = half_impl::raw_uint16_to_half(0xffffu);
half false_half = half_impl::raw_uint16_to_half(0x0000u);
half a1 = __low2half(a);
half a2 = __high2half(a);
half b1 = __low2half(b);
half b2 = __high2half(b);
half eq1 = __half2float(a1) < __half2float(b1) ? true_half : false_half;
half eq2 = __half2float(a2) < __half2float(b2) ? true_half : false_half;
return __halves2half2(eq1, eq2);
}
template <>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE half2 pand<half2>(const half2& a,
const half2& b) {

View File

@ -713,6 +713,8 @@ template<> EIGEN_STRONG_INLINE Packet2d pandnot<Packet2d>(const Packet2d& a, con
return vreinterpretq_f64_u64(vbicq_u64(vreinterpretq_u64_f64(a),vreinterpretq_u64_f64(b)));
}
template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vcleq_f64(a,b)); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vcltq_f64(a,b)); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return vreinterpretq_f64_u64(vceqq_f64(a,b)); }
template<> EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) { EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from); }

View File

@ -384,10 +384,17 @@ template<> EIGEN_STRONG_INLINE Packet4i pmax<Packet4i>(const Packet4i& a, const
template<> EIGEN_STRONG_INLINE Packet4f pcmp_le(const Packet4f& a, const Packet4f& b) { return _mm_cmple_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt(const Packet4f& a, const Packet4f& b) { return _mm_cmplt_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_lt_or_nan(const Packet4f& a, const Packet4f& b) { return _mm_cmpnge_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet4f pcmp_eq(const Packet4f& a, const Packet4f& b) { return _mm_cmpeq_ps(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_le(const Packet2d& a, const Packet2d& b) { return _mm_cmple_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt(const Packet2d& a, const Packet2d& b) { return _mm_cmplt_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_lt_or_nan(const Packet2d& a, const Packet2d& b) { return _mm_cmpnge_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) { return _mm_cmpeq_pd(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pcmp_lt(const Packet4i& a, const Packet4i& b) { return _mm_cmplt_epi32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i pcmp_eq(const Packet4i& a, const Packet4i& b) { return _mm_cmpeq_epi32(a,b); }
template<> EIGEN_STRONG_INLINE Packet4i ptrue<Packet4i>(const Packet4i& a) { return _mm_cmpeq_epi32(a, a); }
template<> EIGEN_STRONG_INLINE Packet4f

View File

@ -905,14 +905,106 @@ struct scalar_logistic_op {
}
};
/** \internal
* \brief Template specialization of the logistic function for float.
*
* Uses just a 9/10-degree rational interpolant which
* interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulps in the range
* [-9, 18]. Below -9 we use the more accurate approximation
* 1/(1+exp(-x)) ~= exp(x), and above 18 the logistic function is 1 withing
* one ulp. The shifted logistic is interpolated because it was easier to
* make the fit converge.
*
*/
template <>
struct scalar_logistic_op<float> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(const float& x) const {
// The upper cut-off is the smallest x for which the rational approximation evaluates to 1.
// Choosing this value saves us a few instructions clamping the results at the end.
#ifdef EIGEN_VECTORIZE_FMA
const float cutoff_upper = 16.285715103149414062f;
#else
const float cutoff_upper = 16.619047164916992188f;
#endif
const float cutoff_lower = -9.f;
if (x > cutoff_upper) return 1.0f;
else if (x < cutoff_lower) return numext::exp(x);
else return 1.0f / (1.0f + numext::exp(-x));
}
template <typename Packet> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Packet packetOp(const Packet& _x) const {
const Packet cutoff_lower = pset1<Packet>(-9.f);
const Packet lt_mask = pcmp_lt<Packet>(_x, cutoff_lower);
const bool any_small = predux(lt_mask);
// Clamp the input to be at most 'cutoff_upper'.
#ifdef EIGEN_VECTORIZE_FMA
const Packet cutoff_upper = pset1<Packet>(16.285715103149414062f);
#else
const Packet cutoff_upper = pset1<Packet>(16.619047164916992188f);
#endif
const Packet x = pmin(_x, cutoff_upper);
// The monomial coefficients of the numerator polynomial (odd).
const Packet alpha_1 = pset1<Packet>(2.48287947061529e-01f);
const Packet alpha_3 = pset1<Packet>(8.51377133304701e-03f);
const Packet alpha_5 = pset1<Packet>(6.08574864600143e-05f);
const Packet alpha_7 = pset1<Packet>(1.15627324459942e-07f);
const Packet alpha_9 = pset1<Packet>(4.37031012579801e-11f);
// The monomial coefficients of the denominator polynomial (even).
const Packet beta_0 = pset1<Packet>(9.93151921023180e-01f);
const Packet beta_2 = pset1<Packet>(1.16817656904453e-01f);
const Packet beta_4 = pset1<Packet>(1.70198817374094e-03f);
const Packet beta_6 = pset1<Packet>(6.29106785017040e-06f);
const Packet beta_8 = pset1<Packet>(5.76102136993427e-09f);
const Packet beta_10 = pset1<Packet>(6.10247389755681e-13f);
// Since the polynomials are odd/even, we need x^2.
const Packet x2 = pmul(x, x);
// Evaluate the numerator polynomial p.
Packet p = pmadd(x2, alpha_9, alpha_7);
p = pmadd(x2, p, alpha_5);
p = pmadd(x2, p, alpha_3);
p = pmadd(x2, p, alpha_1);
p = pmul(x, p);
// Evaluate the denominator polynomial q.
Packet q = pmadd(x2, beta_10, beta_8);
q = pmadd(x2, q, beta_6);
q = pmadd(x2, q, beta_4);
q = pmadd(x2, q, beta_2);
q = pmadd(x2, q, beta_0);
// Divide the numerator by the denominator and shift it up.
const Packet logistic = padd(pdiv(p, q), pset1<Packet>(0.5f));
if (EIGEN_PREDICT_FALSE(any_small)) {
const Packet exponential = pexp(_x);
return pselect(lt_mask, exponential, logistic);
} else {
return logistic;
}
}
};
template <typename T>
struct functor_traits<scalar_logistic_op<T> > {
enum {
// The cost estimate for float here here is for the common(?) case where
// all arguments are greater than -9.
Cost = scalar_div_cost<T, packet_traits<T>::HasDiv>::value +
NumTraits<T>::AddCost * 2 + functor_traits<scalar_exp_op<T> >::Cost,
(internal::is_same<T, float>::value
? NumTraits<T>::AddCost * 15 + NumTraits<T>::MulCost * 11
: NumTraits<T>::AddCost * 2 +
functor_traits<scalar_exp_op<T> >::Cost),
PacketAccess =
packet_traits<T>::HasAdd && packet_traits<T>::HasDiv &&
packet_traits<T>::HasNegate && packet_traits<T>::HasExp
(internal::is_same<T, float>::value
? packet_traits<T>::HasMul && packet_traits<T>::HasMax &&
packet_traits<T>::HasMin
: packet_traits<T>::HasNegate && packet_traits<T>::HasExp)
};
};

View File

@ -1112,6 +1112,11 @@ namespace Eigen {
#define EIGEN_IMPLIES(a,b) (!(a) || (b))
#if EIGEN_HAS_BUILTIN(__builtin_expect) || EIGEN_COMP_GNUC
#define EIGEN_PREDICT_FALSE(x) (__builtin_expect(x, false))
#define EIGEN_PREDICT_TRUE(x) (__builtin_expect(false || (x), true))
#endif
// the expression type of a standard coefficient wise binary operation
#define EIGEN_CWISE_BINARY_RETURN_TYPE(LHS,RHS,OPNAME) \
CwiseBinaryOp< \

View File

@ -545,6 +545,7 @@ template<typename Scalar,typename Packet> void packetmath_real()
data1[i] = internal::random<Scalar>(-1,1) * std::pow(Scalar(10), internal::random<Scalar>(-6,6));
data2[i] = internal::random<Scalar>(-1,1) * std::pow(Scalar(10), internal::random<Scalar>(-6,6));
}
data1[0] = 1e-20;
CHECK_CWISE1_IF(PacketTraits::HasTanh, std::tanh, internal::ptanh);
if(PacketTraits::HasExp && PacketSize>=2)
{