Speed up pldexp_generic.

This commit is contained in:
Rasmus Munk Larsen 2024-04-12 01:32:17 +00:00
parent 3c6521ed90
commit 5226566a14
2 changed files with 5 additions and 5 deletions

View File

@ -1293,13 +1293,13 @@ EIGEN_DEVICE_FUNC inline Packet pmsub(const Packet& a, const Packet& b, const Pa
/** \internal \returns -(a * b) + c (coeff-wise) */ /** \internal \returns -(a * b) + c (coeff-wise) */
template <typename Packet> template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pnmadd(const Packet& a, const Packet& b, const Packet& c) { EIGEN_DEVICE_FUNC inline Packet pnmadd(const Packet& a, const Packet& b, const Packet& c) {
return padd(pnegate(pmul(a, b)), c); return psub(c, pmul(a, b));
} }
/** \internal \returns -(a * b) - c (coeff-wise) */ /** \internal \returns -((a * b + c) (coeff-wise) */
template <typename Packet> template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pnmsub(const Packet& a, const Packet& b, const Packet& c) { EIGEN_DEVICE_FUNC inline Packet pnmsub(const Packet& a, const Packet& b, const Packet& c) {
return psub(pnegate(pmul(a, b)), c); return pnegate(pmadd(a, b, c));
} }
/** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned /** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned

View File

@ -129,8 +129,8 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Packet pldexp_generic(const Packet& a, con
const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent)); const PacketI e = pcast<Packet, PacketI>(pmin(pmax(exponent, pnegate(max_exponent)), max_exponent));
PacketI b = parithmetic_shift_right<2>(e); // floor(e/4); PacketI b = parithmetic_shift_right<2>(e); // floor(e/4);
Packet c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^b Packet c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^b
Packet out = pmul(pmul(pmul(a, c), c), c); // a * 2^(3b) Packet out = pmul(pmul(a, c), pmul(c, c)); // a * 2^(3b)
b = psub(psub(psub(e, b), b), b); // e - 3b b = pnmadd(pset1<PacketI>(3), b, e); // e - 3b
c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^(e-3*b) c = preinterpret<Packet>(plogical_shift_left<MantissaBits>(padd(b, bias))); // 2^(e-3*b)
out = pmul(out, c); out = pmul(out, c);
return out; return out;