Fix cwise NaN propagation for scalar input.

This commit is contained in:
Antonio Sánchez 2022-04-16 05:07:44 +00:00
parent a4bb513b99
commit f845a8bb1a
2 changed files with 12 additions and 3 deletions

View File

@ -86,10 +86,10 @@ cwiseMin(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<int NaNPropagation=PropagateFast>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived, const ConstantReturnType>
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NaNPropagation>, const Derived, const ConstantReturnType>
cwiseMin(const Scalar &other) const
{
return cwiseMin(Derived::Constant(rows(), cols(), other));
return cwiseMin<NaNPropagation>(Derived::Constant(rows(), cols(), other));
}
/** \returns an expression of the coefficient-wise max of *this and \a other
@ -116,7 +116,7 @@ EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NaNPropagation>, const Derived, const ConstantReturnType>
cwiseMax(const Scalar &other) const
{
return cwiseMax(Derived::Constant(rows(), cols(), other));
return cwiseMax<NaNPropagation>(Derived::Constant(rows(), cols(), other));
}

View File

@ -219,11 +219,20 @@ template<typename MatrixType> void cwise_min_max(const MatrixType& m)
VERIFY((numext::isnan)(m1.template cwiseMin<PropagateNaN>(MatrixType::Constant(rows,cols, Scalar(1)))(0,0)));
VERIFY(!(numext::isnan)(m1.template cwiseMax<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)))(0,0)));
VERIFY(!(numext::isnan)(m1.template cwiseMin<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)))(0,0)));
VERIFY((numext::isnan)(m1.template cwiseMax<PropagateNaN>(Scalar(1))(0,0)));
VERIFY((numext::isnan)(m1.template cwiseMin<PropagateNaN>(Scalar(1))(0,0)));
VERIFY(!(numext::isnan)(m1.template cwiseMax<PropagateNumbers>(Scalar(1))(0,0)));
VERIFY(!(numext::isnan)(m1.template cwiseMin<PropagateNumbers>(Scalar(1))(0,0)));
VERIFY((numext::isnan)(m1.array().template max<PropagateNaN>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
VERIFY((numext::isnan)(m1.array().template min<PropagateNaN>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
VERIFY(!(numext::isnan)(m1.array().template max<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
VERIFY(!(numext::isnan)(m1.array().template min<PropagateNumbers>(MatrixType::Constant(rows,cols, Scalar(1)).array())(0,0)));
VERIFY((numext::isnan)(m1.array().template max<PropagateNaN>(Scalar(1))(0,0)));
VERIFY((numext::isnan)(m1.array().template min<PropagateNaN>(Scalar(1))(0,0)));
VERIFY(!(numext::isnan)(m1.array().template max<PropagateNumbers>(Scalar(1))(0,0)));
VERIFY(!(numext::isnan)(m1.array().template min<PropagateNumbers>(Scalar(1))(0,0)));
// Reductions.
VERIFY((numext::isnan)(m1.template maxCoeff<PropagateNaN>()));