Replace calls to numext::fma with numext:madd.

This commit is contained in:
Antonio Sánchez 2025-08-28 21:40:19 +00:00 committed by Rasmus Munk Larsen
parent 52f570a409
commit 2e8cc042a1
7 changed files with 88 additions and 44 deletions

View File

@ -1350,20 +1350,20 @@ struct pmadd_impl {
template <typename Scalar> template <typename Scalar>
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> { struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, c); return numext::madd<Scalar>(a, b, c);
} }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, Scalar(-c)); return numext::madd<Scalar>(a, b, Scalar(-c));
} }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(Scalar(-a), b, c); return numext::madd<Scalar>(Scalar(-a), b, c);
} }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return -Scalar(numext::fma(a, b, c)); return -Scalar(numext::madd<Scalar>(a, b, c));
} }
}; };
// FMA instructions. // Multiply-add instructions.
/** \internal \returns a * b + c (coeff-wise) */ /** \internal \returns a * b + c (coeff-wise) */
template <typename Packet> template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) { EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {

View File

@ -941,23 +941,44 @@ struct nearest_integer_impl<Scalar, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; } static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
}; };
// Extra namespace to prevent leaking std::fma into Eigen::internal.
namespace has_fma_detail {
template <typename T, typename EnableIf = void>
struct has_fma_impl : public std::false_type {};
using std::fma;
template <typename T>
struct has_fma_impl<
T, std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>>
: public std::true_type {};
} // namespace has_fma_detail
template <typename T>
struct has_fma : public has_fma_detail::has_fma_impl<T> {};
// Default implementation. // Default implementation.
template <typename Scalar, typename Enable = void> template <typename T, typename Enable = void>
struct fma_impl { struct fma_impl {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) { static_assert(has_fma<T>::value, "No function fma(...) for type. Please provide an implementation.");
return a * b + c; };
// STD or ADL version if it exists.
template <typename T>
struct fma_impl<T, std::enable_if_t<has_fma<T>::value>> {
static T run(const T& a, const T& b, const T& c) {
using std::fma;
return fma(a, b, c);
} }
}; };
// ADL version if it exists.
template <typename T>
struct fma_impl<
T,
std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>> {
static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); }
};
#if defined(EIGEN_GPUCC) #if defined(EIGEN_GPUCC)
template <>
struct has_fma<float> : public true_type {
}
template <> template <>
struct fma_impl<float, void> { struct fma_impl<float, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
@ -965,6 +986,10 @@ struct fma_impl<float, void> {
} }
}; };
template <>
struct has_fma<double> : public true_type {
}
template <> template <>
struct fma_impl<double, void> { struct fma_impl<double, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
@ -973,6 +998,24 @@ struct fma_impl<double, void> {
}; };
#endif #endif
// Basic multiply-add.
template <typename Scalar, typename EnableIf = void>
struct madd_impl {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
return x * y + z;
}
};
// Use FMA if there is a single CPU instruction.
#ifdef EIGEN_VECTORIZE_FMA
template <typename Scalar>
struct madd_impl<Scalar, std::enable_if_t<has_fma<Scalar>::value>> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
return fma_impl<Scalar>::run(x, y, z);
}
};
#endif
} // end namespace internal } // end namespace internal
/**************************************************************************** /****************************************************************************
@ -1886,15 +1929,18 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n); return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
} }
// Use std::fma if available.
using std::fma;
// Otherwise, rely on template implementation. // Otherwise, rely on template implementation.
template <typename Scalar> template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
return internal::fma_impl<Scalar>::run(x, y, z); return internal::fma_impl<Scalar>::run(x, y, z);
} }
// Multiply-add.
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar madd(const Scalar& x, const Scalar& y, const Scalar& z) {
return internal::madd_impl<Scalar>::run(x, y, z);
}
} // end namespace numext } // end namespace numext
namespace internal { namespace internal {

View File

@ -2026,38 +2026,38 @@ EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d&
} }
// Scalar path for pmadd with FMA to ensure consistency with vectorized path. // Scalar path for pmadd with FMA to ensure consistency with vectorized path.
#ifdef EIGEN_VECTORIZE_FMA #if defined(EIGEN_VECTORIZE_FMA)
template <> template <>
EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) { EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
return ::fmaf(a, b, c); return std::fmaf(a, b, c);
} }
template <> template <>
EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) { EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
return ::fma(a, b, c); return std::fma(a, b, c);
} }
template <> template <>
EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) { EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
return ::fmaf(a, b, -c); return std::fmaf(a, b, -c);
} }
template <> template <>
EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) { EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
return ::fma(a, b, -c); return std::fma(a, b, -c);
} }
template <> template <>
EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) { EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
return ::fmaf(-a, b, c); return std::fmaf(-a, b, c);
} }
template <> template <>
EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) { EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
return ::fma(-a, b, c); return std::fma(-a, b, c);
} }
template <> template <>
EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) { EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
return ::fmaf(-a, b, -c); return std::fmaf(-a, b, -c);
} }
template <> template <>
EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) { EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
return ::fma(-a, b, -c); return std::fma(-a, b, -c);
} }
#endif #endif

View File

@ -36,10 +36,10 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
Scalar res1(0); Scalar res1(0);
Scalar res2(0); Scalar res2(0);
for (; i; ++i) { for (; i; ++i) {
res1 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res1); res1 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res1);
++i; ++i;
if (i) { if (i) {
res2 = numext::fma(numext::conj(i.value()), other.coeff(i.index()), res2); res2 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res2);
} }
} }
return res1 + res2; return res1 + res2;
@ -67,7 +67,7 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
Scalar res(0); Scalar res(0);
while (i && j) { while (i && j) {
if (i.index() == j.index()) { if (i.index() == j.index()) {
res = numext::fma(numext::conj(i.value()), j.value(), res); res = numext::madd<Scalar>(numext::conj(i.value()), j.value(), res);
++i; ++i;
++j; ++j;
} else if (i.index() < j.index()) } else if (i.index() < j.index())

View File

@ -41,7 +41,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, RowMajor> {
lastVal = it.value(); lastVal = it.value();
lastIndex = it.index(); lastIndex = it.index();
if (lastIndex == i) break; if (lastIndex == i) break;
tmp = numext::fma(-lastVal, other.coeff(lastIndex, col), tmp); tmp = numext::madd<Scalar>(-lastVal, other.coeff(lastIndex, col), tmp);
} }
if (Mode & UnitDiag) if (Mode & UnitDiag)
other.coeffRef(i, col) = tmp; other.coeffRef(i, col) = tmp;
@ -75,7 +75,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, RowMajor> {
} else if (it && it.index() == i) } else if (it && it.index() == i)
++it; ++it;
for (; it; ++it) { for (; it; ++it) {
tmp = numext::fma<Scalar>(-it.value(), other.coeff(it.index(), col), tmp); tmp = numext::madd<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
} }
if (Mode & UnitDiag) if (Mode & UnitDiag)
@ -108,7 +108,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, ColMajor> {
} }
if (it && it.index() == i) ++it; if (it && it.index() == i) ++it;
for (; it; ++it) { for (; it; ++it) {
other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col)); other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
} }
} }
} }
@ -138,7 +138,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, ColMajor> {
} }
LhsIterator it(lhsEval, i); LhsIterator it(lhsEval, i);
for (; it && it.index() < i; ++it) { for (; it && it.index() < i; ++it) {
other.coeffRef(it.index(), col) = numext::fma<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col)); other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
} }
} }
} }
@ -220,11 +220,11 @@ struct sparse_solve_triangular_sparse_selector<Lhs, Rhs, Mode, UpLo, ColMajor> {
if (IsLower) { if (IsLower) {
if (it.index() == i) ++it; if (it.index() == i) ++it;
for (; it; ++it) { for (; it; ++it) {
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index())); tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
} }
} else { } else {
for (; it && it.index() < i; ++it) { for (; it && it.index() < i; ++it) {
tempVector.coeffRef(it.index()) = numext::fma(-ci, it.value(), tempVector.coeffRef(it.index())); tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
} }
} }
} }

View File

@ -44,16 +44,16 @@ template <typename Scalar>
struct madd_impl<Scalar, struct madd_impl<Scalar,
std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> { std::enable_if_t<Eigen::internal::is_scalar<Scalar>::value && Eigen::NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar madd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, c); return numext::madd(a, b, c);
} }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar msub(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, Scalar(-c)); return numext::madd(a, b, Scalar(-c));
} }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(Scalar(-a), b, c); return numext::madd(Scalar(-a), b, c);
} }
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) { static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar nmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return -Scalar(numext::fma(a, b, c)); return -Scalar(numext::madd(a, b, c));
} }
}; };

View File

@ -42,7 +42,7 @@ template <typename Scalar, typename V1, typename V2>
Scalar ref_dot_product(const V1& v1, const V2& v2) { Scalar ref_dot_product(const V1& v1, const V2& v2) {
Scalar out = Scalar(0); Scalar out = Scalar(0);
for (Index i = 0; i < v1.size(); ++i) { for (Index i = 0; i < v1.size(); ++i) {
out = Eigen::numext::fma(v1[i], v2[i], out); out = Eigen::numext::madd(v1[i], v2[i], out);
} }
return out; return out;
} }
@ -254,8 +254,6 @@ void product(const MatrixType& m) {
// inner product // inner product
{ {
Scalar x = square2.row(c) * square2.col(c2); Scalar x = square2.row(c) * square2.col(c2);
// NOTE: FMA is necessary here in the reference to ensure accuracy for
// large vector sizes and float16/bfloat16 types.
Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2)); Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2));
VERIFY_IS_APPROX(x, y); VERIFY_IS_APPROX(x, y);
} }