Fix bug in minmax_coeff_visitor for matrix of all NaNs.

This commit is contained in:
Rasmus Munk Larsen 2023-03-13 18:25:22 +00:00
parent ee0ff0ab3a
commit 2067b54b13
2 changed files with 29 additions and 3 deletions

View File

@ -453,6 +453,7 @@ struct minmax_compare<Scalar, NaNPropagation, false> {
static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max<NaNPropagation>(p); } static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max<NaNPropagation>(p); }
}; };
// Default imlementatio
template <typename Derived, bool is_min, int NaNPropagation> template <typename Derived, bool is_min, int NaNPropagation>
struct minmax_coeff_visitor : coeff_visitor<Derived> { struct minmax_coeff_visitor : coeff_visitor<Derived> {
using Scalar = typename Derived::Scalar; using Scalar = typename Derived::Scalar;
@ -489,8 +490,8 @@ struct minmax_coeff_visitor : coeff_visitor<Derived> {
} }
}; };
// Suppress NaN. The only case in which we return NaN is if the matrix is all NaN, in which case, // Suppress NaN. The only case in which we return NaN is if the matrix is all NaN,
// the row=0, col=0 is returned for the location. // in which case, row=0, col=0 is returned for the location.
template <typename Derived, bool is_min> template <typename Derived, bool is_min>
struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers> : coeff_visitor<Derived> { struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers> : coeff_visitor<Derived> {
typedef typename Derived::Scalar Scalar; typedef typename Derived::Scalar Scalar;
@ -520,6 +521,12 @@ struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers> : coeff_visitor<D
EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) { EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) {
const Index PacketSize = packet_traits<Scalar>::size; const Index PacketSize = packet_traits<Scalar>::size;
Scalar value = Comparator::predux(p); Scalar value = Comparator::predux(p);
if ((numext::isnan)(value)) {
this->res = value;
this->row = 0;
this->col = 0;
return;
}
const Packet range = preverse(plset<Packet>(Scalar(1))); const Packet range = preverse(plset<Packet>(Scalar(1)));
/* mask will be zero for NaNs, so they will be ignored. */ /* mask will be zero for NaNs, so they will be ignored. */
Packet mask = pcmp_eq(pset1<Packet>(value), p); Packet mask = pcmp_eq(pset1<Packet>(value), p);

View File

@ -92,8 +92,27 @@ template<typename MatrixType> void matrixVisitor(const MatrixType& p)
VERIFY(maxrow != eigen_maxrow || maxcol != eigen_maxcol); VERIFY(maxrow != eigen_maxrow || maxcol != eigen_maxcol);
VERIFY((numext::isnan)(eigen_minc)); VERIFY((numext::isnan)(eigen_minc));
VERIFY((numext::isnan)(eigen_maxc)); VERIFY((numext::isnan)(eigen_maxc));
}
// Test matrix of all NaNs.
m.fill(NumTraits<Scalar>::quiet_NaN());
eigen_minc = m.template minCoeff<PropagateNumbers>(&eigen_minrow, &eigen_mincol);
eigen_maxc = m.template maxCoeff<PropagateNumbers>(&eigen_maxrow, &eigen_maxcol);
VERIFY(eigen_minrow == 0);
VERIFY(eigen_maxrow == 0);
VERIFY(eigen_mincol == 0);
VERIFY(eigen_maxcol == 0);
VERIFY((numext::isnan)(eigen_minc));
VERIFY((numext::isnan)(eigen_maxc));
eigen_minc = m.template minCoeff<PropagateNaN>(&eigen_minrow, &eigen_mincol);
eigen_maxc = m.template maxCoeff<PropagateNaN>(&eigen_maxrow, &eigen_maxcol);
VERIFY(eigen_minrow == 0);
VERIFY(eigen_maxrow == 0);
VERIFY(eigen_mincol == 0);
VERIFY(eigen_maxcol == 0);
VERIFY((numext::isnan)(eigen_minc));
VERIFY((numext::isnan)(eigen_maxc));
}
} }
template<typename VectorType> void vectorVisitor(const VectorType& w) template<typename VectorType> void vectorVisitor(const VectorType& w)