From 2067b54b135adede6fb9124f3d457cac53538913 Mon Sep 17 00:00:00 2001 From: Rasmus Munk Larsen Date: Mon, 13 Mar 2023 18:25:22 +0000 Subject: [PATCH] Fix bug in minmax_coeff_visitor for matrix of all NaNs. --- Eigen/src/Core/Visitor.h | 11 +++++++++-- test/visitor.cpp | 21 ++++++++++++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/Eigen/src/Core/Visitor.h b/Eigen/src/Core/Visitor.h index 4e9a85e80..7b29dae22 100644 --- a/Eigen/src/Core/Visitor.h +++ b/Eigen/src/Core/Visitor.h @@ -453,6 +453,7 @@ struct minmax_compare { static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max(p); } }; +// Default imlementatio template struct minmax_coeff_visitor : coeff_visitor { using Scalar = typename Derived::Scalar; @@ -489,8 +490,8 @@ struct minmax_coeff_visitor : coeff_visitor { } }; -// Suppress NaN. The only case in which we return NaN is if the matrix is all NaN, in which case, -// the row=0, col=0 is returned for the location. +// Suppress NaN. The only case in which we return NaN is if the matrix is all NaN, +// in which case, row=0, col=0 is returned for the location. template struct minmax_coeff_visitor : coeff_visitor { typedef typename Derived::Scalar Scalar; @@ -520,6 +521,12 @@ struct minmax_coeff_visitor : coeff_visitor::size; Scalar value = Comparator::predux(p); + if ((numext::isnan)(value)) { + this->res = value; + this->row = 0; + this->col = 0; + return; + } const Packet range = preverse(plset(Scalar(1))); /* mask will be zero for NaNs, so they will be ignored. */ Packet mask = pcmp_eq(pset1(value), p); diff --git a/test/visitor.cpp b/test/visitor.cpp index 9586539a2..174a1bb1d 100644 --- a/test/visitor.cpp +++ b/test/visitor.cpp @@ -92,8 +92,27 @@ template void matrixVisitor(const MatrixType& p) VERIFY(maxrow != eigen_maxrow || maxcol != eigen_maxcol); VERIFY((numext::isnan)(eigen_minc)); VERIFY((numext::isnan)(eigen_maxc)); - } + // Test matrix of all NaNs. + m.fill(NumTraits::quiet_NaN()); + eigen_minc = m.template minCoeff(&eigen_minrow, &eigen_mincol); + eigen_maxc = m.template maxCoeff(&eigen_maxrow, &eigen_maxcol); + VERIFY(eigen_minrow == 0); + VERIFY(eigen_maxrow == 0); + VERIFY(eigen_mincol == 0); + VERIFY(eigen_maxcol == 0); + VERIFY((numext::isnan)(eigen_minc)); + VERIFY((numext::isnan)(eigen_maxc)); + + eigen_minc = m.template minCoeff(&eigen_minrow, &eigen_mincol); + eigen_maxc = m.template maxCoeff(&eigen_maxrow, &eigen_maxcol); + VERIFY(eigen_minrow == 0); + VERIFY(eigen_maxrow == 0); + VERIFY(eigen_mincol == 0); + VERIFY(eigen_maxcol == 0); + VERIFY((numext::isnan)(eigen_minc)); + VERIFY((numext::isnan)(eigen_maxc)); + } } template void vectorVisitor(const VectorType& w)