Fix numext::round pre c++11 for large inputs.

This is to resolve an issue for large inputs when +0.5 can
actually lead to +1 if the input doesn't have enough precision
to resolve the addition - leading to an off-by-one error.

See discussion on 9a663973.
This commit is contained in:
Antonio Sanchez 2021-03-10 21:27:35 -08:00 committed by Rasmus Munk Larsen
parent c9d4367fa4
commit 14b7ebea11

View File

@ -476,31 +476,35 @@ inline NewType cast(const OldType& x)
* Implementation of round * * Implementation of round *
****************************************************************************/ ****************************************************************************/
template<typename Scalar>
struct round_impl
{
EIGEN_DEVICE_FUNC
static inline Scalar run(const Scalar& x)
{
EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL)
#if EIGEN_HAS_CXX11_MATH #if EIGEN_HAS_CXX11_MATH
template<typename Scalar> EIGEN_USING_STD(round);
struct round_impl { return Scalar(round(x));
EIGEN_DEVICE_FUNC #elif EIGEN_HAS_C99_MATH
static inline Scalar run(const Scalar& x) if (is_same<Scalar, float>::value) {
{ return Scalar(::roundf(x));
EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL) } else {
EIGEN_USING_STD(round);
return Scalar(round(x)); return Scalar(round(x));
} }
};
#else #else
template<typename Scalar> EIGEN_USING_STD(floor);
struct round_impl EIGEN_USING_STD(ceil);
{ // If not enough precision to resolve a decimal at all, return the input.
EIGEN_DEVICE_FUNC // Otherwise, adding 0.5 can trigger an increment by 1.
static inline Scalar run(const Scalar& x) const Scalar limit = Scalar(1ull << (NumTraits<Scalar>::digits() - 1));
{ if (x >= limit || x <= -limit) {
EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), NUMERIC_TYPE_MUST_BE_REAL) return x;
EIGEN_USING_STD(floor);
EIGEN_USING_STD(ceil);
return (x > Scalar(0)) ? floor(x + Scalar(0.5)) : ceil(x - Scalar(0.5));
} }
}; return (x > Scalar(0)) ? Scalar(floor(x + Scalar(0.5))) : Scalar(ceil(x - Scalar(0.5)));
#endif #endif
}
};
template<typename Scalar> template<typename Scalar>
struct round_retval struct round_retval