Specialize numext::madd for half/bfloat16.

This commit is contained in:
Antonio Sánchez 2025-08-29 18:11:25 +00:00
parent 1e9d7ed7d3
commit 7f0cb638c5
2 changed files with 12 additions and 0 deletions

View File

@ -793,6 +793,12 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, c
return numext::bit_cast<bfloat16>(from_bits);
}
// Specialize multiply-add to match packet operations and reduce conversions to/from float.
template<>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) {
return Eigen::bfloat16(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
}
} // namespace numext
} // namespace Eigen

View File

@ -955,6 +955,12 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(c
return Eigen::half_impl::raw_half_as_uint16(src);
}
// Specialize multiply-add to match packet operations and reduce conversions to/from float.
template<>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half madd<Eigen::half>(const Eigen::half& x, const Eigen::half& y, const Eigen::half& z) {
return Eigen::half(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
}
} // namespace numext
} // namespace Eigen