rint round floor ceil

This commit is contained in:
Charles Schlosser 2023-06-23 16:29:16 +00:00
parent 6ee86fd473
commit 44c20bbbe3
2 changed files with 34 additions and 65 deletions

View File

@ -462,51 +462,6 @@ inline NewType cast(const OldType& x)
return cast_impl<OldType, NewType>::run(x);
}
/****************************************************************************
* Implementation of round *
****************************************************************************/
template<typename Scalar>
struct round_impl
{
EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x)
{
EIGEN_USING_STD(round);
return Scalar(round(x));
}
};
template<typename Scalar>
struct round_retval
{
typedef Scalar type;
};
/****************************************************************************
* Implementation of rint *
****************************************************************************/
template<typename Scalar>
struct rint_impl {
EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x)
{
EIGEN_USING_STD(rint);
return rint(x);
}
};
template<typename Scalar>
struct rint_retval
{
typedef Scalar type;
};
/****************************************************************************
* Implementation of arg *
****************************************************************************/
@ -996,6 +951,22 @@ struct sign_retval {
typedef Scalar type;
};
template <typename Scalar, bool IsInteger = NumTraits<typename unpacket_traits<Scalar>::type>::IsInteger>
struct nearest_integer_impl {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_floor(const Scalar& x) { EIGEN_USING_STD(floor) return floor(x); }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_ceil(const Scalar& x) { EIGEN_USING_STD(ceil) return ceil(x); }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_rint(const Scalar& x) { EIGEN_USING_STD(rint) return rint(x); }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_round(const Scalar& x) { EIGEN_USING_STD(round) return round(x); }
};
template <typename Scalar>
struct nearest_integer_impl<Scalar, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_floor(const Scalar& x) { return x; }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_ceil(const Scalar& x) { return x; }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_rint(const Scalar& x) { return x; }
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_round(const Scalar& x) { return x; }
};
} // end namespace internal
/****************************************************************************
@ -1317,29 +1288,28 @@ SYCL_SPECIALIZE_FLOATING_TYPES_UNARY_FUNC_RET_TYPE(isfinite, isfinite, bool)
#endif
template<typename Scalar>
EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(rint, Scalar) rint(const Scalar& x)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar rint(const Scalar& x)
{
return EIGEN_MATHFUNC_IMPL(rint, Scalar)::run(x);
return internal::nearest_integer_impl<Scalar>::run_rint(x);
}
template<typename Scalar>
EIGEN_DEVICE_FUNC
inline EIGEN_MATHFUNC_RETVAL(round, Scalar) round(const Scalar& x)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar round(const Scalar& x)
{
return EIGEN_MATHFUNC_IMPL(round, Scalar)::run(x);
return internal::nearest_integer_impl<Scalar>::run_round(x);
}
#if defined(SYCL_DEVICE_ONLY)
SYCL_SPECIALIZE_FLOATING_TYPES_UNARY(round, round)
#endif
template<typename T>
EIGEN_DEVICE_FUNC
T (floor)(const T& x)
template<typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar (floor)(const Scalar& x)
{
EIGEN_USING_STD(floor)
return floor(x);
return internal::nearest_integer_impl<Scalar>::run_floor(x);
}
#if defined(SYCL_DEVICE_ONLY)
@ -1354,12 +1324,11 @@ template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
double floor(const double &x) { return ::floor(x); }
#endif
template<typename T>
EIGEN_DEVICE_FUNC
T (ceil)(const T& x)
template<typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
Scalar (ceil)(const Scalar& x)
{
EIGEN_USING_STD(ceil);
return ceil(x);
return internal::nearest_integer_impl<Scalar>::run_ceil(x);
}
#if defined(SYCL_DEVICE_ONLY)

View File

@ -809,7 +809,7 @@ struct functor_traits<scalar_round_op<Scalar> >
{
enum {
Cost = NumTraits<Scalar>::MulCost,
PacketAccess = packet_traits<Scalar>::HasRound
PacketAccess = packet_traits<Scalar>::HasRound || NumTraits<Scalar>::IsInteger
};
};
@ -827,7 +827,7 @@ struct functor_traits<scalar_floor_op<Scalar> >
{
enum {
Cost = NumTraits<Scalar>::MulCost,
PacketAccess = packet_traits<Scalar>::HasFloor
PacketAccess = packet_traits<Scalar>::HasFloor || NumTraits<Scalar>::IsInteger
};
};
@ -845,7 +845,7 @@ struct functor_traits<scalar_rint_op<Scalar> >
{
enum {
Cost = NumTraits<Scalar>::MulCost,
PacketAccess = packet_traits<Scalar>::HasRint
PacketAccess = packet_traits<Scalar>::HasRint || NumTraits<Scalar>::IsInteger
};
};
@ -863,7 +863,7 @@ struct functor_traits<scalar_ceil_op<Scalar> >
{
enum {
Cost = NumTraits<Scalar>::MulCost,
PacketAccess = packet_traits<Scalar>::HasCeil
PacketAccess = packet_traits<Scalar>::HasCeil || NumTraits<Scalar>::IsInteger
};
};