mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 12:46:00 +08:00
Fix random for custom scalars that don't have constexpr digits().
This commit is contained in:
parent
500a3602f0
commit
2a9055b50e
@ -797,25 +797,33 @@ inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random();
|
|||||||
template <typename Scalar>
|
template <typename Scalar>
|
||||||
struct random_default_impl<Scalar, false, false> {
|
struct random_default_impl<Scalar, false, false> {
|
||||||
using BitsType = typename numext::get_integer_by_size<sizeof(Scalar)>::unsigned_type;
|
using BitsType = typename numext::get_integer_by_size<sizeof(Scalar)>::unsigned_type;
|
||||||
enum : int { MantissaBits = NumTraits<Scalar>::digits() - 1 };
|
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits) {
|
||||||
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits = MantissaBits) {
|
|
||||||
Scalar half_x = Scalar(0.5) * x;
|
Scalar half_x = Scalar(0.5) * x;
|
||||||
Scalar half_y = Scalar(0.5) * y;
|
Scalar half_y = Scalar(0.5) * y;
|
||||||
Scalar result = (half_x + half_y) + (half_y - half_x) * run(numRandomBits);
|
Scalar result = (half_x + half_y) + (half_y - half_x) * run(numRandomBits);
|
||||||
// result is in the half-open interval [x, y) -- provided that x < y
|
// result is in the half-open interval [x, y) -- provided that x < y
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits = MantissaBits) {
|
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) {
|
||||||
eigen_assert(numRandomBits >= 0 && numRandomBits <= MantissaBits);
|
const int mantissa_bits = NumTraits<Scalar>::digits() - 1;
|
||||||
|
return run(x, y, mantissa_bits);
|
||||||
|
}
|
||||||
|
static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits) {
|
||||||
|
const int mantissa_bits = NumTraits<Scalar>::digits() - 1;
|
||||||
|
eigen_assert(numRandomBits >= 0 && numRandomBits <= mantissa_bits);
|
||||||
BitsType randomBits = getRandomBits<BitsType>(numRandomBits);
|
BitsType randomBits = getRandomBits<BitsType>(numRandomBits);
|
||||||
// if fewer than MantissaBits is requested, shift them to the left
|
// if fewer than MantissaBits is requested, shift them to the left
|
||||||
randomBits <<= (MantissaBits - numRandomBits);
|
randomBits <<= (mantissa_bits - numRandomBits);
|
||||||
// randomBits is in the half-open interval [2,4)
|
// randomBits is in the half-open interval [2,4)
|
||||||
randomBits |= numext::bit_cast<BitsType>(Scalar(2));
|
randomBits |= numext::bit_cast<BitsType>(Scalar(2));
|
||||||
// result is in the half-open interval [-1,1)
|
// result is in the half-open interval [-1,1)
|
||||||
Scalar result = numext::bit_cast<Scalar>(randomBits) - Scalar(3);
|
Scalar result = numext::bit_cast<Scalar>(randomBits) - Scalar(3);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
static EIGEN_DEVICE_FUNC inline Scalar run() {
|
||||||
|
const int mantissa_bits = NumTraits<Scalar>::digits() - 1;
|
||||||
|
return run(mantissa_bits);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: fix this for PPC
|
// TODO: fix this for PPC
|
||||||
|
Loading…
x
Reference in New Issue
Block a user