mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-23 10:09:36 +08:00
Fix signed integer UB in random.
This commit is contained in:
parent
a6dc930d16
commit
7a88cdd6ad
@ -881,7 +881,9 @@ struct random_default_impl<Scalar, false, true> {
|
||||
// if the random draw is outside [0, range), try again (rejection sampling)
|
||||
// in the worst-case scenario, the probability of rejection is: 1/2 - 1/2^numRandomBits < 50%
|
||||
} while (randomBits >= range);
|
||||
Scalar result = x + static_cast<Scalar>(randomBits);
|
||||
// Avoid overflow in the case where `x` is negative and there is a large range so
|
||||
// `randomBits` would also be negative if cast to `Scalar` first.
|
||||
Scalar result = static_cast<Scalar>(static_cast<BitsType>(x) + randomBits);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
@ -75,27 +75,36 @@ class HistogramHelper<Scalar, std::enable_if_t<Eigen::NumTraits<Scalar>::IsInteg
|
||||
HistogramHelper(Scalar lower, Scalar upper, int nbins)
|
||||
: lower_{lower}, upper_{upper}, num_bins_{nbins}, bin_width_{bin_width(lower, upper, nbins)} {}
|
||||
|
||||
int bin(Scalar v) { return static_cast<int>(RangeType(v - lower_) / bin_width_); }
|
||||
int bin(Scalar v) { return static_cast<int>(RangeType(RangeType(v) - RangeType(lower_)) / bin_width_); }
|
||||
|
||||
double uniform_bin_probability(int bin) {
|
||||
// Avoid overflow in computing range.
|
||||
double range = static_cast<double>(RangeType(upper_ - lower_)) + 1.0;
|
||||
// The full range upper - lower + 1 might overflow the RangeType by one.
|
||||
// So instead, we know we have (nbins - 1) bins of width bin_width_,
|
||||
// and the last bin of width:
|
||||
RangeType last_bin_width =
|
||||
RangeType(upper_) - (RangeType(lower_) + RangeType(num_bins_ - 1) * bin_width_) + RangeType(1);
|
||||
double last_bin_ratio = static_cast<double>(last_bin_width) / static_cast<double>(bin_width_);
|
||||
// Total probability = (nbins - 1) * p + last_bin_ratio * p = 1.0
|
||||
// p = 1.0 / (nbins - 1 + last_bin_ratio)
|
||||
double p = 1.0 / (last_bin_ratio + num_bins_ - 1);
|
||||
if (bin < num_bins_ - 1) {
|
||||
return static_cast<double>(bin_width_) / range;
|
||||
return p;
|
||||
}
|
||||
return static_cast<double>(RangeType(upper_) - RangeType((lower_ + bin * bin_width_)) + 1) / range;
|
||||
return last_bin_ratio * p;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr Scalar bin_width(Scalar lower, Scalar upper, int nbins) {
|
||||
static constexpr RangeType bin_width(Scalar lower, Scalar upper, int nbins) {
|
||||
// Avoid overflow in computing the full range.
|
||||
return RangeType(upper - nbins - lower + 1) / nbins + 1;
|
||||
// floor( (upper - lower + 1) / nbins) )
|
||||
// = floor( (upper- nbins - lower + 1 + nbins) / nbins) )
|
||||
return RangeType(RangeType(upper - nbins) - RangeType(lower) + 1) / nbins + 1;
|
||||
}
|
||||
|
||||
Scalar lower_;
|
||||
Scalar upper_;
|
||||
int num_bins_;
|
||||
Scalar bin_width_;
|
||||
RangeType bin_width_;
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
|
Loading…
x
Reference in New Issue
Block a user