diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 4f1ff6b03..f60724ac0 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -939,6 +939,35 @@ template EIGEN_DEVICE_FUNC inline bool predux_any(const Packet& * The following functions might not have to be overwritten for vectorized types ***************************************************************************/ +// FMA instructions. +/** \internal \returns a * b + c (coeff-wise) */ +template +EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, + const Packet& c) { + return padd(pmul(a, b), c); +} + +/** \internal \returns a * b - c (coeff-wise) */ +template +EIGEN_DEVICE_FUNC inline Packet pmsub(const Packet& a, const Packet& b, + const Packet& c) { + return psub(pmul(a, b), c); +} + +/** \internal \returns -(a * b) + c (coeff-wise) */ +template +EIGEN_DEVICE_FUNC inline Packet pnmadd(const Packet& a, const Packet& b, + const Packet& c) { + return padd(pnegate(pmul(a, b)), c); +} + +/** \internal \returns -(a * b) - c (coeff-wise) */ +template +EIGEN_DEVICE_FUNC inline Packet pnmsub(const Packet& a, const Packet& b, + const Packet& c) { + return psub(pnegate(pmul(a, b)), c); +} + /** \internal copy a packet with constant coefficient \a a (e.g., [a,a,a,a]) to \a *to. \a to must be 16 bytes aligned */ // NOTE: this function must really be templated on the packet type (think about different packet types for the same scalar type) template @@ -947,13 +976,6 @@ inline void pstore1(typename unpacket_traits::type* to, const typename u pstore(to, pset1(a)); } -/** \internal \returns a * b + c (coeff-wise) */ -template EIGEN_DEVICE_FUNC inline Packet -pmadd(const Packet& a, - const Packet& b, - const Packet& c) -{ return padd(pmul(a, b),c); } - /** \internal \returns a packet version of \a *from. * The pointer \a from must be aligned on a \a Alignment bytes boundary. */ template diff --git a/Eigen/src/Core/arch/AVX/PacketMath.h b/Eigen/src/Core/arch/AVX/PacketMath.h index bf832c9c2..2df899d6e 100644 --- a/Eigen/src/Core/arch/AVX/PacketMath.h +++ b/Eigen/src/Core/arch/AVX/PacketMath.h @@ -540,30 +540,46 @@ template<> EIGEN_STRONG_INLINE Packet8i pdiv(const Packet8i& /*a*/, co } #ifdef EIGEN_VECTORIZE_FMA -template<> EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) { -#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) ) - // Clang stupidly generates a vfmadd213ps instruction plus some vmovaps on registers, - // and even register spilling with clang>=6.0 (bug 1637). - // Gcc stupidly generates a vfmadd132ps instruction. - // So let's enforce it to generate a vfmadd231ps instruction since the most common use - // case is to accumulate the result of the product. - Packet8f res = c; - __asm__("vfmadd231ps %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b)); - return res; -#else - return _mm256_fmadd_ps(a,b,c); -#endif + +template <> +EIGEN_STRONG_INLINE Packet8f pmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) { + return _mm256_fmadd_ps(a, b, c); } -template<> EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) { -#if ( (EIGEN_COMP_GNUC_STRICT && EIGEN_COMP_GNUC<80) || (EIGEN_COMP_CLANG) ) - // see above - Packet4d res = c; - __asm__("vfmadd231pd %[a], %[b], %[c]" : [c] "+x" (res) : [a] "x" (a), [b] "x" (b)); - return res; -#else - return _mm256_fmadd_pd(a,b,c); -#endif +template <> +EIGEN_STRONG_INLINE Packet4d pmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) { + return _mm256_fmadd_pd(a, b, c); } + +template <> +EIGEN_STRONG_INLINE Packet8f pmsub(const Packet8f& a, const Packet8f& b, const Packet8f& c) { + return _mm256_fmsub_ps(a, b, c); +} + +template <> +EIGEN_STRONG_INLINE Packet4d pmsub(const Packet4d& a, const Packet4d& b, const Packet4d& c) { + return _mm256_fmsub_pd(a, b, c); +} + +template <> +EIGEN_STRONG_INLINE Packet8f pnmadd(const Packet8f& a, const Packet8f& b, const Packet8f& c) { + return _mm256_fnmadd_ps(a, b, c); +} + +template <> +EIGEN_STRONG_INLINE Packet4d pnmadd(const Packet4d& a, const Packet4d& b, const Packet4d& c) { + return _mm256_fnmadd_pd(a, b, c); +} + +template <> +EIGEN_STRONG_INLINE Packet8f pnmsub(const Packet8f& a, const Packet8f& b, const Packet8f& c) { + return _mm256_fnmsub_ps(a, b, c); +} + +template <> +EIGEN_STRONG_INLINE Packet4d pnmsub(const Packet4d& a, const Packet4d& b, const Packet4d& c) { + return _mm256_fnmsub_pd(a, b, c); +} + #endif template<> EIGEN_STRONG_INLINE Packet8f pcmp_le(const Packet8f& a, const Packet8f& b) { return _mm256_cmp_ps(a,b,_CMP_LE_OQ); } diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 8a00c626c..d0ccfd8cc 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -359,6 +359,39 @@ EIGEN_STRONG_INLINE Packet8d pmadd(const Packet8d& a, const Packet8d& b, const Packet8d& c) { return _mm512_fmadd_pd(a, b, c); } + +template <> +EIGEN_STRONG_INLINE Packet16f pmsub(const Packet16f& a, const Packet16f& b, + const Packet16f& c) { + return _mm512_fmsub_ps(a, b, c); +} +template <> +EIGEN_STRONG_INLINE Packet8d pmsub(const Packet8d& a, const Packet8d& b, + const Packet8d& c) { + return _mm512_fmsub_pd(a, b, c); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pnmadd(const Packet16f& a, const Packet16f& b, + const Packet16f& c) { + return _mm512_fnmadd_ps(a, b, c); +} +template <> +EIGEN_STRONG_INLINE Packet8d pnmadd(const Packet8d& a, const Packet8d& b, + const Packet8d& c) { + return _mm512_fnmadd_pd(a, b, c); +} + +template <> +EIGEN_STRONG_INLINE Packet16f pnmsub(const Packet16f& a, const Packet16f& b, + const Packet16f& c) { + return _mm512_fnmsub_ps(a, b, c); +} +template <> +EIGEN_STRONG_INLINE Packet8d pnmsub(const Packet8d& a, const Packet8d& b, + const Packet8d& c) { + return _mm512_fnmsub_pd(a, b, c); +} #endif template <> @@ -2281,13 +2314,13 @@ EIGEN_STRONG_INLINE Packet16bf pxor(const Packet16bf& a, const Packet16bf& b) { template <> EIGEN_STRONG_INLINE Packet16bf pand(const Packet16bf& a, const Packet16bf& b) { - return Packet16bf(pand((Packet8i)a, (Packet8i)b)); + return Packet16bf(pand(Packet8i(a), Packet8i(b))); } template <> EIGEN_STRONG_INLINE Packet16bf pandnot(const Packet16bf& a, const Packet16bf& b) { - return Packet16bf(pandnot((Packet8i)a, (Packet8i)b)); + return Packet16bf(pandnot(Packet8i(a), Packet8i(b))); } template <> diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 4de3d4707..80f86ffab 100755 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -364,6 +364,12 @@ template<> EIGEN_STRONG_INLINE Packet4i pmadd(const Packet4i& a, const Packet4i& #ifdef EIGEN_VECTORIZE_FMA template<> EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmadd_ps(a,b,c); } template<> EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmadd_pd(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet4f pmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fmsub_ps(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet2d pmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fmsub_pd(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmadd_ps(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmadd_pd(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) { return _mm_fnmsub_ps(a,b,c); } +template<> EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) { return _mm_fnmsub_pd(a,b,c); } #endif #ifdef EIGEN_VECTORIZE_SSE4_1 @@ -1263,6 +1269,24 @@ template<> EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const template<> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) { return ::fma(a,b,c); } +template<> EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) { + return ::fmaf(a,b,-c); +} +template<> EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) { + return ::fma(a,b,-c); +} +template<> EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) { + return ::fmaf(-a,b,c); +} +template<> EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) { + return ::fma(-a,b,c); +} +template<> EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) { + return ::fmaf(-a,b,-c); +} +template<> EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) { + return ::fma(-a,b,-c); +} #endif #ifdef EIGEN_VECTORIZE_SSE4_1 diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 455ecab09..e60308a90 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -24,6 +24,22 @@ inline T REF_MUL(const T& a, const T& b) { return a * b; } template +inline T REF_MADD(const T& a, const T& b, const T& c) { + return a * b + c; +} +template +inline T REF_MSUB(const T& a, const T& b, const T& c) { + return a * b - c; +} +template +inline T REF_NMADD(const T& a, const T& b, const T& c) { + return (-a * b) + c; +} +template +inline T REF_NMSUB(const T& a, const T& b, const T& c) { + return (-a * b) - c; +} +template inline T REF_DIV(const T& a, const T& b) { return a / b; } @@ -49,6 +65,10 @@ template <> inline bool REF_MUL(const bool& a, const bool& b) { return a && b; } +template <> +inline bool REF_MADD(const bool& a, const bool& b, const bool& c) { + return (a && b) || c; +} template inline T REF_FREXP(const T& x, T& exp) { @@ -622,6 +642,12 @@ void packetmath() { } CHECK_CWISE1_IF(PacketTraits::HasSqrt, numext::sqrt, internal::psqrt); CHECK_CWISE1_IF(PacketTraits::HasRsqrt, numext::rsqrt, internal::prsqrt); + CHECK_CWISE3_IF(true, REF_MADD, internal::pmadd); + if (!std::is_same::value) { + CHECK_CWISE3_IF(true, REF_MSUB, internal::pmsub); + CHECK_CWISE3_IF(true, REF_NMADD, internal::pnmadd); + CHECK_CWISE3_IF(true, REF_NMSUB, internal::pnmsub); + } } // Notice that this definition works for complex types as well.