Vectorize atan() for double.

This commit is contained in:
Rasmus Munk Larsen 2022-10-01 01:49:30 +00:00
parent 1e1848fdb1
commit c475228b28
12 changed files with 104 additions and 4 deletions

View File

@ -50,6 +50,12 @@ patan<Packet8f>(const Packet8f& _x) {
return patan_float(_x); return patan_float(_x);
} }
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet4d
patan<Packet4d>(const Packet4d& _x) {
return patan_double(_x);
}
template <> template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8f EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8f
plog<Packet8f>(const Packet8f& _x) { plog<Packet8f>(const Packet8f& _x) {

View File

@ -109,6 +109,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasExp = 1, HasExp = 1,
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1, HasRsqrt = 1,
HasATan = 1,
HasBlend = 1, HasBlend = 1,
HasRound = 1, HasRound = 1,
HasFloor = 1, HasFloor = 1,

View File

@ -275,6 +275,12 @@ patan<Packet16f>(const Packet16f& _x) {
return patan_float(_x); return patan_float(_x);
} }
template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet8d
patan<Packet8d>(const Packet8d& _x) {
return patan_double(_x);
}
template <> template <>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet16f
ptanh<Packet16f>(const Packet16f& _x) { ptanh<Packet16f>(const Packet16f& _x) {

View File

@ -161,6 +161,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasSqrt = EIGEN_FAST_MATH, HasSqrt = EIGEN_FAST_MATH,
HasRsqrt = EIGEN_FAST_MATH, HasRsqrt = EIGEN_FAST_MATH,
#endif #endif
HasATan = 1,
HasCmp = 1, HasCmp = 1,
HasDiv = 1, HasDiv = 1,
HasRound = 1, HasRound = 1,

View File

@ -60,6 +60,12 @@ Packet4f patan<Packet4f>(const Packet4f& _x)
return patan_float(_x); return patan_float(_x);
} }
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet2d patan<Packet2d>(const Packet2d& _x)
{
return patan_double(_x);
}
#ifdef __VSX__ #ifdef __VSX__
#ifndef EIGEN_COMP_CLANG #ifndef EIGEN_COMP_CLANG
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS

View File

@ -2708,6 +2708,7 @@ template<> struct packet_traits<double> : default_packet_traits
HasAbs = 1, HasAbs = 1,
HasSin = 0, HasSin = 0,
HasCos = 0, HasCos = 0,
HasATan = 1,
HasLog = 0, HasLog = 0,
HasExp = 1, HasExp = 1,
HasSqrt = 1, HasSqrt = 1,

View File

@ -864,6 +864,68 @@ Packet patan_float(const Packet& x_in) {
return pselect(neg_mask, pnegate(p), p); return pselect(neg_mask, pnegate(p), p);
} }
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet patan_double(const Packet& x_in) {
typedef typename unpacket_traits<Packet>::type Scalar;
static_assert(std::is_same<Scalar, double>::value, "Scalar type must be double");
const Packet cst_one = pset1<Packet>(1.0);
constexpr double kPiOverTwo = static_cast<double>(M_PI_2);
const Packet cst_pi_over_two = pset1<Packet>(kPiOverTwo);
constexpr double kPiOverFour = static_cast<double>(M_PI_4);
const Packet cst_pi_over_four = pset1<Packet>(kPiOverFour);
const Packet cst_large = pset1<Packet>(2.4142135623730950488016887); // tan(3*pi/8);
const Packet cst_medium = pset1<Packet>(0.4142135623730950488016887); // tan(pi/8);
const Packet q0 = pset1<Packet>(-0.33333333333330028569463365784031338989734649658203);
const Packet q2 = pset1<Packet>(0.199999999990664090177006073645316064357757568359375);
const Packet q4 = pset1<Packet>(-0.142857141937123677255527809393242932856082916259766);
const Packet q6 = pset1<Packet>(0.111111065991039953404495577160560060292482376098633);
const Packet q8 = pset1<Packet>(-9.0907812986129224452902519715280504897236824035645e-2);
const Packet q10 = pset1<Packet>(7.6900542950704739442180368769186316058039665222168e-2);
const Packet q12 = pset1<Packet>(-6.6410112986494976294871150912513257935643196105957e-2);
const Packet q14 = pset1<Packet>(5.6920144995467943094258345126945641823112964630127e-2);
const Packet q16 = pset1<Packet>(-4.3577020814990513608577771265117917209863662719727e-2);
const Packet q18 = pset1<Packet>(2.1244050233624342527427586446719942614436149597168e-2);
const Packet neg_mask = pcmp_lt(x_in, pzero(x_in));
Packet x = pabs(x_in);
// Use the same range reduction strategy (to [0:tan(pi/8)]) as the
// Cephes library:
// "Large": For x >= tan(3*pi/8), use atan(1/x) = pi/2 - atan(x).
// "Medium": For x in [tan(pi/8) : tan(3*pi/8)),
// use atan(x) = pi/4 + atan((x-1)/(x+1)).
// "Small": For x < tan(pi/8), approximate atan(x) directly by a polynomial
// calculated using Sollya.
const Packet large_mask = pcmp_lt(cst_large, x);
x = pselect(large_mask, preciprocal(x), x);
const Packet medium_mask = pandnot(pcmp_lt(cst_medium, x), large_mask);
x = pselect(medium_mask, pdiv(psub(x, cst_one), padd(x, cst_one)), x);
// Approximate atan(x) on [0:tan(pi/8)] by a polynomial of the form
// P(x) = x + x^3 * Q(x^2),
// where Q(x^2) is a 9th order polynomial in x^2.
const Packet x2 = pmul(x, x);
const Packet x4 = pmul(x2, x2);
Packet q_odd = pmadd(q18, x4, q14);
Packet q_even = pmadd(q16, x4, q12);
q_odd = pmadd(q_odd, x4, q10);
q_even = pmadd(q_even, x4, q8);
q_odd = pmadd(q_odd, x4, q6);
q_even = pmadd(q_even, x4, q4);
q_odd = pmadd(q_odd, x4, q2);
q_even = pmadd(q_even, x4, q0);
const Packet q = pmadd(q_odd, x2, q_even);
Packet p = pmadd(q, pmul(x, x2), x);
// Apply transformations according to the range reduction masks.
p = pselect(large_mask, psub(cst_pi_over_two, p), p);
p = pselect(medium_mask, padd(cst_pi_over_four, p), p);
return pselect(neg_mask, pnegate(p), p);
}
template<typename Packet> template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet pdiv_complex(const Packet& x, const Packet& y) { Packet pdiv_complex(const Packet& x, const Packet& y) {
@ -958,8 +1020,8 @@ Packet psqrt_complex(const Packet& a) {
// Step 4. Compute solution for inputs with negative real part: // Step 4. Compute solution for inputs with negative real part:
// [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1] // [|eta0|, sign(y0)*rho0, |eta1|, sign(y1)*rho1]
const RealScalar neg_zero = RealScalar(numext::bit_cast<float>(0x80000000u)); const RealPacket cst_imag_sign_mask =
const RealPacket cst_imag_sign_mask = pset1<Packet>(Scalar(RealScalar(0.0), neg_zero)).v; pset1<Packet>(Scalar(RealScalar(0.0), RealScalar(-0.0))).v;
RealPacket imag_signs = pand(a.v, cst_imag_sign_mask); RealPacket imag_signs = pand(a.v, cst_imag_sign_mask);
Packet negative_real_result; Packet negative_real_result;
// Notice that rho is positive, so taking it's absolute value is a noop. // Notice that rho is positive, so taking it's absolute value is a noop.

View File

@ -104,6 +104,11 @@ template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet patan_float(const Packet& x); Packet patan_float(const Packet& x);
/** \internal \returns atan(x) for double precision float */
template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet patan_double(const Packet& x);
/** \internal \returns sqrt(x) for complex types */ /** \internal \returns sqrt(x) for complex types */
template<typename Packet> template<typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS

View File

@ -102,6 +102,9 @@ template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet2d pexp<Pac
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet2d plog<Packet2d>(const Packet2d& x) template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet2d plog<Packet2d>(const Packet2d& x)
{ return plog_double(x); } { return plog_double(x); }
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet2d patan<Packet2d>(const Packet2d& x)
{ return patan_double(x); }
#endif #endif
} // end namespace internal } // end namespace internal

View File

@ -3762,10 +3762,13 @@ template<> struct packet_traits<double> : default_packet_traits
HasCeil = 1, HasCeil = 1,
HasRint = 1, HasRint = 1,
#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
HasExp = 1,
HasLog = 1,
HasATan = 1,
#endif
HasSin = 0, HasSin = 0,
HasCos = 0, HasCos = 0,
HasLog = 1,
HasExp = 1,
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1, HasRsqrt = 1,
HasTanh = 0, HasTanh = 0,

View File

@ -81,6 +81,11 @@ Packet4f pacos<Packet4f>(const Packet4f& _x)
return pacos_float(_x); return pacos_float(_x);
} }
template <> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet2d patan<Packet2d>(const Packet2d& _x) {
return patan_double(_x);
}
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
Packet4f pasin<Packet4f>(const Packet4f& _x) Packet4f pasin<Packet4f>(const Packet4f& _x)
{ {

View File

@ -178,6 +178,7 @@ struct packet_traits<double> : default_packet_traits {
HasExp = 1, HasExp = 1,
HasSqrt = 1, HasSqrt = 1,
HasRsqrt = 1, HasRsqrt = 1,
HasATan = 1,
HasBlend = 1, HasBlend = 1,
HasFloor = 1, HasFloor = 1,
HasCeil = 1, HasCeil = 1,