diff --git a/test/product.h b/test/product.h index bd9fa7127..b0ce06db9 100644 --- a/test/product.h +++ b/test/product.h @@ -17,6 +17,19 @@ bool areNotApprox(const MatrixBase& m1, const MatrixBase& m2 * (std::max)(m1.cwiseAbs2().maxCoeff(), m2.cwiseAbs2().maxCoeff())); } +// Allow specifying tolerance for verifying error. +template +inline bool verifyIsApprox(const Type1& a, const Type2& b, Tol tol) +{ + bool ret = a.isApprox(b, tol); + if(!ret) + { + std::cerr << "Difference too large wrt tolerance " << tol << ", relative error is: " << test_relative_error(a,b) << std::endl; + } + return ret; +} + + template std::enable_if_t check_mismatched_product(LhsType& lhs, const RhsType& rhs) { @@ -34,6 +47,7 @@ template void product(const MatrixType& m) Identity.h Product.h */ typedef typename MatrixType::Scalar Scalar; + typedef typename MatrixType::RealScalar RealScalar; typedef Matrix RowVectorType; typedef Matrix ColVectorType; typedef Matrix RowSquareMatrixType; @@ -41,6 +55,11 @@ template void product(const MatrixType& m) typedef Matrix OtherMajorMatrixType; + // Wwe want a tighter epsilon for not-approx tests. Otherwise, for certain + // low-precision types (e.g. bfloat16), the bound ends up being relatively large + // (e.g. 0.12), causing flaky tests. + RealScalar not_approx_epsilon = RealScalar(0.1) * NumTraits::dummy_precision(); + Index rows = m.rows(); Index cols = m.cols(); @@ -68,7 +87,11 @@ template void product(const MatrixType& m) // begin testing Product.h: only associativity for now // (we use Transpose.h but this doesn't count as a test for it) - VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2)); + { + // Increase tolerance, since coefficients here can get relatively large. + RealScalar tol = RealScalar(2) * get_test_precision(m1); + VERIFY(verifyIsApprox((m1*m1.transpose())*m2, m1*(m1.transpose()*m2), tol)); + } m3 = m1; m3 *= m1.transpose() * m2; VERIFY_IS_APPROX(m3, m1 * (m1.transpose()*m2)); @@ -96,7 +119,7 @@ template void product(const MatrixType& m) // (we use the more accurate default epsilon) if (!NumTraits::IsInteger && (std::min)(rows,cols)>1) { - VERIFY(areNotApprox(m1.transpose()*m2,m2.transpose()*m1)); + VERIFY(areNotApprox(m1.transpose()*m2,m2.transpose()*m1, not_approx_epsilon)); } // test optimized operator+= path @@ -105,7 +128,7 @@ template void product(const MatrixType& m) VERIFY_IS_APPROX(res, square + m1 * m2.transpose()); if (!NumTraits::IsInteger && (std::min)(rows,cols)>1) { - VERIFY(areNotApprox(res,square + m2 * m1.transpose())); + VERIFY(areNotApprox(res,square + m2 * m1.transpose(), not_approx_epsilon)); } vcres = vc2; vcres.noalias() += m1.transpose() * v1; @@ -117,7 +140,7 @@ template void product(const MatrixType& m) VERIFY_IS_APPROX(res, square - (m1 * m2.transpose())); if (!NumTraits::IsInteger && (std::min)(rows,cols)>1) { - VERIFY(areNotApprox(res,square - m2 * m1.transpose())); + VERIFY(areNotApprox(res,square - m2 * m1.transpose(), not_approx_epsilon)); } vcres = vc2; vcres.noalias() -= m1.transpose() * v1; @@ -169,7 +192,7 @@ template void product(const MatrixType& m) VERIFY_IS_APPROX(res2, square2 + m1.transpose() * m2); if (!NumTraits::IsInteger && (std::min)(rows,cols)>1) { - VERIFY(areNotApprox(res2,square2 + m2.transpose() * m1)); + VERIFY(areNotApprox(res2,square2 + m2.transpose() * m1, not_approx_epsilon)); } VERIFY_IS_APPROX(res.col(r).noalias() = square.adjoint() * square.col(r), (square.adjoint() * square.col(r)).eval()); @@ -247,10 +270,12 @@ template void product(const MatrixType& m) // regression for blas_trais { - VERIFY_IS_APPROX(square * (square*square).transpose(), square * square.transpose() * square.transpose()); - VERIFY_IS_APPROX(square * (-(square*square)), -square * square * square); - VERIFY_IS_APPROX(square * (s1*(square*square)), s1 * square * square * square); - VERIFY_IS_APPROX(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate()); + // Increase test tolerance, since coefficients can get relatively large. + RealScalar tol = RealScalar(2) * get_test_precision(square); + VERIFY(verifyIsApprox(square * (square*square).transpose(), square * square.transpose() * square.transpose(), tol)); + VERIFY(verifyIsApprox(square * (-(square*square)), -square * square * square, tol)); + VERIFY(verifyIsApprox(square * (s1*(square*square)), s1 * square * square * square, tol)); + VERIFY(verifyIsApprox(square * (square*square).conjugate(), square * square.conjugate() * square.conjugate(), tol)); } // destination with a non-default inner-stride