Fix random for custom scalars that don't have constexpr digits().

This commit is contained in:
Antonio Sánchez 2024-02-16 02:30:54 +00:00 committed by Charles Schlosser
parent 500a3602f0
commit 2a9055b50e

View File

@ -797,25 +797,33 @@ inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random();
template <typename Scalar>
struct random_default_impl<Scalar, false, false> {
using BitsType = typename numext::get_integer_by_size<sizeof(Scalar)>::unsigned_type;
enum : int { MantissaBits = NumTraits<Scalar>::digits() - 1 };
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits = MantissaBits) {
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits) {
Scalar half_x = Scalar(0.5) * x;
Scalar half_y = Scalar(0.5) * y;
Scalar result = (half_x + half_y) + (half_y - half_x) * run(numRandomBits);
// result is in the half-open interval [x, y) -- provided that x < y
return result;
}
static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits = MantissaBits) {
eigen_assert(numRandomBits >= 0 && numRandomBits <= MantissaBits);
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) {
const int mantissa_bits = NumTraits<Scalar>::digits() - 1;
return run(x, y, mantissa_bits);
}
static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) {
const int mantissa_bits = NumTraits<Scalar>::digits() - 1;
eigen_assert(numRandomBits >= 0 && numRandomBits <= mantissa_bits);
BitsType randomBits = getRandomBits<BitsType>(numRandomBits);
// if fewer than MantissaBits is requested, shift them to the left
randomBits <<= (MantissaBits - numRandomBits);
randomBits <<= (mantissa_bits - numRandomBits);
// randomBits is in the half-open interval [2,4)
randomBits |= numext::bit_cast<BitsType>(Scalar(2));
// result is in the half-open interval [-1,1)
Scalar result = numext::bit_cast<Scalar>(randomBits) - Scalar(3);
return result;
}
static EIGEN_DEVICE_FUNC inline Scalar run() {
const int mantissa_bits = NumTraits<Scalar>::digits() - 1;
return run(mantissa_bits);
}
};
// TODO: fix this for PPC