diff --git a/Eigen/src/Core/MathFunctions.h b/Eigen/src/Core/MathFunctions.h index a0e9d1a35..1bb57bbc4 100644 --- a/Eigen/src/Core/MathFunctions.h +++ b/Eigen/src/Core/MathFunctions.h @@ -881,7 +881,9 @@ struct random_default_impl { // if the random draw is outside [0, range), try again (rejection sampling) // in the worst-case scenario, the probability of rejection is: 1/2 - 1/2^numRandomBits < 50% } while (randomBits >= range); - Scalar result = x + static_cast(randomBits); + // Avoid overflow in the case where `x` is negative and there is a large range so + // `randomBits` would also be negative if cast to `Scalar` first. + Scalar result = static_cast(static_cast(x) + randomBits); return result; } diff --git a/test/rand.cpp b/test/rand.cpp index b5cf801f3..6a7c316d8 100644 --- a/test/rand.cpp +++ b/test/rand.cpp @@ -75,27 +75,36 @@ class HistogramHelper::IsInteg HistogramHelper(Scalar lower, Scalar upper, int nbins) : lower_{lower}, upper_{upper}, num_bins_{nbins}, bin_width_{bin_width(lower, upper, nbins)} {} - int bin(Scalar v) { return static_cast(RangeType(v - lower_) / bin_width_); } + int bin(Scalar v) { return static_cast(RangeType(RangeType(v) - RangeType(lower_)) / bin_width_); } double uniform_bin_probability(int bin) { - // Avoid overflow in computing range. - double range = static_cast(RangeType(upper_ - lower_)) + 1.0; + // The full range upper - lower + 1 might overflow the RangeType by one. + // So instead, we know we have (nbins - 1) bins of width bin_width_, + // and the last bin of width: + RangeType last_bin_width = + RangeType(upper_) - (RangeType(lower_) + RangeType(num_bins_ - 1) * bin_width_) + RangeType(1); + double last_bin_ratio = static_cast(last_bin_width) / static_cast(bin_width_); + // Total probability = (nbins - 1) * p + last_bin_ratio * p = 1.0 + // p = 1.0 / (nbins - 1 + last_bin_ratio) + double p = 1.0 / (last_bin_ratio + num_bins_ - 1); if (bin < num_bins_ - 1) { - return static_cast(bin_width_) / range; + return p; } - return static_cast(RangeType(upper_) - RangeType((lower_ + bin * bin_width_)) + 1) / range; + return last_bin_ratio * p; } private: - static constexpr Scalar bin_width(Scalar lower, Scalar upper, int nbins) { + static constexpr RangeType bin_width(Scalar lower, Scalar upper, int nbins) { // Avoid overflow in computing the full range. - return RangeType(upper - nbins - lower + 1) / nbins + 1; + // floor( (upper - lower + 1) / nbins) ) + // = floor( (upper- nbins - lower + 1 + nbins) / nbins) ) + return RangeType(RangeType(upper - nbins) - RangeType(lower) + 1) / nbins + 1; } Scalar lower_; Scalar upper_; int num_bins_; - Scalar bin_width_; + RangeType bin_width_; }; template