diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 1fe212ab9..b7c8b9028 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -2953,7 +2953,27 @@ EIGEN_STRONG_INLINE Packet16bf psub(const Packet16bf& a, const Packe template <> EIGEN_STRONG_INLINE Packet16bf pmul(const Packet16bf& a, const Packet16bf& b) { - return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); + return F32ToBf16(pmul(Bf16ToF32(a), Bf16ToF32(b))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmadd(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) { + return F32ToBf16(pmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pmsub(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) { + return F32ToBf16(pmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pnmadd(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) { + return F32ToBf16(pnmadd(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c))); +} + +template <> +EIGEN_STRONG_INLINE Packet16bf pnmsub(const Packet16bf& a, const Packet16bf& b, const Packet16bf& c) { + return F32ToBf16(pnmsub(Bf16ToF32(a), Bf16ToF32(b), Bf16ToF32(c))); } template <>