mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-23 01:59:38 +08:00
Reduce flakiness of test for Eigen::half.
This commit is contained in:
parent
d935916ac6
commit
8e32cbf7da
@ -2444,6 +2444,26 @@ EIGEN_STRONG_INLINE Packet16h pdiv<Packet16h>(const Packet16h& a, const Packet16
|
|||||||
return float2half(rf);
|
return float2half(rf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16h pmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||||
|
return float2half(pmadd(half2float(a), half2float(b), half2float(c)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16h pmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||||
|
return float2half(pmsub(half2float(a), half2float(b), half2float(c)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16h pnmadd<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||||
|
return float2half(pnmadd(half2float(a), half2float(b), half2float(c)));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
EIGEN_STRONG_INLINE Packet16h pnmsub<Packet16h>(const Packet16h& a, const Packet16h& b, const Packet16h& c) {
|
||||||
|
return float2half(pnmsub(half2float(a), half2float(b), half2float(c)));
|
||||||
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
|
EIGEN_STRONG_INLINE half predux<Packet16h>(const Packet16h& from) {
|
||||||
Packet16f from_float = half2float(from);
|
Packet16f from_float = half2float(from);
|
||||||
|
@ -700,6 +700,12 @@ void packetmath() {
|
|||||||
for (int i = 0; i < PacketSize; ++i) {
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
data1[i] = internal::random<Scalar>(Scalar(0) - limit, limit);
|
data1[i] = internal::random<Scalar>(Scalar(0) - limit, limit);
|
||||||
}
|
}
|
||||||
|
} else if (!NumTraits<Scalar>::IsInteger && !NumTraits<Scalar>::IsComplex) {
|
||||||
|
// Prevent very small product results by adjusting range. Otherwise,
|
||||||
|
// we may end up with multiplying e.g. 32 Eigen::halfs with values < 1.
|
||||||
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
|
data1[i] = internal::random<Scalar>(Scalar(0.5), Scalar(1)) * (internal::random<bool>() ? Scalar(-1) : Scalar(1));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ref[0] = Scalar(1);
|
ref[0] = Scalar(1);
|
||||||
for (int i = 0; i < PacketSize; ++i) ref[0] = REF_MUL(ref[0], data1[i]);
|
for (int i = 0; i < PacketSize; ++i) ref[0] = REF_MUL(ref[0], data1[i]);
|
||||||
|
@ -38,6 +38,15 @@ template <typename LhsType, typename RhsType>
|
|||||||
std::enable_if_t<RhsType::SizeAtCompileTime != Dynamic, void> check_mismatched_product(LhsType& /*unused*/,
|
std::enable_if_t<RhsType::SizeAtCompileTime != Dynamic, void> check_mismatched_product(LhsType& /*unused*/,
|
||||||
const RhsType& /*unused*/) {}
|
const RhsType& /*unused*/) {}
|
||||||
|
|
||||||
|
template <typename Scalar, typename V1, typename V2>
|
||||||
|
Scalar ref_dot_product(const V1& v1, const V2& v2) {
|
||||||
|
Scalar out = Scalar(0);
|
||||||
|
for (Index i = 0; i < v1.size(); ++i) {
|
||||||
|
out = Eigen::numext::fma(v1[i], v2[i], out);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename MatrixType>
|
template <typename MatrixType>
|
||||||
void product(const MatrixType& m) {
|
void product(const MatrixType& m) {
|
||||||
/* this test covers the following files:
|
/* this test covers the following files:
|
||||||
@ -245,7 +254,10 @@ void product(const MatrixType& m) {
|
|||||||
// inner product
|
// inner product
|
||||||
{
|
{
|
||||||
Scalar x = square2.row(c) * square2.col(c2);
|
Scalar x = square2.row(c) * square2.col(c2);
|
||||||
VERIFY_IS_APPROX(x, square2.row(c).transpose().cwiseProduct(square2.col(c2)).sum());
|
// NOTE: FMA is necessary here in the reference to ensure accuracy for
|
||||||
|
// large vector sizes and float16/bfloat16 types.
|
||||||
|
Scalar y = ref_dot_product<Scalar>(square2.row(c), square2.col(c2));
|
||||||
|
VERIFY_IS_APPROX(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
// outer product
|
// outer product
|
||||||
|
Loading…
x
Reference in New Issue
Block a user