Adjust thresholds for bfloat16 product tests that are currently failing.

This commit is contained in:
Antonio Sánchez 2022-12-28 19:32:25 +00:00 committed by Rasmus Munk Larsen
parent 311cc0f9cc
commit 910f6f65d0

View File

@ -17,6 +17,19 @@ bool areNotApprox(const MatrixBase<Derived1>& m1, const MatrixBase<Derived2>& m2
* (std::max)(m1.cwiseAbs2().maxCoeff(), m2.cwiseAbs2().maxCoeff())); * (std::max)(m1.cwiseAbs2().maxCoeff(), m2.cwiseAbs2().maxCoeff()));
} }
// Allow specifying tolerance for verifying error.
template<typename Type1, typename Type2, typename Tol>
inline bool verifyIsApprox(const Type1& a, const Type2& b, Tol tol)
{
bool ret = a.isApprox(b, tol);
if(!ret)
{
std::cerr << "Difference too large wrt tolerance " << tol << ", relative error is: " << test_relative_error(a,b) << std::endl;
}
return ret;
}
template <typename LhsType, typename RhsType> template <typename LhsType, typename RhsType>
std::enable_if_t<RhsType::SizeAtCompileTime==Dynamic,void> std::enable_if_t<RhsType::SizeAtCompileTime==Dynamic,void>
check_mismatched_product(LhsType& lhs, const RhsType& rhs) { check_mismatched_product(LhsType& lhs, const RhsType& rhs) {
@ -34,6 +47,7 @@ template<typename MatrixType> void product(const MatrixType& m)
Identity.h Product.h Identity.h Product.h
*/ */
typedef typename MatrixType::Scalar Scalar; typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::RealScalar RealScalar;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> RowVectorType; typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> RowVectorType;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> ColVectorType; typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, 1> ColVectorType;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> RowSquareMatrixType; typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::RowsAtCompileTime> RowSquareMatrixType;
@ -41,6 +55,11 @@ template<typename MatrixType> void product(const MatrixType& m)
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime, typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime,
MatrixType::Flags&RowMajorBit?ColMajor:RowMajor> OtherMajorMatrixType; MatrixType::Flags&RowMajorBit?ColMajor:RowMajor> OtherMajorMatrixType;
// Wwe want a tighter epsilon for not-approx tests. Otherwise, for certain
// low-precision types (e.g. bfloat16), the bound ends up being relatively large
// (e.g. 0.12), causing flaky tests.
RealScalar not_approx_epsilon = RealScalar(0.1) * NumTraits<RealScalar>::dummy_precision();
Index rows = m.rows(); Index rows = m.rows();
Index cols = m.cols(); Index cols = m.cols();
@ -68,7 +87,11 @@ template<typename MatrixType> void product(const MatrixType& m)
// begin testing Product.h: only associativity for now // begin testing Product.h: only associativity for now
// (we use Transpose.h but this doesn't count as a test for it) // (we use Transpose.h but this doesn't count as a test for it)
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2)); {
// Increase tolerance, since coefficients here can get relatively large.
RealScalar tol = RealScalar(2) * get_test_precision(m1);
VERIFY(verifyIsApprox((m1*m1.transpose())*m2, m1*(m1.transpose()*m2), tol));
}
m3 = m1; m3 = m1;
m3 *= m1.transpose() * m2; m3 *= m1.transpose() * m2;
VERIFY_IS_APPROX(m3, m1 * (m1.transpose()*m2)); VERIFY_IS_APPROX(m3, m1 * (m1.transpose()*m2));
@ -96,7 +119,7 @@ template<typename MatrixType> void product(const MatrixType& m)
// (we use the more accurate default epsilon) // (we use the more accurate default epsilon)
if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1) if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1)
{ {
VERIFY(areNotApprox(m1.transpose()*m2,m2.transpose()*m1)); VERIFY(areNotApprox(m1.transpose()*m2,m2.transpose()*m1, not_approx_epsilon));
} }
// test optimized operator+= path // test optimized operator+= path
@ -105,7 +128,7 @@ template<typename MatrixType> void product(const MatrixType& m)
VERIFY_IS_APPROX(res, square + m1 * m2.transpose()); VERIFY_IS_APPROX(res, square + m1 * m2.transpose());
if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1) if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1)
{ {
VERIFY(areNotApprox(res,square + m2 * m1.transpose())); VERIFY(areNotApprox(res,square + m2 * m1.transpose(), not_approx_epsilon));
} }
vcres = vc2; vcres = vc2;
vcres.noalias() += m1.transpose() * v1; vcres.noalias() += m1.transpose() * v1;
@ -117,7 +140,7 @@ template<typename MatrixType> void product(const MatrixType& m)
VERIFY_IS_APPROX(res, square - (m1 * m2.transpose())); VERIFY_IS_APPROX(res, square - (m1 * m2.transpose()));
if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1) if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1)
{ {
VERIFY(areNotApprox(res,square - m2 * m1.transpose())); VERIFY(areNotApprox(res,square - m2 * m1.transpose(), not_approx_epsilon));
} }
vcres = vc2; vcres = vc2;
vcres.noalias() -= m1.transpose() * v1; vcres.noalias() -= m1.transpose() * v1;
@ -169,7 +192,7 @@ template<typename MatrixType> void product(const MatrixType& m)
VERIFY_IS_APPROX(res2, square2 + m1.transpose() * m2); VERIFY_IS_APPROX(res2, square2 + m1.transpose() * m2);
if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1) if (!NumTraits<Scalar>::IsInteger && (std::min)(rows,cols)>1)
{ {
VERIFY(areNotApprox(res2,square2 + m2.transpose() * m1)); VERIFY(areNotApprox(res2,square2 + m2.transpose() * m1, not_approx_epsilon));
} }
VERIFY_IS_APPROX(res.col(r).noalias() = square.adjoint() * square.col(r), (square.adjoint() * square.col(r)).eval()); VERIFY_IS_APPROX(res.col(r).noalias() = square.adjoint() * square.col(r), (square.adjoint() * square.col(r)).eval());
@ -247,10 +270,12 @@ template<typename MatrixType> void product(const MatrixType& m)
// regression for blas_trais // regression for blas_trais
{ {
VERIFY_IS_APPROX(square * (square*square).transpose(), square * square.transpose() * square.transpose()); // Increase test tolerance, since coefficients can get relatively large.
VERIFY_IS_APPROX(square * (-(square*square)), -square * square * square); RealScalar tol = RealScalar(2) * get_test_precision(square);
VERIFY_IS_APPROX(square * (s1*(square*square)), s1 * square * square * square); VERIFY(verifyIsApprox(square * (square*square).transpose(), square * square.transpose() * square.transpose(), tol));
VERIFY_IS_APPROX(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate()); VERIFY(verifyIsApprox(square * (-(square*square)), -square * square * square, tol));
VERIFY(verifyIsApprox(square * (s1*(square*square)), s1 * square * square * square, tol));
VERIFY(verifyIsApprox(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate(), tol));
} }
// destination with a non-default inner-stride // destination with a non-default inner-stride