mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Don't make assumptions about NaN-propagation for pmin/pmax - it various across platforms.
Change test to only test for NaN-propagation for pfmin/pfmax.
This commit is contained in:
parent
f66f3393e3
commit
b431024404
@ -216,12 +216,12 @@ template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
|||||||
pdiv(const Packet& a, const Packet& b) { return a/b; }
|
pdiv(const Packet& a, const Packet& b) { return a/b; }
|
||||||
|
|
||||||
/** \internal \returns the min of \a a and \a b (coeff-wise).
|
/** \internal \returns the min of \a a and \a b (coeff-wise).
|
||||||
Equivalent to std::min(a, b), so if either a or b is NaN, a is returned. */
|
If \a a or \b b is NaN, the return value is implementation defined. */
|
||||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
pmin(const Packet& a, const Packet& b) { return numext::mini(a, b); }
|
pmin(const Packet& a, const Packet& b) { return numext::mini(a, b); }
|
||||||
|
|
||||||
/** \internal \returns the max of \a a and \a b (coeff-wise)
|
/** \internal \returns the max of \a a and \a b (coeff-wise)
|
||||||
Equivalent to std::max(a, b), so if either a or b is NaN, a is returned.*/
|
If \a a or \b b is NaN, the return value is implementation defined. */
|
||||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
pmax(const Packet& a, const Packet& b) { return numext::maxi(a, b); }
|
pmax(const Packet& a, const Packet& b) { return numext::maxi(a, b); }
|
||||||
|
|
||||||
@ -635,23 +635,54 @@ Packet print(const Packet& a) { using numext::rint; return rint(a); }
|
|||||||
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
template<typename Packet> EIGEN_DECLARE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS
|
||||||
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
|
Packet pceil(const Packet& a) { using numext::ceil; return ceil(a); }
|
||||||
|
|
||||||
/** \internal \returns the min of \a a and \a b (coeff-wise)
|
|
||||||
Equivalent to std::fmin(a, b). Only if both a and b are NaN is NaN returned.
|
/** \internal \returns the max of \a a and \a b (coeff-wise)
|
||||||
*/
|
If both \a a and \a b are NaN, NaN is returned.
|
||||||
|
Equivalent to std::fmax(a, b). */
|
||||||
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
|
pfmax(const Packet& a, const Packet& b) {
|
||||||
|
Packet not_nan_mask_a = pcmp_eq(a, a);
|
||||||
|
Packet not_nan_mask_b = pcmp_eq(b, b);
|
||||||
|
return pselect(not_nan_mask_a,
|
||||||
|
pselect(not_nan_mask_b, pmax(a, b), a),
|
||||||
|
b);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** \internal \returns the min of \a a and \a b (coeff-wise)
|
||||||
|
If both \a a and \a b are NaN, NaN is returned.
|
||||||
|
Equivalent to std::fmin(a, b). */
|
||||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
pfmin(const Packet& a, const Packet& b) {
|
pfmin(const Packet& a, const Packet& b) {
|
||||||
Packet not_nan_mask = pcmp_eq(a, a);
|
Packet not_nan_mask_a = pcmp_eq(a, a);
|
||||||
return pselect(not_nan_mask, pmin(a, b), b);
|
Packet not_nan_mask_b = pcmp_eq(b, b);
|
||||||
|
return pselect(not_nan_mask_a,
|
||||||
|
pselect(not_nan_mask_b, pmin(a, b), a),
|
||||||
|
b);
|
||||||
}
|
}
|
||||||
|
|
||||||
/** \internal \returns the max of \a a and \a b (coeff-wise)
|
/** \internal \returns the max of \a a and \a b (coeff-wise)
|
||||||
Equivalent to std::fmax(a, b). Only if both a and b are NaN is NaN returned.*/
|
If either \a a or \a b are NaN, NaN is returned. */
|
||||||
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
pfmax(const Packet& a, const Packet& b) {
|
pfmax_nan(const Packet& a, const Packet& b) {
|
||||||
Packet not_nan_mask = pcmp_eq(a, a);
|
Packet not_nan_mask_a = pcmp_eq(a, a);
|
||||||
return pselect(not_nan_mask, pmax(a, b), b);
|
Packet not_nan_mask_b = pcmp_eq(b, b);
|
||||||
|
return pselect(not_nan_mask_a,
|
||||||
|
pselect(not_nan_mask_b, pmax(a, b), b),
|
||||||
|
a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** \internal \returns the min of \a a and \a b (coeff-wise)
|
||||||
|
If either \a a or \a b are NaN, NaN is returned. */
|
||||||
|
template<typename Packet> EIGEN_DEVICE_FUNC inline Packet
|
||||||
|
pfmin_nan(const Packet& a, const Packet& b) {
|
||||||
|
Packet not_nan_mask_a = pcmp_eq(a, a);
|
||||||
|
Packet not_nan_mask_b = pcmp_eq(b, b);
|
||||||
|
return pselect(not_nan_mask_a,
|
||||||
|
pselect(not_nan_mask_b, pmin(a, b), b),
|
||||||
|
a);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/***************************************************************************
|
/***************************************************************************
|
||||||
* The following functions might not have to be overwritten for vectorized types
|
* The following functions might not have to be overwritten for vectorized types
|
||||||
***************************************************************************/
|
***************************************************************************/
|
||||||
|
@ -134,21 +134,39 @@ struct functor_traits<scalar_conj_product_op<LhsScalar,RhsScalar> > {
|
|||||||
*
|
*
|
||||||
* \sa class CwiseBinaryOp, MatrixBase::cwiseMin, class VectorwiseOp, MatrixBase::minCoeff()
|
* \sa class CwiseBinaryOp, MatrixBase::cwiseMin, class VectorwiseOp, MatrixBase::minCoeff()
|
||||||
*/
|
*/
|
||||||
template<typename LhsScalar,typename RhsScalar>
|
template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
|
||||||
struct scalar_min_op : binary_op_base<LhsScalar,RhsScalar>
|
struct scalar_min_op : binary_op_base<LhsScalar,RhsScalar>
|
||||||
{
|
{
|
||||||
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_min_op>::ReturnType result_type;
|
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_min_op>::ReturnType result_type;
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_min_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_min_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return numext::mini(a, b); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
|
||||||
|
if (NaNPropagation == PropagateFast) {
|
||||||
|
return numext::mini(a, b);
|
||||||
|
} else if (NaNPropagation == PropagateNumbers) {
|
||||||
|
return internal::pfmin(a,b);
|
||||||
|
} else if (NaNPropagation == PropagateNaN) {
|
||||||
|
return internal::pfmin_nan(a,b);
|
||||||
|
}
|
||||||
|
}
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
|
||||||
{ return internal::pmin(a,b); }
|
{
|
||||||
|
if (NaNPropagation == PropagateFast) {
|
||||||
|
return internal::pmin(a,b);
|
||||||
|
} else if (NaNPropagation == PropagateNumbers) {
|
||||||
|
return internal::pfmin(a,b);
|
||||||
|
} else if (NaNPropagation == PropagateNaN) {
|
||||||
|
return internal::pfmin_nan(a,b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO(rmlarsen): Handle all NaN propagation semantics reductions.
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
|
||||||
{ return internal::predux_min(a); }
|
{ return internal::predux_min(a); }
|
||||||
};
|
};
|
||||||
template<typename LhsScalar,typename RhsScalar>
|
|
||||||
struct functor_traits<scalar_min_op<LhsScalar,RhsScalar> > {
|
template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
|
||||||
|
struct functor_traits<scalar_min_op<LhsScalar,RhsScalar, NaNPropagation> > {
|
||||||
enum {
|
enum {
|
||||||
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
|
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
|
||||||
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMin
|
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMin
|
||||||
@ -160,21 +178,39 @@ struct functor_traits<scalar_min_op<LhsScalar,RhsScalar> > {
|
|||||||
*
|
*
|
||||||
* \sa class CwiseBinaryOp, MatrixBase::cwiseMax, class VectorwiseOp, MatrixBase::maxCoeff()
|
* \sa class CwiseBinaryOp, MatrixBase::cwiseMax, class VectorwiseOp, MatrixBase::maxCoeff()
|
||||||
*/
|
*/
|
||||||
template<typename LhsScalar,typename RhsScalar>
|
template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
|
||||||
struct scalar_max_op : binary_op_base<LhsScalar,RhsScalar>
|
struct scalar_max_op : binary_op_base<LhsScalar,RhsScalar>
|
||||||
{
|
{
|
||||||
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_max_op>::ReturnType result_type;
|
typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar,scalar_max_op>::ReturnType result_type;
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_max_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_max_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const { return numext::maxi(a, b); }
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator() (const LhsScalar& a, const RhsScalar& b) const {
|
||||||
|
if (NaNPropagation == PropagateFast) {
|
||||||
|
return numext::maxi(a, b);
|
||||||
|
} else if (NaNPropagation == PropagateNumbers) {
|
||||||
|
return internal::pfmax(a,b);
|
||||||
|
} else if (NaNPropagation == PropagateNaN) {
|
||||||
|
return internal::pfmax_nan(a,b);
|
||||||
|
}
|
||||||
|
}
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
|
||||||
{ return internal::pmax(a,b); }
|
{
|
||||||
|
if (NaNPropagation == PropagateFast) {
|
||||||
|
return internal::pmax(a,b);
|
||||||
|
} else if (NaNPropagation == PropagateNumbers) {
|
||||||
|
return internal::pfmax(a,b);
|
||||||
|
} else if (NaNPropagation == PropagateNaN) {
|
||||||
|
return internal::pfmax_nan(a,b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// TODO(rmlarsen): Handle all NaN propagation semantics reductions.
|
||||||
template<typename Packet>
|
template<typename Packet>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type predux(const Packet& a) const
|
||||||
{ return internal::predux_max(a); }
|
{ return internal::predux_max(a); }
|
||||||
};
|
};
|
||||||
template<typename LhsScalar,typename RhsScalar>
|
|
||||||
struct functor_traits<scalar_max_op<LhsScalar,RhsScalar> > {
|
template<typename LhsScalar,typename RhsScalar, int NaNPropagation>
|
||||||
|
struct functor_traits<scalar_max_op<LhsScalar,RhsScalar, NaNPropagation> > {
|
||||||
enum {
|
enum {
|
||||||
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
|
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
|
||||||
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMax
|
PacketAccess = internal::is_same<LhsScalar, RhsScalar>::value && packet_traits<LhsScalar>::HasMax
|
||||||
|
@ -328,12 +328,21 @@ enum StorageOptions {
|
|||||||
* Enum for specifying whether to apply or solve on the left or right. */
|
* Enum for specifying whether to apply or solve on the left or right. */
|
||||||
enum SideType {
|
enum SideType {
|
||||||
/** Apply transformation on the left. */
|
/** Apply transformation on the left. */
|
||||||
OnTheLeft = 1,
|
OnTheLeft = 1,
|
||||||
/** Apply transformation on the right. */
|
/** Apply transformation on the right. */
|
||||||
OnTheRight = 2
|
OnTheRight = 2
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/** \ingroup enums
|
||||||
|
* Enum for specifying NaN-propagation behavior, e.g. for coeff-wise min/max. */
|
||||||
|
enum NaNPropagationOptions {
|
||||||
|
/** Implementation defined behavior if NaNs are present. */
|
||||||
|
PropagateFast = 0,
|
||||||
|
/** Always propagate NaNs. */
|
||||||
|
PropagateNaN,
|
||||||
|
/** Always propagate not-NaNs. */
|
||||||
|
PropagateNumbers
|
||||||
|
};
|
||||||
|
|
||||||
/* the following used to be written as:
|
/* the following used to be written as:
|
||||||
*
|
*
|
||||||
|
@ -180,8 +180,8 @@ template<typename LhsScalar, typename RhsScalar, bool ConjLhs=false, bool ConjRh
|
|||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_sum_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_sum_op;
|
||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_difference_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_difference_op;
|
||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_conj_product_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_conj_product_op;
|
||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_min_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar, int NaNPropagation=PropagateFast> struct scalar_min_op;
|
||||||
template<typename LhsScalar,typename RhsScalar=LhsScalar> struct scalar_max_op;
|
template<typename LhsScalar,typename RhsScalar=LhsScalar, int NaNPropagation=PropagateFast> struct scalar_max_op;
|
||||||
template<typename Scalar> struct scalar_opposite_op;
|
template<typename Scalar> struct scalar_opposite_op;
|
||||||
template<typename Scalar> struct scalar_conjugate_op;
|
template<typename Scalar> struct scalar_conjugate_op;
|
||||||
template<typename Scalar> struct scalar_real_op;
|
template<typename Scalar> struct scalar_real_op;
|
||||||
|
@ -763,6 +763,20 @@ void packetmath_real<bfloat16, typename internal::packet_traits<bfloat16>::type>
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
Scalar propagate_nan_max(const Scalar& a, const Scalar& b) {
|
||||||
|
if ((std::isnan)(a)) return a;
|
||||||
|
if ((std::isnan)(b)) return b;
|
||||||
|
return (std::max)(a,b);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
Scalar propagate_nan_min(const Scalar& a, const Scalar& b) {
|
||||||
|
if ((std::isnan)(a)) return a;
|
||||||
|
if ((std::isnan)(b)) return b;
|
||||||
|
return (std::min)(a,b);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Scalar, typename Packet>
|
template <typename Scalar, typename Packet>
|
||||||
void packetmath_notcomplex() {
|
void packetmath_notcomplex() {
|
||||||
typedef internal::packet_traits<Scalar> PacketTraits;
|
typedef internal::packet_traits<Scalar> PacketTraits;
|
||||||
@ -829,12 +843,12 @@ void packetmath_notcomplex() {
|
|||||||
data1[i] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
|
data1[i] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
|
||||||
data1[i + PacketSize] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
|
data1[i + PacketSize] = internal::random<bool>() ? std::numeric_limits<Scalar>::quiet_NaN() : Scalar(0);
|
||||||
}
|
}
|
||||||
// Test NaN propagation for pmin and pmax. It should be equivalent to std::min.
|
|
||||||
CHECK_CWISE2_IF(PacketTraits::HasMin, (std::min), internal::pmin);
|
|
||||||
CHECK_CWISE2_IF(PacketTraits::HasMax, (std::max), internal::pmax);
|
|
||||||
// Test NaN propagation for pfmin and pfmax. It should be equivalent to std::fmin.
|
// Test NaN propagation for pfmin and pfmax. It should be equivalent to std::fmin.
|
||||||
|
// Note: NaN propagation is implementation defined for pmin/pmax, so we do not test it here.
|
||||||
CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, internal::pfmin);
|
CHECK_CWISE2_IF(PacketTraits::HasMin, fmin, internal::pfmin);
|
||||||
CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pfmax);
|
CHECK_CWISE2_IF(PacketTraits::HasMax, fmax, internal::pfmax);
|
||||||
|
CHECK_CWISE2_IF(PacketTraits::HasMin, propagate_nan_min, internal::pfmin_nan);
|
||||||
|
CHECK_CWISE2_IF(PacketTraits::HasMax, propagate_nan_max, internal::pfmax_nan);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
|
@ -395,16 +395,18 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return unaryExpr(internal::scalar_mod_op<Scalar>(rhs));
|
return unaryExpr(internal::scalar_mod_op<Scalar>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int NanPropagation=PropagateFast>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
cwiseMax(Scalar threshold) const {
|
cwiseMax(Scalar threshold) const {
|
||||||
return cwiseMax(constant(threshold));
|
return cwiseMax<NanPropagation>(constant(threshold));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int NanPropagation=PropagateFast>
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
EIGEN_STRONG_INLINE const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar,NanPropagation>, const Derived, const TensorCwiseNullaryOp<internal::scalar_constant_op<Scalar>, const Derived> >
|
||||||
cwiseMin(Scalar threshold) const {
|
cwiseMin(Scalar threshold) const {
|
||||||
return cwiseMin(constant(threshold));
|
return cwiseMin<NanPropagation>(constant(threshold));
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename NewType>
|
template<typename NewType>
|
||||||
@ -472,16 +474,16 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return binaryExpr(other.derived(), internal::scalar_quotient_op<Scalar>());
|
return binaryExpr(other.derived(), internal::scalar_quotient_op<Scalar>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, const OtherDerived>
|
const TensorCwiseBinaryOp<internal::scalar_max_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived>
|
||||||
cwiseMax(const OtherDerived& other) const {
|
cwiseMax(const OtherDerived& other) const {
|
||||||
return binaryExpr(other.derived(), internal::scalar_max_op<Scalar>());
|
return binaryExpr(other.derived(), internal::scalar_max_op<Scalar,Scalar, NaNPropagation>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<int NaNPropagation=PropagateFast, typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, const OtherDerived>
|
const TensorCwiseBinaryOp<internal::scalar_min_op<Scalar,Scalar, NaNPropagation>, const Derived, const OtherDerived>
|
||||||
cwiseMin(const OtherDerived& other) const {
|
cwiseMin(const OtherDerived& other) const {
|
||||||
return binaryExpr(other.derived(), internal::scalar_min_op<Scalar>());
|
return binaryExpr(other.derived(), internal::scalar_min_op<Scalar,Scalar, NaNPropagation>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template<typename OtherDerived> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
@ -303,40 +303,79 @@ template <typename Scalar>
|
|||||||
void test_minmax_nan_propagation_templ() {
|
void test_minmax_nan_propagation_templ() {
|
||||||
for (int size = 1; size < 17; ++size) {
|
for (int size = 1; size < 17; ++size) {
|
||||||
const Scalar kNan = std::numeric_limits<Scalar>::quiet_NaN();
|
const Scalar kNan = std::numeric_limits<Scalar>::quiet_NaN();
|
||||||
|
const Scalar kZero(0);
|
||||||
Tensor<Scalar, 1> vec_nan(size);
|
Tensor<Scalar, 1> vec_nan(size);
|
||||||
Tensor<Scalar, 1> vec_zero(size);
|
Tensor<Scalar, 1> vec_zero(size);
|
||||||
Tensor<Scalar, 1> vec_res(size);
|
|
||||||
vec_nan.setConstant(kNan);
|
vec_nan.setConstant(kNan);
|
||||||
vec_zero.setZero();
|
vec_zero.setZero();
|
||||||
vec_res.setZero();
|
|
||||||
|
|
||||||
// Test that we propagate NaNs in the tensor when applying the
|
auto verify_all_nan = [&](const Tensor<Scalar, 1>& v) {
|
||||||
// cwiseMax(scalar) operator, which is used for the Relu operator.
|
for (int i = 0; i < size; ++i) {
|
||||||
vec_res = vec_nan.cwiseMax(Scalar(0));
|
VERIFY((numext::isnan)(v(i)));
|
||||||
for (int i = 0; i < size; ++i) {
|
}
|
||||||
VERIFY((numext::isnan)(vec_res(i)));
|
};
|
||||||
}
|
|
||||||
|
|
||||||
// Test that NaNs do not propagate if we reverse the arguments.
|
auto verify_all_zero = [&](const Tensor<Scalar, 1>& v) {
|
||||||
vec_res = vec_zero.cwiseMax(kNan);
|
for (int i = 0; i < size; ++i) {
|
||||||
for (int i = 0; i < size; ++i) {
|
VERIFY_IS_EQUAL(v(i), Scalar(0));
|
||||||
VERIFY_IS_EQUAL(vec_res(i), Scalar(0));
|
}
|
||||||
}
|
};
|
||||||
|
|
||||||
// Test that we propagate NaNs in the tensor when applying the
|
// Test NaN propagating max.
|
||||||
// cwiseMin(scalar) operator.
|
// max(nan, nan) = nan
|
||||||
vec_res.setZero();
|
// max(nan, 0) = nan
|
||||||
vec_res = vec_nan.cwiseMin(Scalar(0));
|
// max(0, nan) = nan
|
||||||
for (int i = 0; i < size; ++i) {
|
// max(0, 0) = 0
|
||||||
VERIFY((numext::isnan)(vec_res(i)));
|
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kNan));
|
||||||
}
|
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_nan));
|
||||||
|
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(kZero));
|
||||||
|
verify_all_nan(vec_nan.template cwiseMax<PropagateNaN>(vec_zero));
|
||||||
|
verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(kNan));
|
||||||
|
verify_all_nan(vec_zero.template cwiseMax<PropagateNaN>(vec_nan));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(kZero));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMax<PropagateNaN>(vec_zero));
|
||||||
|
|
||||||
|
// Test number propagating max.
|
||||||
|
// max(nan, nan) = nan
|
||||||
|
// max(nan, 0) = 0
|
||||||
|
// max(0, nan) = 0
|
||||||
|
// max(0, 0) = 0
|
||||||
|
verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(kNan));
|
||||||
|
verify_all_nan(vec_nan.template cwiseMax<PropagateNumbers>(vec_nan));
|
||||||
|
verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(kZero));
|
||||||
|
verify_all_zero(vec_nan.template cwiseMax<PropagateNumbers>(vec_zero));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kNan));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_nan));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(kZero));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMax<PropagateNumbers>(vec_zero));
|
||||||
|
|
||||||
// Test that NaNs do not propagate if we reverse the arguments.
|
// Test NaN propagating min.
|
||||||
vec_res = vec_zero.cwiseMin(kNan);
|
// min(nan, nan) = nan
|
||||||
for (int i = 0; i < size; ++i) {
|
// min(nan, 0) = nan
|
||||||
VERIFY_IS_EQUAL(vec_res(i), Scalar(0));
|
// min(0, nan) = nan
|
||||||
}
|
// min(0, 0) = 0
|
||||||
|
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kNan));
|
||||||
|
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_nan));
|
||||||
|
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(kZero));
|
||||||
|
verify_all_nan(vec_nan.template cwiseMin<PropagateNaN>(vec_zero));
|
||||||
|
verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(kNan));
|
||||||
|
verify_all_nan(vec_zero.template cwiseMin<PropagateNaN>(vec_nan));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(kZero));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMin<PropagateNaN>(vec_zero));
|
||||||
|
|
||||||
|
// Test number propagating min.
|
||||||
|
// min(nan, nan) = nan
|
||||||
|
// min(nan, 0) = 0
|
||||||
|
// min(0, nan) = 0
|
||||||
|
// min(0, 0) = 0
|
||||||
|
verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(kNan));
|
||||||
|
verify_all_nan(vec_nan.template cwiseMin<PropagateNumbers>(vec_nan));
|
||||||
|
verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(kZero));
|
||||||
|
verify_all_zero(vec_nan.template cwiseMin<PropagateNumbers>(vec_zero));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kNan));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_nan));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(kZero));
|
||||||
|
verify_all_zero(vec_zero.template cwiseMin<PropagateNumbers>(vec_zero));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user