Implement vectorized versions of log1p and expm1 in Eigen using Kahan's formulas, and change the scalar implementations to properly handle infinite arguments.

Depending on instruction set, significant speedups are observed for the vectorized path:
log1p wall time is reduced 60-93% (2.5x - 15x speedup)
expm1 wall time is reduced 0-85% (1x - 7x speedup)

The scalar path is slower by 20-30% due to the extra branch needed to handle +infinity correctly.

Full benchmarks measured on Intel(R) Xeon(R) Gold 6154 here: https://bitbucket.org/snippets/rmlarsen/MXBkpM
This commit is contained in:
Rasmus Munk Larsen 2019-08-12 13:53:28 -07:00
parent d55d392e7b
commit a3298b22ec
9 changed files with 95 additions and 6 deletions

View File

@ -501,7 +501,8 @@ namespace std_fallback {
}
EIGEN_USING_STD_MATH(log);
return (u - RealScalar(1)) * x / log(u);
Scalar logu = log(u);
return numext::equal_strict(u, logu) ? u : (u - RealScalar(1)) * x / logu;
}
}
@ -548,7 +549,10 @@ namespace std_fallback {
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_USING_STD_MATH(log);
Scalar x1p = RealScalar(1) + x;
return numext::equal_strict(x1p, Scalar(1)) ? x : x * ( log(x1p) / (x1p - RealScalar(1)) );
Scalar log_1p = log(x1p);
const bool is_inf = numext::equal_strict(x1p, log_1p);
const bool is_small = numext::equal_strict(x1p, Scalar(1));
return (is_inf || is_small) ? x : x * (log_1p / (x1p - RealScalar(1)));
}
}

View File

@ -36,6 +36,16 @@ plog<Packet8f>(const Packet8f& _x) {
return plog_float(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f plog1p<Packet8f>(const Packet8f& _x) {
return generic_plog1p(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f pexpm1<Packet8f>(const Packet8f& _x) {
return generic_expm1(_x);
}
// Exponential function. Works by writing "x = m*log(2) + r" where
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).

View File

@ -65,6 +65,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,

View File

@ -393,6 +393,18 @@ pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
#if defined(EIGEN_VECTORIZE_AVX512DQ)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
#endif
} // end namespace internal
} // end namespace Eigen

View File

@ -60,6 +60,8 @@ template<> struct packet_traits<float> : default_packet_traits
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
#endif
HasExp = 1,
HasSqrt = EIGEN_FAST_MATH,

View File

@ -126,6 +126,52 @@ Packet plog_float(const Packet _x)
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
}
/** \internal \returns log(1 + x) computed using W. Kahan's formula.
See: http://www.plunk.org/~hatch/rightway.php
*/
template<typename Packet>
Packet generic_plog1p(const Packet& x)
{
typedef typename unpacket_traits<Packet>::type ScalarType;
const Packet one = pset1<Packet>(ScalarType(1));
Packet xp1 = padd(x, one);
Packet small_mask = pcmp_eq(xp1, one);
Packet log1 = plog(xp1);
// Add a check to handle x == +inf.
Packet pos_inf_mask = pcmp_eq(x, log1);
Packet log_large = pmul(x, pdiv(log1, psub(xp1, one)));
return pselect(por(small_mask, pos_inf_mask), x, log_large);
}
/** \internal \returns exp(x)-1 computed using W. Kahan's formula.
See: http://www.plunk.org/~hatch/rightway.php
*/
template<typename Packet>
Packet generic_expm1(const Packet& x)
{
typedef typename unpacket_traits<Packet>::type ScalarType;
const Packet one = pset1<Packet>(ScalarType(1));
const Packet neg_one = pset1<Packet>(ScalarType(-1));
Packet u = pexp(x);
Packet one_mask = pcmp_eq(u, one);
Packet u_minus_one = psub(u, one);
Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one);
Packet logu = plog(u);
// The following comparison is to catch the case where
// exp(x) = +inf. It is written in this way to avoid having
// to form the constant +inf, which depends on the packet
// type.
Packet pos_inf_mask = pcmp_eq(logu, u);
Packet expm1 = pmul(u_minus_one, pdiv(x, logu));
expm1 = pselect(pos_inf_mask, u, expm1);
return pselect(one_mask,
x,
pselect(neg_one_mask,
neg_one,
expm1));
}
// Exponential function. Works by writing "x = m*log(2) + r" where
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).

View File

@ -22,11 +22,20 @@ namespace Eigen {
namespace internal {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f plog<Packet4f>(const Packet4f& _x)
{
Packet4f plog<Packet4f>(const Packet4f& _x) {
return plog_float(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f plog1p<Packet4f>(const Packet4f& _x) {
return generic_plog1p(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pexpm1<Packet4f>(const Packet4f& _x) {
return generic_expm1(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pexp<Packet4f>(const Packet4f& _x)
{

View File

@ -110,6 +110,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,

View File

@ -604,11 +604,13 @@ template<typename Scalar,typename Packet> void packetmath_real()
CHECK_CWISE1_IF(PacketTraits::HasSqrt, Scalar(1)/std::sqrt, internal::prsqrt);
CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog);
#if EIGEN_HAS_C99_MATH && (__cplusplus > 199711L)
CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1);
CHECK_CWISE1_IF(PacketTraits::HasLog1p, std::log1p, internal::plog1p);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasLGamma, std::lgamma, internal::plgamma);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErf, std::erf, internal::perf);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErfc, std::erfc, internal::perfc);
data1[0] = std::numeric_limits<Scalar>::infinity();
data1[1] = std::numeric_limits<Scalar>::denorm_min();
CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1);
CHECK_CWISE1_IF(PacketTraits::HasLog1p, std::log1p, internal::plog1p);
#endif
if(PacketSize>=2)