Fix plog(+INF): it returned ~87 instead of +INF

This commit is contained in:
Gael Guennebaud 2018-12-23 15:40:52 +01:00
parent 6dd93f7e3b
commit 5713fb7feb
2 changed files with 16 additions and 6 deletions

View File

@ -54,6 +54,7 @@ Packet plog_float(const Packet _x)
// The smallest non denormalized float number.
const Packet cst_min_norm_pos = pset1frombits<Packet>( 0x00800000u);
const Packet cst_minus_inf = pset1frombits<Packet>( 0xff800000u);
const Packet cst_pos_inf = pset1frombits<Packet>( 0x7f800000u);
// Polynomial coefficients.
const Packet cst_cephes_SQRTHF = pset1<Packet>(0.707106781186547524f);
@ -69,9 +70,6 @@ Packet plog_float(const Packet _x)
const Packet cst_cephes_log_q1 = pset1<Packet>(-2.12194440e-4f);
const Packet cst_cephes_log_q2 = pset1<Packet>(0.693359375f);
Packet invalid_mask = pcmp_lt_or_nan(x, pzero(x));
Packet iszero_mask = pcmp_eq(x,pzero(x));
// Truncate input values to the minimum positive normal.
x = pmax(x, cst_min_norm_pos);
@ -117,8 +115,15 @@ Packet plog_float(const Packet _x)
x = padd(x, y);
x = padd(x, y2);
// Filter out invalid inputs, i.e. negative arg will be NAN, 0 will be -INF.
return pselect(iszero_mask, cst_minus_inf, por(x, invalid_mask));
Packet invalid_mask = pcmp_lt_or_nan(_x, pzero(_x));
Packet iszero_mask = pcmp_eq(_x,pzero(_x));
Packet pos_inf_mask = pcmp_eq(_x,cst_pos_inf);
// Filter out invalid inputs, i.e.:
// - negative arg will be NAN
// - 0 will be -INF
// - +INF will be +INF
return pselect(iszero_mask, cst_minus_inf,
por(pselect(pos_inf_mask,cst_pos_inf,x), invalid_mask));
}
// Exponential function. Works by writing "x = m*log(2) + r" where

View File

@ -520,10 +520,11 @@ template<typename Scalar,typename Packet> void packetmath_real()
CHECK_CWISE1_IF(internal::packet_traits<Scalar>::HasErfc, std::erfc, internal::perfc);
#endif
if(PacketTraits::HasLog && PacketSize>=2)
if(PacketSize>=2)
{
data1[0] = std::numeric_limits<Scalar>::quiet_NaN();
data1[1] = std::numeric_limits<Scalar>::epsilon();
if(PacketTraits::HasLog)
{
packet_helper<PacketTraits::HasLog,Packet> h;
h.store(data2, internal::plog(h.load(data1)));
@ -551,6 +552,10 @@ template<typename Scalar,typename Packet> void packetmath_real()
data1[0] = Scalar(-1.0f);
h.store(data2, internal::plog(h.load(data1)));
VERIFY((numext::isnan)(data2[0]));
data1[0] = std::numeric_limits<Scalar>::infinity();
h.store(data2, internal::plog(h.load(data1)));
VERIFY((numext::isinf)(data2[0]));
}
{
packet_helper<PacketTraits::HasSqrt,Packet> h;