Simd sincos double

This commit is contained in:
Damiano Franzò 2024-04-15 21:12:32 +00:00 committed by Rasmus Munk Larsen
parent 6ad2ccea4e
commit 888fca0e2b
6 changed files with 185 additions and 2 deletions

View File

@ -142,6 +142,8 @@ struct packet_traits<double> : default_packet_traits {
HasCmp = 1,
HasDiv = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasSqrt = 1,

View File

@ -156,6 +156,8 @@ struct packet_traits<double> : default_packet_traits {
HasBlend = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasATan = 1,

View File

@ -696,6 +696,173 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_float(const Pack
return psincos_float<false>(x);
}
// Trigonometric argument reduction for double for inputs smaller than 15.
// Reduces trigonometric arguments for double inputs where x < 15. Given an argument x and its corresponding quadrant
// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t.
template <typename Packet>
Packet trig_reduce_small_double(const Packet& x, const Packet& q) {
// Pi/2 split into 2 values
const Packet cst_pio2_a = pset1<Packet>(-1.570796325802803);
const Packet cst_pio2_b = pset1<Packet>(-9.920935184482005e-10);
Packet t;
t = pmadd(cst_pio2_a, q, x);
t = pmadd(cst_pio2_b, q, t);
return t;
}
// Trigonometric argument reduction for double for inputs smaller than 1e14.
// Reduces trigonometric arguments for double inputs where x < 1e14. Given an argument x and its corresponding quadrant
// count n, the function computes and returns the reduced argument t such that x = n * pi/2 + t.
template <typename Packet>
Packet trig_reduce_medium_double(const Packet& x, const Packet& q_high, const Packet& q_low) {
// Pi/2 split into 4 values
const Packet cst_pio2_a = pset1<Packet>(-1.570796325802803);
const Packet cst_pio2_b = pset1<Packet>(-9.920935184482005e-10);
const Packet cst_pio2_c = pset1<Packet>(-6.123234014771656e-17);
const Packet cst_pio2_d = pset1<Packet>(1.903488962019325e-25);
Packet t;
t = pmadd(cst_pio2_a, q_high, x);
t = pmadd(cst_pio2_a, q_low, t);
t = pmadd(cst_pio2_b, q_high, t);
t = pmadd(cst_pio2_b, q_low, t);
t = pmadd(cst_pio2_c, q_high, t);
t = pmadd(cst_pio2_c, q_low, t);
t = pmadd(cst_pio2_d, padd(q_low, q_high), t);
return t;
}
template <bool ComputeSine, typename Packet, bool ComputeBoth = false>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
#if EIGEN_COMP_GNUC_STRICT
__attribute__((optimize("-fno-unsafe-math-optimizations")))
#endif
Packet
psincos_double(const Packet& x) {
typedef typename unpacket_traits<Packet>::integer_packet PacketI;
const Packet cst_sign_mask = pset1frombits<Packet>(static_cast<Eigen::numext::uint64_t>(0x8000000000000000u));
// If the argument is smaller than this value, use a simpler argument reduction
const double small_th = 15;
// If the argument is bigger than this value, use the non-vectorized std version
const double huge_th = 1e14;
const Packet cst_2oPI = pset1<Packet>(0.63661977236758134307553505349006); // 2/PI
// Integer Packet constants
const PacketI cst_one = pset1<PacketI>(1);
// Constant for splitting
const Packet cst_split = pset1<Packet>(1 << 24);
Packet x_abs = pabs(x);
// Scale x by 2/Pi
PacketI q_int;
Packet s;
// TODO Implement huge angle argument reduction
if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1<Packet>(small_th), x_abs)))) {
Packet q_high = pmul(pfloor(pmul(x_abs, pdiv(cst_2oPI, cst_split))), cst_split);
Packet q_low_noround = psub(pmul(x_abs, cst_2oPI), q_high);
q_int = pcast<Packet, PacketI>(padd(q_low_noround, pset1<Packet>(0.5)));
Packet q_low = pcast<PacketI, Packet>(q_int);
s = trig_reduce_medium_double(x_abs, q_high, q_low);
} else {
Packet qval_noround = pmul(x_abs, cst_2oPI);
q_int = pcast<Packet, PacketI>(padd(qval_noround, pset1<Packet>(0.5)));
Packet q = pcast<PacketI, Packet>(q_int);
s = trig_reduce_small_double(x_abs, q);
}
// All the upcoming approximating polynomials have even exponents
Packet ss = pmul(s, s);
// Padé approximant of cos(x)
// Assuring < 1 ULP error on the interval [-pi/4, pi/4]
// cos(x) ~= (80737373*x^8 - 13853547000*x^6 + 727718024880*x^4 - 11275015752000*x^2 + 23594700729600)/(147173*x^8 +
// 39328920*x^6 + 5772800880*x^4 + 522334612800*x^2 + 23594700729600)
// MATLAB code to compute those coefficients:
// syms x;
// cosf = @(x) cos(x);
// pade_cosf = pade(cosf(x), x, 0, 'Order', 8)
Packet sc1_num = pmadd(ss, pset1<Packet>(80737373), pset1<Packet>(-13853547000));
Packet sc2_num = pmadd(sc1_num, ss, pset1<Packet>(727718024880));
Packet sc3_num = pmadd(sc2_num, ss, pset1<Packet>(-11275015752000));
Packet sc4_num = pmadd(sc3_num, ss, pset1<Packet>(23594700729600));
Packet sc1_denum = pmadd(ss, pset1<Packet>(147173), pset1<Packet>(39328920));
Packet sc2_denum = pmadd(sc1_denum, ss, pset1<Packet>(5772800880));
Packet sc3_denum = pmadd(sc2_denum, ss, pset1<Packet>(522334612800));
Packet sc4_denum = pmadd(sc3_denum, ss, pset1<Packet>(23594700729600));
Packet scos = pdiv(sc4_num, sc4_denum);
// Padé approximant of sin(x)
// Assuring < 1 ULP error on the interval [-pi/4, pi/4]
// sin(x) ~= (x*(4585922449*x^8 - 1066023933480*x^6 + 83284044283440*x^4 - 2303682236856000*x^2 +
// 15605159573203200))/(45*(1029037*x^8 + 345207016*x^6 + 61570292784*x^4 + 6603948711360*x^2 + 346781323848960))
// MATLAB code to compute those coefficients:
// syms x;
// sinf = @(x) sin(x);
// pade_sinf = pade(sinf(x), x, 0, 'Order', 8, 'OrderMode', 'relative')
Packet ss1_num = pmadd(ss, pset1<Packet>(4585922449), pset1<Packet>(-1066023933480));
Packet ss2_num = pmadd(ss1_num, ss, pset1<Packet>(83284044283440));
Packet ss3_num = pmadd(ss2_num, ss, pset1<Packet>(-2303682236856000));
Packet ss4_num = pmadd(ss3_num, ss, pset1<Packet>(15605159573203200));
Packet ss1_denum = pmadd(ss, pset1<Packet>(1029037), pset1<Packet>(345207016));
Packet ss2_denum = pmadd(ss1_denum, ss, pset1<Packet>(61570292784));
Packet ss3_denum = pmadd(ss2_denum, ss, pset1<Packet>(6603948711360));
Packet ss4_denum = pmadd(ss3_denum, ss, pset1<Packet>(346781323848960));
Packet ssin = pdiv(pmul(s, ss4_num), pmul(pset1<Packet>(45), ss4_denum));
Packet poly_mask = preinterpret<Packet>(pcmp_eq(pand(q_int, cst_one), pzero(q_int)));
Packet sign_sin = pxor(x, preinterpret<Packet>(plogical_shift_left<62>(q_int)));
Packet sign_cos = preinterpret<Packet>(plogical_shift_left<62>(padd(q_int, cst_one)));
Packet sign_bit, sFinalRes;
if (ComputeBoth) {
Packet peven = peven_mask(x);
sign_bit = pselect((s), sign_sin, sign_cos);
sFinalRes = pselect(pxor(peven, poly_mask), ssin, scos);
} else {
sign_bit = ComputeSine ? sign_sin : sign_cos;
sFinalRes = ComputeSine ? pselect(poly_mask, ssin, scos) : pselect(poly_mask, scos, ssin);
}
sign_bit = pand(sign_bit, cst_sign_mask); // clear all but left most bit
sFinalRes = pxor(sFinalRes, sign_bit);
// If the inputs values are higher than that a value that the argument reduction can currently address, compute them
// using std::sin and std::cos
// TODO Remove it when huge angle argument reduction is implemented
if (EIGEN_PREDICT_FALSE(predux_any(pcmp_le(pset1<Packet>(huge_th), x_abs)))) {
const int PacketSize = unpacket_traits<Packet>::size;
EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) double sincos_vals[PacketSize];
EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) double x_cpy[PacketSize];
pstoreu(x_cpy, x);
pstoreu(sincos_vals, sFinalRes);
for (int k = 0; k < PacketSize; ++k) {
double val = x_cpy[k];
if (std::abs(val) > huge_th && (numext::isfinite)(val)) {
if (ComputeBoth)
sincos_vals[k] = k % 2 == 0 ? std::sin(val) : std::cos(val);
else
sincos_vals[k] = ComputeSine ? std::sin(val) : std::cos(val);
}
}
sFinalRes = ploadu<Packet>(sincos_vals);
}
return sFinalRes;
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_double(const Packet& x) {
return psincos_double<true>(x);
}
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_double(const Packet& x) {
return psincos_double<false>(x);
}
// Generic implementation of acos(x).
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pacos_float(const Packet& x_in) {

View File

@ -82,6 +82,14 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_float(const Pack
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_float(const Packet& x);
/** \internal \returns sin(x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet psin_double(const Packet& x);
/** \internal \returns cos(x) for double precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pcos_double(const Packet& x);
/** \internal \returns asin(x) for single precision float */
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pasin_float(const Packet& x);
@ -158,6 +166,8 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet pexp_complex(const Pa
#define EIGEN_INSTANTIATE_GENERIC_MATH_FUNCS_DOUBLE(PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(atan, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(sin, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(cos, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(log2, PACKET) \
EIGEN_DOUBLE_PACKET_FUNCTION(exp, PACKET)

View File

@ -5177,8 +5177,8 @@ struct packet_traits<double> : default_packet_traits {
HasLog = 1,
HasATan = 1,
#endif
HasSin = 0,
HasCos = 0,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasSqrt = 1,
HasRsqrt = 1,
HasTanh = 0,

View File

@ -218,6 +218,8 @@ struct packet_traits<double> : default_packet_traits {
HasCmp = 1,
HasDiv = 1,
HasSin = EIGEN_FAST_MATH,
HasCos = EIGEN_FAST_MATH,
HasLog = 1,
HasExp = 1,
HasSqrt = 1,