Reduce flakiness of test for Eigen::half.

This commit is contained in:
Antonio Sanchez 2025-03-23 22:31:25 -07:00
parent d935916ac6
commit 8e32cbf7da
3 changed files with 39 additions and 1 deletions

View File

@ -2444,6 +2444,26 @@ EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16
return float2half(rf);
}
template <>
EIGEN_STRONG_INLINE Packet16h pmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pmadd(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pmsub(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pnmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pnmadd(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE Packet16h pnmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
return float2half(pnmsub(half2float(a), half2float(b), half2float(c)));
}
template <>
EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
Packet16f from_float = half2float(from);

View File

@ -700,6 +700,12 @@ void packetmath() {
for (int i = 0; i < PacketSize; ++i) {
data1[i] = internal::random<Scalar>(Scalar(0) - limit, limit);
}
} else if (!NumTraits<Scalar>::IsInteger && !NumTraits<Scalar>::IsComplex) {
// Prevent very small product results by adjusting range. Otherwise,
// we may end up with multiplying e.g. 32 Eigen::halfs with values < 1.
for (int i = 0; i < PacketSize; ++i) {
data1[i] = internal::random<Scalar>(Scalar(0.5), Scalar(1)) * (internal::random<bool>() ? Scalar(-1) : Scalar(1));
}
}
ref[0] = Scalar(1);
for (int i = 0; i < PacketSize; ++i) ref[0] = REF_MUL(ref[0], data1[i]);

View File

@ -38,6 +38,15 @@ template <typename LhsType, typename RhsType>
std::enable_if_t<RhsType::SizeAtCompileTime != Dynamic, void> check_mismatched_product(LhsType& /*unused*/,
const RhsType& /*unused*/) {}
template <typename Scalar, typename V1, typename V2>
Scalar ref_dot_product(const V1& v1, const V2& v2) {
Scalar out = Scalar(0);
for (Index i = 0; i < v1.size(); ++i) {
out = Eigen::numext::fma(v1[i], v2[i], out);
}
return out;
}
template <typename MatrixType>
void product(const MatrixType& m) {
/* this test covers the following files:
@ -245,7 +254,10 @@ void product(const MatrixType& m) {
// inner product
{
Scalar x = square2.row(c) * square2.col(c2);
VERIFY_IS_APPROX(x, square2.row(c).transpose().cwiseProduct(square2.col(c2)).sum());
// NOTE: FMA is necessary here in the reference to ensure accuracy for
// large vector sizes and float16/bfloat16 types.
Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2));
VERIFY_IS_APPROX(x, y);
}
// outer product