Remove inline assembly for FMA (AVX) and add remaining extensions as packet ops: pmsub, pnmadd, and pnmsub.

This commit is contained in:
Rasmus Munk Larsen 2022-01-26 04:25:41 +00:00
parent 4e629b3c1b
commit 51311ec651
5 changed files with 152 additions and 31 deletions

View File

@ -939,6 +939,35 @@ template<typename Packet> EIGEN_DEVICE_FUNC inline bool predux_any(const Packet&
* The following functions might not have to be overwritten for vectorized types
***************************************************************************/
// FMA instructions.
/** \internal \returns a * b + c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b,
const Packet& c) {
return padd(pmul(a, b), c);
}
/** \internal \returns a * b - c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmsub(const Packet& a, const Packet& b,
const Packet& c) {
return psub(pmul(a, b), c);
}
/** \internal \returns -(a * b) + c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pnmadd(const Packet& a, const Packet& b,
const Packet& c) {
return padd(pnegate(pmul(a, b)), c);
}
/** \internal \returns -(a * b) - c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pnmsub(const Packet& a, const Packet& b,
const Packet& c) {
return psub(pnegate(pmul(a, b)), c);
}
/** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned */
// NOTE: this function must really be templated on the packet type (think about different packet types for the same scalar type)
template<typename Packet>
@ -947,13 +976,6 @@ inline void pstore1(typename unpacket_traits<Packet>::type* to, const typename u
pstore(to, pset1<Packet>(a));
}
/** \internal \returns a * b + c (coeff-wise) */
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
pmadd(const Packet& a,
const Packet& b,
const Packet& c)
{ return padd(pmul(a, b),c); }
/** \internal \returns a packet version of \a *from.
* The pointer \a from must be aligned on a \a Alignment bytes boundary. */
template<typename Packet, int Alignment>

View File

@ -540,30 +540,46 @@ template<> EIGEN_STRONG_INLINE Packet8i pdiv<Packet8i>(const Packet8i& /*a*/, co
}
#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) )
// Clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers,
// and even register spilling with clang>=6.0 (bug 1637).
// Gcc stupidly generates a vfmadd132ps instruction.
// So let's enforce it to generate a vfmadd231ps instruction since the most common use
// case is to accumulate the result of the product.
Packet8f res = c;
__asm__("vfmadd231ps %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b));
return res;
#else
return _mm256_fmadd_ps(a,b,c);
#endif
template <>
EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
return _mm256_fmadd_ps(a, b, c);
}
template<> EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) )
// see above
Packet4d res = c;
__asm__("vfmadd231pd %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b));
return res;
#else
return _mm256_fmadd_pd(a,b,c);
#endif
template <>
EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
return _mm256_fmadd_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8f pmsub(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
return _mm256_fmsub_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet4d pmsub(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
return _mm256_fmsub_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8f pnmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
return _mm256_fnmadd_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet4d pnmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
return _mm256_fnmadd_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8f pnmsub(const Packet8f& a, const Packet8f& b, const Packet8f& c) {
return _mm256_fnmsub_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet4d pnmsub(const Packet4d& a, const Packet4d& b, const Packet4d& c) {
return _mm256_fnmsub_pd(a, b, c);
}
#endif
template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); }

View File

@ -359,6 +359,39 @@ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b,
const Packet8d& c) {
return _mm512_fmadd_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16f pmsub(const Packet16f& a, const Packet16f& b,
const Packet16f& c) {
return _mm512_fmsub_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8d pmsub(const Packet8d& a, const Packet8d& b,
const Packet8d& c) {
return _mm512_fmsub_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16f pnmadd(const Packet16f& a, const Packet16f& b,
const Packet16f& c) {
return _mm512_fnmadd_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8d pnmadd(const Packet8d& a, const Packet8d& b,
const Packet8d& c) {
return _mm512_fnmadd_pd(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet16f pnmsub(const Packet16f& a, const Packet16f& b,
const Packet16f& c) {
return _mm512_fnmsub_ps(a, b, c);
}
template <>
EIGEN_STRONG_INLINE Packet8d pnmsub(const Packet8d& a, const Packet8d& b,
const Packet8d& c) {
return _mm512_fnmsub_pd(a, b, c);
}
#endif
template <>
@ -2281,13 +2314,13 @@ EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) {
template <>
EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) {
return Packet16bf(pand<Packet8i>((Packet8i)a, (Packet8i)b));
return Packet16bf(pand<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>
EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a,
const Packet16bf& b) {
return Packet16bf(pandnot<Packet8i>((Packet8i)a, (Packet8i)b));
return Packet16bf(pandnot<Packet8i>(Packet8i(a), Packet8i(b)));
}
template <>

View File

@ -364,6 +364,12 @@ template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i&
#ifdef EIGEN_VECTORIZE_FMA
template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet4f pmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmsub_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmsub_pd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmadd_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); }
template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); }
#endif
#ifdef EIGEN_VECTORIZE_SSE4_1
@ -1263,6 +1269,24 @@ template<> EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const
template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
return ::fma(a,b,c);
}
template<> EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
return ::fmaf(a,b,-c);
}
template<> EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
return ::fma(a,b,-c);
}
template<> EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
return ::fmaf(-a,b,c);
}
template<> EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
return ::fma(-a,b,c);
}
template<> EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
return ::fmaf(-a,b,-c);
}
template<> EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
return ::fma(-a,b,-c);
}
#endif
#ifdef EIGEN_VECTORIZE_SSE4_1

View File

@ -24,6 +24,22 @@ inline T REF_MUL(const T& a, const T& b) {
return a * b;
}
template <typename T>
inline T REF_MADD(const T& a, const T& b, const T& c) {
return a * b + c;
}
template <typename T>
inline T REF_MSUB(const T& a, const T& b, const T& c) {
return a * b - c;
}
template <typename T>
inline T REF_NMADD(const T& a, const T& b, const T& c) {
return (-a * b) + c;
}
template <typename T>
inline T REF_NMSUB(const T& a, const T& b, const T& c) {
return (-a * b) - c;
}
template <typename T>
inline T REF_DIV(const T& a, const T& b) {
return a / b;
}
@ -49,6 +65,10 @@ template <>
inline bool REF_MUL(const bool& a, const bool& b) {
return a && b;
}
template <>
inline bool REF_MADD(const bool& a, const bool& b, const bool& c) {
return (a && b) || c;
}
template <typename T>
inline T REF_FREXP(const T& x, T& exp) {
@ -622,6 +642,12 @@ void packetmath() {
}
CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt);
CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt);
CHECK_CWISE3_IF(true, REF_MADD, internal::pmadd);
if (!std::is_same<Scalar, bool>::value) {
CHECK_CWISE3_IF(true, REF_MSUB, internal::pmsub);
CHECK_CWISE3_IF(true, REF_NMADD, internal::pnmadd);
CHECK_CWISE3_IF(true, REF_NMSUB, internal::pnmsub);
}
}
// Notice that this definition works for complex types as well.