diff --git a/Eigen/src/Core/arch/AVX512/PacketMath.h b/Eigen/src/Core/arch/AVX512/PacketMath.h index 8087a6166..1fe212ab9 100644 --- a/Eigen/src/Core/arch/AVX512/PacketMath.h +++ b/Eigen/src/Core/arch/AVX512/PacketMath.h @@ -2444,6 +2444,26 @@ EIGEN_STRONG_INLINE Packet16h pdiv(const Packet16h& a, const Packet16 return float2half(rf); } +template <> +EIGEN_STRONG_INLINE Packet16h pmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) { + return float2half(pmadd(half2float(a), half2float(b), half2float(c))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) { + return float2half(pmsub(half2float(a), half2float(b), half2float(c))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pnmadd(const Packet16h& a, const Packet16h& b, const Packet16h& c) { + return float2half(pnmadd(half2float(a), half2float(b), half2float(c))); +} + +template <> +EIGEN_STRONG_INLINE Packet16h pnmsub(const Packet16h& a, const Packet16h& b, const Packet16h& c) { + return float2half(pnmsub(half2float(a), half2float(b), half2float(c))); +} + template <> EIGEN_STRONG_INLINE half predux(const Packet16h& from) { Packet16f from_float = half2float(from); diff --git a/test/packetmath.cpp b/test/packetmath.cpp index 1e85f8b31..ed3950927 100644 --- a/test/packetmath.cpp +++ b/test/packetmath.cpp @@ -700,6 +700,12 @@ void packetmath() { for (int i = 0; i < PacketSize; ++i) { data1[i] = internal::random(Scalar(0) - limit, limit); } + } else if (!NumTraits::IsInteger && !NumTraits::IsComplex) { + // Prevent very small product results by adjusting range. Otherwise, + // we may end up with multiplying e.g. 32 Eigen::halfs with values < 1. + for (int i = 0; i < PacketSize; ++i) { + data1[i] = internal::random(Scalar(0.5), Scalar(1)) * (internal::random() ? Scalar(-1) : Scalar(1)); + } } ref[0] = Scalar(1); for (int i = 0; i < PacketSize; ++i) ref[0] = REF_MUL(ref[0], data1[i]); diff --git a/test/product.h b/test/product.h index f8eb5df85..f37a932d0 100644 --- a/test/product.h +++ b/test/product.h @@ -38,6 +38,15 @@ template std::enable_if_t check_mismatched_product(LhsType& /*unused*/, const RhsType& /*unused*/) {} +template +Scalar ref_dot_product(const V1& v1, const V2& v2) { + Scalar out = Scalar(0); + for (Index i = 0; i < v1.size(); ++i) { + out = Eigen::numext::fma(v1[i], v2[i], out); + } + return out; +} + template void product(const MatrixType& m) { /* this test covers the following files: @@ -245,7 +254,10 @@ void product(const MatrixType& m) { // inner product { Scalar x = square2.row(c) * square2.col(c2); - VERIFY_IS_APPROX(x, square2.row(c).transpose().cwiseProduct(square2.col(c2)).sum()); + // NOTE: FMA is necessary here in the reference to ensure accuracy for + // large vector sizes and float16/bfloat16 types. + Scalar y = ref_dot_product(square2.row(c), square2.col(c2)); + VERIFY_IS_APPROX(x, y); } // outer product