Avoid signed integer overflow in adjoint test.

This commit is contained in:
Antonio Sánchez 2022-05-23 14:46:16 +00:00
parent cbe03f3531
commit 32348091ba

View File

@ -62,6 +62,17 @@ template<> struct adjoint_specific<false> {
} }
}; };
template<typename MatrixType, typename Scalar = typename MatrixType::Scalar>
MatrixType RandomMatrix(int rows, int cols, Scalar min, Scalar max) {
MatrixType M = MatrixType(rows, cols);
for (int i=0; i<rows; ++i) {
for (int j=0; j<cols; ++j) {
M(i, j) = Eigen::internal::random<Scalar>(min, max);
}
}
return M;
}
template<typename MatrixType> void adjoint(const MatrixType& m) template<typename MatrixType> void adjoint(const MatrixType& m)
{ {
/* this test covers the following files: /* this test covers the following files:
@ -77,17 +88,21 @@ template<typename MatrixType> void adjoint(const MatrixType& m)
Index rows = m.rows(); Index rows = m.rows();
Index cols = m.cols(); Index cols = m.cols();
MatrixType m1 = MatrixType::Random(rows, cols), // Avoid integer overflow by limiting input values.
m2 = MatrixType::Random(rows, cols), RealScalar rmin = static_cast<RealScalar>(NumTraits<Scalar>::IsInteger ? NumTraits<Scalar>::IsSigned ? -100 : 0 : -1);
RealScalar rmax = static_cast<RealScalar>(NumTraits<Scalar>::IsInteger ? 100 : 1);
MatrixType m1 = RandomMatrix<MatrixType>(rows, cols, rmin, rmax),
m2 = RandomMatrix<MatrixType>(rows, cols, rmin, rmax),
m3(rows, cols), m3(rows, cols),
square = SquareMatrixType::Random(rows, rows); square = RandomMatrix<SquareMatrixType>(rows, rows, rmin, rmax);
VectorType v1 = VectorType::Random(rows), VectorType v1 = RandomMatrix<VectorType>(rows, 1, rmin, rmax),
v2 = VectorType::Random(rows), v2 = RandomMatrix<VectorType>(rows, 1, rmin, rmax),
v3 = VectorType::Random(rows), v3 = RandomMatrix<VectorType>(rows, 1, rmin, rmax),
vzero = VectorType::Zero(rows); vzero = VectorType::Zero(rows);
Scalar s1 = internal::random<Scalar>(), Scalar s1 = internal::random<Scalar>(rmin, rmax),
s2 = internal::random<Scalar>(); s2 = internal::random<Scalar>(rmin, rmax);
// check basic compatibility of adjoint, transpose, conjugate // check basic compatibility of adjoint, transpose, conjugate
VERIFY_IS_APPROX(m1.transpose().conjugate().adjoint(), m1); VERIFY_IS_APPROX(m1.transpose().conjugate().adjoint(), m1);
@ -138,7 +153,8 @@ template<typename MatrixType> void adjoint(const MatrixType& m)
// check mixed dot product // check mixed dot product
typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealVectorType; typedef Matrix<RealScalar, MatrixType::RowsAtCompileTime, 1> RealVectorType;
RealVectorType rv1 = RealVectorType::Random(rows); RealVectorType rv1 = RandomMatrix<RealVectorType>(rows, 1, rmin, rmax);
VERIFY_IS_APPROX(v1.dot(rv1.template cast<Scalar>()), v1.dot(rv1)); VERIFY_IS_APPROX(v1.dot(rv1.template cast<Scalar>()), v1.dot(rv1));
VERIFY_IS_APPROX(rv1.template cast<Scalar>().dot(v1), rv1.dot(v1)); VERIFY_IS_APPROX(rv1.template cast<Scalar>().dot(v1), rv1.dot(v1));