Scalarize comps

This commit is contained in:
Charles Schlosser 2023-03-02 17:06:23 +00:00 committed by Rasmus Munk Larsen
parent 3abe12472e
commit 0b396c3167
3 changed files with 107 additions and 44 deletions

View File

@ -195,7 +195,7 @@ struct functor_traits<scalar_max_op<LhsScalar,RhsScalar, NaNPropagation> > {
* \todo Implement packet-comparisons
*/
template <typename LhsScalar, typename RhsScalar, ComparisonName cmp,
bool UseTypedComparators = true>
bool UseTypedComparators = false>
struct scalar_cmp_op;
template <typename LhsScalar, typename RhsScalar, ComparisonName cmp, bool UseTypedComparators>

View File

@ -25,6 +25,13 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
return EIGEN_CWISE_BINARY_RETURN_TYPE(Derived,OtherDerived,product)(derived(), other.derived());
}
template<typename OtherDerived> using CwiseBinaryEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryNotEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryLesserReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryGreaterReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryLesserOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryGreaterOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>, const Derived, const OtherDerived>;
/** \returns an expression of the coefficient-wise == operator of *this and \a other
*
* \warning this performs an exact comparison, which is generally a bad idea with floating-point types.
@ -39,10 +46,10 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ>, const Derived, const OtherDerived>
inline const CwiseBinaryEqualReturnType<OtherDerived>
cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryEqualReturnType<OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise != operator of *this and \a other
@ -59,46 +66,46 @@ cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>, const Derived, const OtherDerived>
inline const CwiseBinaryNotEqualReturnType<OtherDerived>
cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryNotEqualReturnType<OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise < operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>, const Derived, const OtherDerived>
inline const CwiseBinaryLesserReturnType<OtherDerived>
cwiseLesser(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryLesserReturnType<OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise > operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>, const Derived, const OtherDerived>
inline const CwiseBinaryGreaterReturnType<OtherDerived>
cwiseGreater(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryGreaterReturnType<OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise <= operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>, const Derived, const OtherDerived>
inline const CwiseBinaryLesserOrEqualReturnType<OtherDerived>
cwiseLesserOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryLesserOrEqualReturnType<OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise >= operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>, const Derived, const OtherDerived>
inline const CwiseBinaryGreaterOrEqualReturnType<OtherDerived>
cwiseGreaterOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryGreaterOrEqualReturnType<OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise min of *this and \a other
@ -191,7 +198,7 @@ EIGEN_DEVICE_FUNC
inline const CwiseScalarEqualReturnType
cwiseEqual(const Scalar& s) const
{
return CwiseScalarEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar,Scalar,internal::cmp_EQ>());
return CwiseScalarEqualReturnType(derived(), Derived::Constant(rows(), cols(), s));
}
@ -208,7 +215,7 @@ EIGEN_DEVICE_FUNC
inline const CwiseScalarNotEqualReturnType
cwiseNotEqual(const Scalar& s) const
{
return CwiseScalarNotEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>());
return CwiseScalarNotEqualReturnType(derived(), Derived::Constant(rows(), cols(), s));
}
/** \returns an expression of the coefficient-wise < operator of \c *this and a scalar \a s */
@ -216,7 +223,7 @@ EIGEN_DEVICE_FUNC
inline const CwiseScalarLesserReturnType
cwiseLesser(const Scalar& s) const
{
return CwiseScalarLesserReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>());
return CwiseScalarLesserReturnType(derived(), Derived::Constant(rows(), cols(), s));
}
/** \returns an expression of the coefficient-wise > operator of \c *this and a scalar \a s */
@ -224,7 +231,7 @@ EIGEN_DEVICE_FUNC
inline const CwiseScalarGreaterReturnType
cwiseGreater(const Scalar& s) const
{
return CwiseScalarGreaterReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>());
return CwiseScalarGreaterReturnType(derived(), Derived::Constant(rows(), cols(), s));
}
/** \returns an expression of the coefficient-wise <= operator of \c *this and a scalar \a s */
@ -232,7 +239,7 @@ EIGEN_DEVICE_FUNC
inline const CwiseScalarLesserOrEqualReturnType
cwiseLesserOrEqual(const Scalar& s) const
{
return CwiseScalarLesserOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>());
return CwiseScalarLesserOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s));
}
/** \returns an expression of the coefficient-wise >= operator of \c *this and a scalar \a s */
@ -240,7 +247,62 @@ EIGEN_DEVICE_FUNC
inline const CwiseScalarGreaterOrEqualReturnType
cwiseGreaterOrEqual(const Scalar& s) const
{
return CwiseScalarGreaterOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>());
return CwiseScalarGreaterOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s));
}
template<typename OtherDerived> using CwiseBinaryTypedEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryTypedNotEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryTypedLesserReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryTypedGreaterReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryTypedLesserOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>, const Derived, const OtherDerived>;
template<typename OtherDerived> using CwiseBinaryTypedGreaterOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>, const Derived, const OtherDerived>;
template <typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedEqualReturnType<OtherDerived>
cwiseTypedEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const { return CwiseBinaryTypedEqualReturnType<OtherDerived>(derived(), other.derived()); }
template <typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedNotEqualReturnType<OtherDerived>
cwiseTypedNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const { return CwiseBinaryTypedNotEqualReturnType<OtherDerived>(derived(), other.derived()); }
template <typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedLesserReturnType<OtherDerived>
cwiseTypedLesser(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const { return CwiseBinaryTypedLesserReturnType<OtherDerived>(derived(), other.derived()); }
template <typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedGreaterReturnType<OtherDerived>
cwiseTypedGreater(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const { return CwiseBinaryTypedGreaterReturnType<OtherDerived>(derived(), other.derived()); }
template <typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedLesserOrEqualReturnType<OtherDerived>
cwiseTypedLesserOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const { return CwiseBinaryTypedLesserOrEqualReturnType<OtherDerived>(derived(), other.derived()); }
template <typename OtherDerived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseBinaryTypedGreaterOrEqualReturnType<OtherDerived>
cwiseTypedGreaterOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const { return CwiseBinaryTypedGreaterOrEqualReturnType<OtherDerived>(derived(), other.derived()); }
using CwiseScalarTypedEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>, const Derived, const ConstantReturnType>;
using CwiseScalarTypedNotEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>, const Derived, const ConstantReturnType>;
using CwiseScalarTypedLesserReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>, const Derived, const ConstantReturnType>;
using CwiseScalarTypedGreaterReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>, const Derived, const ConstantReturnType>;
using CwiseScalarTypedLesserOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>, const Derived, const ConstantReturnType>;
using CwiseScalarTypedGreaterOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>, const Derived, const ConstantReturnType>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedEqualReturnType
cwiseTypedEqual(const Scalar& s) const { return CwiseScalarTypedEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedNotEqualReturnType
cwiseTypedNotEqual(const Scalar& s) const { return CwiseScalarTypedNotEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedLesserReturnType
cwiseTypedLesser(const Scalar& s) const { return CwiseScalarTypedLesserReturnType(derived(), ConstantReturnType(rows(), cols(), s)); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedGreaterReturnType
cwiseTypedGreater(const Scalar& s) const { return CwiseScalarTypedGreaterReturnType(derived(), ConstantReturnType(rows(), cols(), s)); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedLesserOrEqualReturnType
cwiseTypedLesserOrEqual(const Scalar& s) const { return CwiseScalarTypedLesserOrEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const CwiseScalarTypedGreaterOrEqualReturnType
cwiseTypedGreaterOrEqual(const Scalar& s) const { return CwiseScalarTypedGreaterOrEqualReturnType(derived(), ConstantReturnType(rows(), cols(), s)); }

View File

@ -590,21 +590,6 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
typedef typename ArrayType::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
// explicitly test both typed and boolean comparison ops
using typed_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
using typed_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
using typed_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
using typed_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
using typed_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
using typed_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
using bool_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, false>;
using bool_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, false>;
using bool_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, false>;
using bool_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, false>;
using bool_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, false>;
using bool_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, false>;
Index rows = m.rows();
Index cols = m.cols();
@ -649,28 +634,44 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
// use typed comparisons, regardless of operator overload behavior
typename ArrayType::ConstantReturnType typed_true = ArrayType::Constant(rows, cols, Scalar(1));
// (m1 + Scalar(1)) > m1).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, typed_gt()), typed_true);
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseTypedGreater(m1), typed_true);
// (m1 - Scalar(1)) < m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_lt()), typed_true);
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseTypedLesser(m1), typed_true);
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), typed_eq()), typed_true);
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseTypedEqual(m1 + Scalar(1)), typed_true);
// (m1 - Scalar(1)) != m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_ne()), typed_true);
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseTypedNotEqual(m1), typed_true);
// (m1 <= m2 || m1 >= m2).all()
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, typed_le()) || m1.binaryExpr(m2, typed_ge()), typed_true);
VERIFY_IS_CWISE_EQUAL(m1.cwiseTypedGreaterOrEqual(m2) || m1.cwiseTypedLesserOrEqual(m2), typed_true);
// use boolean comparisons, regardless of operator overload behavior
ArrayXX<bool>::ConstantReturnType bool_true = ArrayXX<bool>::Constant(rows, cols, true);
// (m1 + Scalar(1)) > m1).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, bool_gt()), bool_true);
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseGreater(m1), bool_true);
// (m1 - Scalar(1)) < m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_lt()), bool_true);
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseLesser(m1), bool_true);
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), bool_eq()), bool_true);
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseEqual(m1 + Scalar(1)), bool_true);
// (m1 - Scalar(1)) != m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_ne()), bool_true);
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseNotEqual(m1), bool_true);
// (m1 <= m2 || m1 >= m2).all()
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, bool_le()) || m1.binaryExpr(m2, bool_ge()), bool_true);
VERIFY_IS_CWISE_EQUAL(m1.cwiseLesserOrEqual(m2) || m1.cwiseGreaterOrEqual(m2), bool_true);
// test typed comparisons with scalar argument
VERIFY_IS_CWISE_EQUAL((m1 - m1).cwiseTypedEqual(Scalar(0)), typed_true);
VERIFY_IS_CWISE_EQUAL((m1.abs() + Scalar(1)).cwiseTypedNotEqual(Scalar(0)), typed_true);
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseTypedGreater(m1.minCoeff()), typed_true);
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseTypedLesser(m1.maxCoeff()), typed_true);
VERIFY_IS_CWISE_EQUAL(m1.abs().cwiseTypedLesserOrEqual(NumTraits<Scalar>::highest()), typed_true);
VERIFY_IS_CWISE_EQUAL((m1 * m1).cwiseTypedGreaterOrEqual(Scalar(0)), typed_true);
// test boolean comparisons with scalar argument
VERIFY_IS_CWISE_EQUAL((m1 - m1).cwiseEqual(Scalar(0)), bool_true);
VERIFY_IS_CWISE_EQUAL((m1.abs() + Scalar(1)).cwiseNotEqual(Scalar(0)), bool_true);
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).cwiseGreater(m1.minCoeff()), bool_true);
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).cwiseLesser(m1.maxCoeff()), bool_true);
VERIFY_IS_CWISE_EQUAL(m1.abs().cwiseLesserOrEqual(NumTraits<Scalar>::highest()), bool_true);
VERIFY_IS_CWISE_EQUAL((m1 * m1).cwiseGreaterOrEqual(Scalar(0)), bool_true);
// test Select
VERIFY_IS_APPROX( (m1<m2).select(m1,m2), m1.cwiseMin(m2) );