From 2e8cc042a1f9c1b35e7ab3013bab2e01f7b04142 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20S=C3=A1nchez?= Date: Thu, 28 Aug 2025 21:40:19 +0000 Subject: [PATCH] Replace calls to numext::fma with numext:madd. --- Eigen/src/Core/GenericPacketMath.h | 10 ++-- Eigen/src/Core/MathFunctions.h | 74 ++++++++++++++++++++----- Eigen/src/Core/arch/SSE/PacketMath.h | 18 +++--- Eigen/src/SparseCore/SparseDot.h | 6 +- Eigen/src/SparseCore/TriangularSolver.h | 12 ++-- test/packetmath.cpp | 8 +-- test/product.h | 4 +- 7 files changed, 88 insertions(+), 44 deletions(-) diff --git a/Eigen/src/Core/GenericPacketMath.h b/Eigen/src/Core/GenericPacketMath.h index 21a1bfc41..139b10e8a 100644 --- a/Eigen/src/Core/GenericPacketMath.h +++ b/Eigen/src/Core/GenericPacketMath.h @@ -1350,20 +1350,20 @@ struct pmadd_impl { template struct pmadd_impl::value && NumTraits::IsSigned>> { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) { - return numext::fma(a, b, c); + return numext::madd(a, b, c); } static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) { - return numext::fma(a, b, Scalar(-c)); + return numext::madd(a, b, Scalar(-c)); } static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) { - return numext::fma(Scalar(-a), b, c); + return numext::madd(Scalar(-a), b, c); } static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) { - return -Scalar(numext::fma(a, b, c)); + return -Scalar(numext::madd(a, b, c)); } }; -// FMA instructions. +// Multiply-add instructions. /** \internal \returns a * b + c (coeff-wise) */ template EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) { diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index 481e057d0..44b16be54 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -941,23 +941,44 @@ struct nearest_integer_impl { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; } }; +// Extra namespace to prevent leaking std::fma into Eigen::internal. +namespace has_fma_detail { + +template +struct has_fma_impl : public std::false_type {}; + +using std::fma; + +template +struct has_fma_impl< + T, std::enable_if_t(), std::declval(), std::declval()))>::value>> + : public std::true_type {}; + +} // namespace has_fma_detail + +template +struct has_fma : public has_fma_detail::has_fma_impl {}; + // Default implementation. -template +template struct fma_impl { - static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) { - return a * b + c; + static_assert(has_fma::value, "No function fma(...) for type. Please provide an implementation."); +}; + +// STD or ADL version if it exists. +template +struct fma_impl::value>> { + static T run(const T& a, const T& b, const T& c) { + using std::fma; + return fma(a, b, c); } }; -// ADL version if it exists. -template -struct fma_impl< - T, - std::enable_if_t(), std::declval(), std::declval()))>::value>> { - static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); } -}; - #if defined(EIGEN_GPUCC) +template <> +struct has_fma : public true_type { +} + template <> struct fma_impl { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) { @@ -965,6 +986,10 @@ struct fma_impl { } }; +template <> +struct has_fma : public true_type { +} + template <> struct fma_impl { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) { @@ -973,6 +998,24 @@ struct fma_impl { }; #endif +// Basic multiply-add. +template +struct madd_impl { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) { + return x * y + z; + } +}; + +// Use FMA if there is a single CPU instruction. +#ifdef EIGEN_VECTORIZE_FMA +template +struct madd_impl::value>> { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) { + return fma_impl::run(x, y, z); + } +}; +#endif + } // end namespace internal /**************************************************************************** @@ -1886,15 +1929,18 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar return bit_cast(bit_cast(a) >> n); } -// Use std::fma if available. -using std::fma; - // Otherwise, rely on template implementation. template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) { return internal::fma_impl::run(x, y, z); } +// Multiply-add. +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar madd(const Scalar& x, const Scalar& y, const Scalar& z) { + return internal::madd_impl::run(x, y, z); +} + } // end namespace numext namespace internal { diff --git a/Eigen/src/Core/arch/SSE/PacketMath.h b/Eigen/src/Core/arch/SSE/PacketMath.h index 64ba7ba3a..b66a4db7c 100644 --- a/Eigen/src/Core/arch/SSE/PacketMath.h +++ b/Eigen/src/Core/arch/SSE/PacketMath.h @@ -2026,38 +2026,38 @@ EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d& } // Scalar path for pmadd with FMA to ensure consistency with vectorized path. -#ifdef EIGEN_VECTORIZE_FMA +#if defined(EIGEN_VECTORIZE_FMA) template <> EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) { - return ::fmaf(a, b, c); + return std::fmaf(a, b, c); } template <> EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) { - return ::fma(a, b, c); + return std::fma(a, b, c); } template <> EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) { - return ::fmaf(a, b, -c); + return std::fmaf(a, b, -c); } template <> EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) { - return ::fma(a, b, -c); + return std::fma(a, b, -c); } template <> EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) { - return ::fmaf(-a, b, c); + return std::fmaf(-a, b, c); } template <> EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) { - return ::fma(-a, b, c); + return std::fma(-a, b, c); } template <> EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) { - return ::fmaf(-a, b, -c); + return std::fmaf(-a, b, -c); } template <> EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) { - return ::fma(-a, b, -c); + return std::fma(-a, b, -c); } #endif diff --git a/Eigen/src/SparseCore/SparseDot.h b/Eigen/src/SparseCore/SparseDot.h index 8aeebc8f4..485605fd4 100644 --- a/Eigen/src/SparseCore/SparseDot.h +++ b/Eigen/src/SparseCore/SparseDot.h @@ -36,10 +36,10 @@ inline typename internal::traits::Scalar SparseMatrixBase::dot Scalar res1(0); Scalar res2(0); for (; i; ++i) { - res1 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res1); + res1 = numext::madd(numext::conj(i.value()), other.coeff(i.index()), res1); ++i; if (i) { - res2 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res2); + res2 = numext::madd(numext::conj(i.value()), other.coeff(i.index()), res2); } } return res1 + res2; @@ -67,7 +67,7 @@ inline typename internal::traits::Scalar SparseMatrixBase::dot Scalar res(0); while (i && j) { if (i.index() == j.index()) { - res = numext::fma(numext::conj(i.value()), j.value(), res); + res = numext::madd(numext::conj(i.value()), j.value(), res); ++i; ++j; } else if (i.index() < j.index()) diff --git a/Eigen/src/SparseCore/TriangularSolver.h b/Eigen/src/SparseCore/TriangularSolver.h index fb8c15781..684de4830 100644 --- a/Eigen/src/SparseCore/TriangularSolver.h +++ b/Eigen/src/SparseCore/TriangularSolver.h @@ -41,7 +41,7 @@ struct sparse_solve_triangular_selector { lastVal = it.value(); lastIndex = it.index(); if (lastIndex == i) break; - tmp = numext::fma(-lastVal, other.coeff(lastIndex, col), tmp); + tmp = numext::madd(-lastVal, other.coeff(lastIndex, col), tmp); } if (Mode & UnitDiag) other.coeffRef(i, col) = tmp; @@ -75,7 +75,7 @@ struct sparse_solve_triangular_selector { } else if (it && it.index() == i) ++it; for (; it; ++it) { - tmp = numext::fma(-it.value(), other.coeff(it.index(), col), tmp); + tmp = numext::madd(-it.value(), other.coeff(it.index(), col), tmp); } if (Mode & UnitDiag) @@ -108,7 +108,7 @@ struct sparse_solve_triangular_selector { } if (it && it.index() == i) ++it; for (; it; ++it) { - other.coeffRef(it.index(), col) = numext::fma(-tmp, it.value(), other.coeffRef(it.index(), col)); + other.coeffRef(it.index(), col) = numext::madd(-tmp, it.value(), other.coeffRef(it.index(), col)); } } } @@ -138,7 +138,7 @@ struct sparse_solve_triangular_selector { } LhsIterator it(lhsEval, i); for (; it && it.index() < i; ++it) { - other.coeffRef(it.index(), col) = numext::fma(-tmp, it.value(), other.coeffRef(it.index(), col)); + other.coeffRef(it.index(), col) = numext::madd(-tmp, it.value(), other.coeffRef(it.index(), col)); } } } @@ -220,11 +220,11 @@ struct sparse_solve_triangular_sparse_selector { if (IsLower) { if (it.index() == i) ++it; for (; it; ++it) { - tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index())); + tempVector.coeffRef(it.index()) = numext::madd(-ci, it.value(), tempVector.coeffRef(it.index())); } } else { for (; it && it.index() < i; ++it) { - tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index())); + tempVector.coeffRef(it.index()) = numext::madd(-ci, it.value(), tempVector.coeffRef(it.index())); } } } diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 5f48d713c..f21c72621 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -44,16 +44,16 @@ template struct madd_impl::value && Eigen::NumTraits::IsSigned>> { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) { - return numext::fma(a, b, c); + return numext::madd(a, b, c); } static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) { - return numext::fma(a, b, Scalar(-c)); + return numext::madd(a, b, Scalar(-c)); } static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) { - return numext::fma(Scalar(-a), b, c); + return numext::madd(Scalar(-a), b, c); } static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) { - return -Scalar(numext::fma(a, b, c)); + return -Scalar(numext::madd(a, b, c)); } }; diff --git a/test/product.h b/test/product.h index f37a932d0..21b470119 100644 --- a/test/product.h +++ b/test/product.h @@ -42,7 +42,7 @@ template Scalar ref_dot_product(const V1& v1, const V2& v2) { Scalar out = Scalar(0); for (Index i = 0; i < v1.size(); ++i) { - out = Eigen::numext::fma(v1[i], v2[i], out); + out = Eigen::numext::madd(v1[i], v2[i], out); } return out; } @@ -254,8 +254,6 @@ void product(const MatrixType& m) { // inner product { Scalar x = square2.row(c) * square2.col(c2); - // NOTE: FMA is necessary here in the reference to ensure accuracy for - // large vector sizes and float16/bfloat16 types. Scalar y = ref_dot_product(square2.row(c), square2.col(c2)); VERIFY_IS_APPROX(x, y); }