Merged in rmlarsen/eigen (pull request PR-680)

Implement vectorized versions of log1p and expm1 in Eigen using Kahan's formulas, and change the scalar implementations to properly handle infinite arguments.
This commit is contained in:
Rasmus Larsen 2019-08-22 00:25:29 +00:00
commit 57f6b62597
9 changed files with 95 additions and 6 deletions

View File

@ -501,7 +501,8 @@ namespace std_fallback {
}
EIGEN_USING_STD_MATH(log);
return (u - RealScalar(1)) * x / log(u);
Scalar logu = log(u);
return numext::equal_strict(u, logu) ? u : (u - RealScalar(1)) * x / logu;
}
}
@ -548,7 +549,10 @@ namespace std_fallback {
typedef typename NumTraits<Scalar>::Real RealScalar;
EIGEN_USING_STD_MATH(log);
Scalar x1p = RealScalar(1) + x;
return numext::equal_strict(x1p, Scalar(1)) ? x : x * ( log(x1p) / (x1p - RealScalar(1)) );
Scalar log_1p = log(x1p);
const bool is_inf = numext::equal_strict(x1p, log_1p);
const bool is_small = numext::equal_strict(x1p, Scalar(1));
return (is_inf || is_small) ? x : x * (log_1p / (x1p - RealScalar(1)));
}
}

View File

@ -36,6 +36,16 @@ plog<Packet8f>(const Packet8f& _x) {
return plog_float(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f plog1p<Packet8f>(const Packet8f& _x) {
return generic_plog1p(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet8f pexpm1<Packet8f>(const Packet8f& _x) {
return generic_expm1(_x);
}
// Exponential function. Works by writing "x = m*log(2) + r" where
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).

View File

@ -65,6 +65,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,

View File

@ -393,6 +393,18 @@ pcos<Packet16f>(const Packet16f& _x) {
return pcos_float(_x);
}
#if defined(EIGEN_VECTORIZE_AVX512DQ)
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f plog1p<Packet16f>(const Packet16f& _x) {
return generic_plog1p(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet16f pexpm1<Packet16f>(const Packet16f& _x) {
return generic_expm1(_x);
}
#endif
} // end namespace internal
} // end namespace Eigen

View File

@ -60,6 +60,8 @@ template<> struct packet_traits<float> : default_packet_traits
#if EIGEN_GNUC_AT_LEAST(5, 3) || (!EIGEN_COMP_GNUC_STRICT)
#ifdef EIGEN_VECTORIZE_AVX512DQ
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
#endif
HasExp = 1,
HasSqrt = EIGEN_FAST_MATH,

View File

@ -126,6 +126,52 @@ Packet plog_float(const Packet _x)
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
}
/** \internal \returns log(1 + x) computed using W. Kahan's formula.
See: http://www.plunk.org/~hatch/rightway.php
*/
template<typename Packet>
Packet generic_plog1p(const Packet& x)
{
typedef typename unpacket_traits<Packet>::type ScalarType;
const Packet one = pset1<Packet>(ScalarType(1));
Packet xp1 = padd(x, one);
Packet small_mask = pcmp_eq(xp1, one);
Packet log1 = plog(xp1);
// Add a check to handle x == +inf.
Packet pos_inf_mask = pcmp_eq(x, log1);
Packet log_large = pmul(x, pdiv(log1, psub(xp1, one)));
return pselect(por(small_mask, pos_inf_mask), x, log_large);
}
/** \internal \returns exp(x)-1 computed using W. Kahan's formula.
See: http://www.plunk.org/~hatch/rightway.php
*/
template<typename Packet>
Packet generic_expm1(const Packet& x)
{
typedef typename unpacket_traits<Packet>::type ScalarType;
const Packet one = pset1<Packet>(ScalarType(1));
const Packet neg_one = pset1<Packet>(ScalarType(-1));
Packet u = pexp(x);
Packet one_mask = pcmp_eq(u, one);
Packet u_minus_one = psub(u, one);
Packet neg_one_mask = pcmp_eq(u_minus_one, neg_one);
Packet logu = plog(u);
// The following comparison is to catch the case where
// exp(x) = +inf. It is written in this way to avoid having
// to form the constant +inf, which depends on the packet
// type.
Packet pos_inf_mask = pcmp_eq(logu, u);
Packet expm1 = pmul(u_minus_one, pdiv(x, logu));
expm1 = pselect(pos_inf_mask, u, expm1);
return pselect(one_mask,
x,
pselect(neg_one_mask,
neg_one,
expm1));
}
// Exponential function. Works by writing "x = m*log(2) + r" where
// "m = floor(x/log(2)+1/2)" and "r" is the remainder. The result is then
// "exp(x) = 2^m*exp(r)" where exp(r) is in the range [-1,1).

View File

@ -22,11 +22,20 @@ namespace Eigen {
namespace internal {
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f plog<Packet4f>(const Packet4f& _x)
{
Packet4f plog<Packet4f>(const Packet4f& _x) {
return plog_float(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f plog1p<Packet4f>(const Packet4f& _x) {
return generic_plog1p(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pexpm1<Packet4f>(const Packet4f& _x) {
return generic_expm1(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED
Packet4f pexp<Packet4f>(const Packet4f& _x)
{

View File

@ -110,6 +110,8 @@ template<> struct packet_traits<float> : default_packet_traits
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasLog1p = 1,
HasExpm1 = 1,
HasExp = 1,
HasSqrt = 1,
HasRsqrt = 1,

View File

@ -604,11 +604,13 @@ template<typename Scalar,typename Packet> void packetmath_real()
CHECK_CWISE1_IF(PacketTraits::HasSqrt, Scalar(1)/std::sqrt, internal::prsqrt);
CHECK_CWISE1_IF(PacketTraits::HasLog, std::log, internal::plog);
#if EIGEN_HAS_C99_MATH && (__cplusplus > 199711L)
CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1);
CHECK_CWISE1_IF(PacketTraits::HasLog1p, std::log1p, internal::plog1p);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasLGamma, std::lgamma, internal::plgamma);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErf, std::erf, internal::perf);
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErfc, std::erfc, internal::perfc);
data1[0] = std::numeric_limits<Scalar>::infinity();
data1[1] = std::numeric_limits<Scalar>::denorm_min();
CHECK_CWISE1_IF(PacketTraits::HasExpm1, std::expm1, internal::pexpm1);
CHECK_CWISE1_IF(PacketTraits::HasLog1p, std::log1p, internal::plog1p);
#endif
if(PacketSize>=2)