Use fma<float> for fma<half> and fma<bfloat16> if native fma is not available on the platform.

This commit is contained in:
Rasmus Munk Larsen 2025-03-28 04:26:04 +00:00 committed by Antonio Sánchez
parent 44fb6422be
commit 63a40ffb95
2 changed files with 2 additions and 2 deletions

View File

@ -675,7 +675,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 fmax(const bfloat16& a, const bfl
EIGEN_DEVICE_FUNC inline bfloat16 fma(const bfloat16& a, const bfloat16& b, const bfloat16& c) {
// Emulate FMA via float.
return bfloat16(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
return bfloat16(numext::fma(static_cast<float>(a), static_cast<float>(b), static_cast<float>(c)));
}
#ifndef EIGEN_NO_IO

View File

@ -812,7 +812,7 @@ EIGEN_DEVICE_FUNC inline half fma(const half& a, const half& b, const half& c) {
return half(_mm_cvtsh_h(_mm_fmadd_ph(_mm_set_sh(a.x), _mm_set_sh(b.x), _mm_set_sh(c.x))));
#else
// Emulate FMA via float.
return half(static_cast<float>(a) * static_cast<float>(b) + static_cast<float>(c));
return half(numext::fma(static_cast<float>(a), static_cast<float>(b), static_cast<float>(c)));
#endif
}