From 7f0cb638c5e967da36d4599bdf0aea793e2f1303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Fri, 29 Aug 2025 18:11:25 +0000 Subject: [PATCH] Specialize numext::madd for half/bfloat16. --- Eigen/src/Core/arch/Default/BFloat16.h | 6 ++++++ Eigen/src/Core/arch/Default/Half.h | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/Eigen/src/Core/arch/Default/BFloat16.h b/Eigen/src/Core/arch/Default/BFloat16.h index f2e55f345..b93c4bc2e 100644 --- a/Eigen/src/Core/arch/Default/BFloat16.h +++ b/Eigen/src/Core/arch/Default/BFloat16.h @@ -793,6 +793,12 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, c return numext::bit_cast(from_bits); } +// Specialize multiply-add to match packet operations and reduce conversions to/from float. +template<> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) { + return Eigen::bfloat16(static_cast(x) * static_cast(y) + static_cast(z)); +} + } // namespace numext } // namespace Eigen diff --git a/Eigen/src/Core/arch/Default/Half.h b/Eigen/src/Core/arch/Default/Half.h index c073fe8ec..210dfff1e 100644 --- a/Eigen/src/Core/arch/Default/Half.h +++ b/Eigen/src/Core/arch/Default/Half.h @@ -955,6 +955,12 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast(c return Eigen::half_impl::raw_half_as_uint16(src); } +// Specialize multiply-add to match packet operations and reduce conversions to/from float. +template<> +EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half madd(const Eigen::half& x, const Eigen::half& y, const Eigen::half& z) { + return Eigen::half(static_cast(x) * static_cast(y) + static_cast(z)); +} + } // namespace numext } // namespace Eigen