Fix signed integer UB in random.

This commit is contained in:
Antonio Sánchez 2024-02-24 13:16:23 +00:00 committed by Charles Schlosser
parent a6dc930d16
commit 7a88cdd6ad
2 changed files with 20 additions and 9 deletions

View File

@ -881,7 +881,9 @@ struct random_default_impl<Scalar, false, true> {
// if the random draw is outside [0, range), try again (rejection sampling)
// in the worst-case scenario, the probability of rejection is: 1/2 - 1/2^numRandomBits < 50%
} while (randomBits >= range);
Scalar result = x + static_cast<Scalar>(randomBits);
// Avoid overflow in the case where `x` is negative and there is a large range so
// `randomBits` would also be negative if cast to `Scalar` first.
Scalar result = static_cast<Scalar>(static_cast<BitsType>(x) + randomBits);
return result;
}

View File

@ -75,27 +75,36 @@ class HistogramHelper<Scalar, std::enable_if_t<Eigen::NumTraits<Scalar>::IsInteg
HistogramHelper(Scalar lower, Scalar upper, int nbins)
: lower_{lower}, upper_{upper}, num_bins_{nbins}, bin_width_{bin_width(lower, upper, nbins)} {}
int bin(Scalar v) { return static_cast<int>(RangeType(v - lower_) / bin_width_); }
int bin(Scalar v) { return static_cast<int>(RangeType(RangeType(v) - RangeType(lower_)) / bin_width_); }
double uniform_bin_probability(int bin) {
// Avoid overflow in computing range.
double range = static_cast<double>(RangeType(upper_ - lower_)) + 1.0;
// The full range upper - lower + 1 might overflow the RangeType by one.
// So instead, we know we have (nbins - 1) bins of width bin_width_,
// and the last bin of width:
RangeType last_bin_width =
RangeType(upper_) - (RangeType(lower_) + RangeType(num_bins_ - 1) * bin_width_) + RangeType(1);
double last_bin_ratio = static_cast<double>(last_bin_width) / static_cast<double>(bin_width_);
// Total probability = (nbins - 1) * p + last_bin_ratio * p = 1.0
// p = 1.0 / (nbins - 1 + last_bin_ratio)
double p = 1.0 / (last_bin_ratio + num_bins_ - 1);
if (bin < num_bins_ - 1) {
return static_cast<double>(bin_width_) / range;
return p;
}
return static_cast<double>(RangeType(upper_) - RangeType((lower_ + bin * bin_width_)) + 1) / range;
return last_bin_ratio * p;
}
private:
static constexpr Scalar bin_width(Scalar lower, Scalar upper, int nbins) {
static constexpr RangeType bin_width(Scalar lower, Scalar upper, int nbins) {
// Avoid overflow in computing the full range.
return RangeType(upper - nbins - lower + 1) / nbins + 1;
// floor( (upper - lower + 1) / nbins) )
// = floor( (upper- nbins - lower + 1 + nbins) / nbins) )
return RangeType(RangeType(upper - nbins) - RangeType(lower) + 1) / nbins + 1;
}
Scalar lower_;
Scalar upper_;
int num_bins_;
Scalar bin_width_;
RangeType bin_width_;
};
template <typename Scalar>