Vectorize atanh<double>. Make atanh(x) standard compliant for |x| >= 1.

This commit is contained in:
Rasmus Munk Larsen 2024-08-30 17:27:55 +00:00
parent 26e2c4f617
commit bbdabebf44
8 changed files with 64 additions and 3 deletions

View File

@ -23,6 +23,7 @@ namespace internal {
EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT(Packet8f) EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_FLOAT(Packet8f)
EIGEN_DOUBLE_PACKET_FUNCTION(atanh, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(log, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(log2, Packet4d)
EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d) EIGEN_DOUBLE_PACKET_FUNCTION(exp, Packet4d)

View File

@ -148,6 +148,7 @@ struct packet_traits<double> : default_packet_traits {
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1, HasRsqrt = 1,
HasATan = 1, HasATan = 1,
HasATanh = 1,
HasBlend = 1 HasBlend = 1
}; };
}; };

View File

@ -154,6 +154,7 @@ struct packet_traits<double> : default_packet_traits {
HasExp = 1, HasExp = 1,
HasATan = 1, HasATan = 1,
HasTanh = EIGEN_FAST_MATH, HasTanh = EIGEN_FAST_MATH,
HasATanh = 1,
HasCmp = 1, HasCmp = 1,
HasDiv = 1 HasDiv = 1
}; };

View File

@ -3177,6 +3177,7 @@ struct packet_traits<double> : default_packet_traits {
HasSin = EIGEN_FAST_MATH, HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH,
HasTanh = EIGEN_FAST_MATH, HasTanh = EIGEN_FAST_MATH,
HasATanh = 1,
HasATan = 0, HasATan = 0,
HasLog = 0, HasLog = 0,
HasExp = 1, HasExp = 1,

View File

@ -1198,8 +1198,7 @@ template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x) { EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x) {
typedef typename unpacket_traits<Packet>::type Scalar; typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, float>::value, "Scalar type must be float"); static_assert(std::is_same<Scalar, float>::value, "Scalar type must be float");
const Packet half = pset1<Packet>(0.5f);
const Packet x_gt_half = pcmp_le(half, pabs(x));
// For |x| in [0:0.5] we use a polynomial approximation of the form // For |x| in [0:0.5] we use a polynomial approximation of the form
// P(x) = x + x^3*(c3 + x^2 * (c5 + x^2 * (... x^2 * c11) ... )). // P(x) = x + x^3*(c3 + x^2 * (c5 + x^2 * (... x^2 * c11) ... )).
const Packet C3 = pset1<Packet>(0.3333373963832855224609375f); const Packet C3 = pset1<Packet>(0.3333373963832855224609375f);
@ -1215,10 +1214,61 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Pa
p = pmadd(pmul(x, x2), p, x); p = pmadd(pmul(x, x2), p, x);
// For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x)); // For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x));
const Packet half = pset1<Packet>(0.5f);
const Packet one = pset1<Packet>(1.0f); const Packet one = pset1<Packet>(1.0f);
Packet r = pdiv(padd(one, x), psub(one, x)); Packet r = pdiv(padd(one, x), psub(one, x));
r = pmul(half, plog(r)); r = pmul(half, plog(r));
return pselect(x_gt_half, r, p);
const Packet x_gt_half = pcmp_le(half, pabs(x));
const Packet x_eq_one = pcmp_eq(one, pabs(x));
const Packet x_gt_one = pcmp_lt(one, pabs(x));
const Packet inf = pset1<Packet>(std::numeric_limits<float>::infinity());
return por(x_gt_one, pselect(x_eq_one, por(psignbit(x), inf), pselect(x_gt_half, r, p)));
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_double(const Packet& x) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, double>::value, "Scalar type must be double");
// For x in [-0.5:0.5] we use a rational approximation of the form
// R(x) = x + x^3*P(x^2)/Q(x^2), where P is or order 4 and Q is of order 5.
const Packet p0 = pset1<Packet>(1.2306328729812676e-01);
const Packet p2 = pset1<Packet>(-2.5949536095445679e-01);
const Packet p4 = pset1<Packet>(1.8185306179826699e-01);
const Packet p6 = pset1<Packet>(-4.7129526768798737e-02);
const Packet p8 = pset1<Packet>(3.3071338469301391e-03);
const Packet q0 = pset1<Packet>(3.6918986189438030e-01);
const Packet q2 = pset1<Packet>(-1.0000000000000000e+00);
const Packet q4 = pset1<Packet>(9.8733495886883648e-01);
const Packet q6 = pset1<Packet>(-4.2828141436397615e-01);
const Packet q8 = pset1<Packet>(7.6391885763341910e-02);
const Packet q10 = pset1<Packet>(-3.8679974580640881e-03);
const Packet x2 = pmul(x, x);
const Packet x3 = pmul(x, x2);
Packet q = pmadd(q10, x2, q8);
Packet p = pmadd(p8, x2, p6);
q = pmadd(x2, q, q6);
p = pmadd(x2, p, p4);
q = pmadd(x2, q, q4);
p = pmadd(x2, p, p2);
q = pmadd(x2, q, q2);
p = pmadd(x2, p, p0);
q = pmadd(x2, q, q0);
Packet y_small = pmadd(x3, pdiv(p, q), x);
// For |x| in ]0.5:1.0] we use atanh = 0.5*ln((1+x)/(1-x));
const Packet half = pset1<Packet>(0.5);
const Packet one = pset1<Packet>(1.0);
Packet y_large = pdiv(padd(one, x), psub(one, x));
y_large = pmul(half, plog(y_large));
const Packet x_gt_half = pcmp_le(half, pabs(x));
const Packet x_eq_one = pcmp_eq(one, pabs(x));
const Packet x_gt_one = pcmp_lt(one, pabs(x));
const Packet inf = pset1<Packet>(std::numeric_limits<double>::infinity());
return por(x_gt_one, pselect(x_eq_one, por(psignbit(x), inf), pselect(x_gt_half, y_large, y_small)));
} }
template <typename Packet> template <typename Packet>

View File

@ -114,6 +114,10 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet ptanh_double(const Pa
template <typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x); EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_float(const Packet& x);
/** \internal \returns atanh(x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet patanh_double(const Packet& x);
/** \internal \returns sqrt(x) for complex types */ /** \internal \returns sqrt(x) for complex types */
template <typename Packet> template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt_complex(const Packet& a); EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psqrt_complex(const Packet& a);
@ -182,6 +186,7 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_round(const Packet& a);
} }
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \ #define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(atanh, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log, PACKET) \ EIGEN_DOUBLE_PACKET_FUNCTION(log, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET) \ EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET) \ EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET) \

View File

@ -5133,6 +5133,7 @@ struct packet_traits<double> : default_packet_traits {
HasExp = 1, HasExp = 1,
HasLog = 1, HasLog = 1,
HasATan = 1, HasATan = 1,
HasATanh = 1,
#endif #endif
HasSin = EIGEN_FAST_MATH, HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH, HasCos = EIGEN_FAST_MATH,

View File

@ -220,6 +220,7 @@ struct packet_traits<double> : default_packet_traits {
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1, HasRsqrt = 1,
HasATan = 1, HasATan = 1,
HasATanh = 1,
HasBlend = 1 HasBlend = 1
}; };
}; };