From 44c20bbbe3da051e6318414bd0a3126c362d3d51 Mon Sep 17 00:00:00 2001 From: Charles Schlosser Date: Fri, 23 Jun 2023 16:29:16 +0000 Subject: [PATCH] rint round floor ceil --- Eigen/src/Core/MathFunctions.h | 91 ++++++++----------------- Eigen/src/Core/functors/UnaryFunctors.h | 8 +-- 2 files changed, 34 insertions(+), 65 deletions(-) diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index d7851e319..822701bc3 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -462,51 +462,6 @@ inline NewType cast(const OldType& x) return cast_impl::run(x); } -/**************************************************************************** -* Implementation of round * -****************************************************************************/ - -template -struct round_impl -{ - EIGEN_STATIC_ASSERT((!NumTraits::IsComplex), NUMERIC_TYPE_MUST_BE_REAL) - - EIGEN_DEVICE_FUNC - static inline Scalar run(const Scalar& x) - { - EIGEN_USING_STD(round); - return Scalar(round(x)); - } -}; - -template -struct round_retval -{ - typedef Scalar type; -}; - -/**************************************************************************** -* Implementation of rint * -****************************************************************************/ - -template -struct rint_impl { - EIGEN_STATIC_ASSERT((!NumTraits::IsComplex), NUMERIC_TYPE_MUST_BE_REAL) - - EIGEN_DEVICE_FUNC - static inline Scalar run(const Scalar& x) - { - EIGEN_USING_STD(rint); - return rint(x); - } -}; - -template -struct rint_retval -{ - typedef Scalar type; -}; - /**************************************************************************** * Implementation of arg * ****************************************************************************/ @@ -996,6 +951,22 @@ struct sign_retval { typedef Scalar type; }; + +template ::type>::IsInteger> +struct nearest_integer_impl { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_floor(const Scalar& x) { EIGEN_USING_STD(floor) return floor(x); } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_ceil(const Scalar& x) { EIGEN_USING_STD(ceil) return ceil(x); } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_rint(const Scalar& x) { EIGEN_USING_STD(rint) return rint(x); } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_round(const Scalar& x) { EIGEN_USING_STD(round) return round(x); } +}; +template +struct nearest_integer_impl { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_floor(const Scalar& x) { return x; } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_ceil(const Scalar& x) { return x; } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_rint(const Scalar& x) { return x; } + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_round(const Scalar& x) { return x; } +}; + } // end namespace internal /**************************************************************************** @@ -1317,29 +1288,28 @@ SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isfinite, isfinite, bool) #endif template -EIGEN_DEVICE_FUNC -inline EIGEN_MATHFUNC_RETVAL(rint, Scalar) rint(const Scalar& x) +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Scalar rint(const Scalar& x) { - return EIGEN_MATHFUNC_IMPL(rint, Scalar)::run(x); + return internal::nearest_integer_impl::run_rint(x); } template -EIGEN_DEVICE_FUNC -inline EIGEN_MATHFUNC_RETVAL(round, Scalar) round(const Scalar& x) +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Scalar round(const Scalar& x) { - return EIGEN_MATHFUNC_IMPL(round, Scalar)::run(x); + return internal::nearest_integer_impl::run_round(x); } #if defined(SYCL_DEVICE_ONLY) SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(round, round) #endif -template -EIGEN_DEVICE_FUNC -T (floor)(const T& x) +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Scalar (floor)(const Scalar& x) { - EIGEN_USING_STD(floor) - return floor(x); + return internal::nearest_integer_impl::run_floor(x); } #if defined(SYCL_DEVICE_ONLY) @@ -1354,12 +1324,11 @@ template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE double floor(const double &x) { return ::floor(x); } #endif -template -EIGEN_DEVICE_FUNC -T (ceil)(const T& x) +template +EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE +Scalar (ceil)(const Scalar& x) { - EIGEN_USING_STD(ceil); - return ceil(x); + return internal::nearest_integer_impl::run_ceil(x); } #if defined(SYCL_DEVICE_ONLY) diff --git a/Eigen/src/Core/functors/UnaryFunctors.h b/Eigen/src/Core/functors/UnaryFunctors.h index 4760d9b59..fcd81c186 100644 --- a/Eigen/src/Core/functors/UnaryFunctors.h +++ b/Eigen/src/Core/functors/UnaryFunctors.h @@ -809,7 +809,7 @@ struct functor_traits > { enum { Cost = NumTraits::MulCost, - PacketAccess = packet_traits::HasRound + PacketAccess = packet_traits::HasRound || NumTraits::IsInteger }; }; @@ -827,7 +827,7 @@ struct functor_traits > { enum { Cost = NumTraits::MulCost, - PacketAccess = packet_traits::HasFloor + PacketAccess = packet_traits::HasFloor || NumTraits::IsInteger }; }; @@ -845,7 +845,7 @@ struct functor_traits > { enum { Cost = NumTraits::MulCost, - PacketAccess = packet_traits::HasRint + PacketAccess = packet_traits::HasRint || NumTraits::IsInteger }; }; @@ -863,7 +863,7 @@ struct functor_traits > { enum { Cost = NumTraits::MulCost, - PacketAccess = packet_traits::HasCeil + PacketAccess = packet_traits::HasCeil || NumTraits::IsInteger }; };