mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-16 14:49:39 +08:00
improve random
This commit is contained in:
parent
a9ddab3e06
commit
d626762e3f
@ -563,34 +563,6 @@ struct pow_impl<ScalarX, ScalarY, true> {
|
||||
}
|
||||
};
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of random *
|
||||
****************************************************************************/
|
||||
|
||||
template <typename Scalar, bool IsComplex, bool IsInteger>
|
||||
struct random_default_impl {};
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_impl : random_default_impl<Scalar, NumTraits<Scalar>::IsComplex, NumTraits<Scalar>::IsInteger> {};
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_retval {
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random(const Scalar& x, const Scalar& y);
|
||||
template <typename Scalar>
|
||||
inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random();
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_default_impl<Scalar, false, false> {
|
||||
static inline Scalar run(const Scalar& x, const Scalar& y) {
|
||||
return x + (y - x) * Scalar(std::rand()) / Scalar(RAND_MAX);
|
||||
}
|
||||
static inline Scalar run() { return run(Scalar(NumTraits<Scalar>::IsSigned ? -1 : 0), Scalar(1)); }
|
||||
};
|
||||
|
||||
enum { meta_floor_log2_terminate, meta_floor_log2_move_up, meta_floor_log2_move_down, meta_floor_log2_bogus };
|
||||
|
||||
template <unsigned int n, int lower, int upper>
|
||||
@ -769,56 +741,166 @@ struct count_bits_impl<
|
||||
|
||||
#endif // EIGEN_COMP_GNUC || EIGEN_COMP_CLANG
|
||||
|
||||
template <typename BitsType>
|
||||
int log2_ceil(BitsType x) {
|
||||
int n = CHAR_BIT * sizeof(BitsType) - clz(x);
|
||||
bool powerOfTwo = (x & (x - 1)) == 0;
|
||||
return x == 0 ? 0 : powerOfTwo ? n - 1 : n;
|
||||
}
|
||||
|
||||
template <typename BitsType>
|
||||
int log2_floor(BitsType x) {
|
||||
int n = CHAR_BIT * sizeof(BitsType) - clz(x);
|
||||
return x == 0 ? 0 : n - 1;
|
||||
}
|
||||
|
||||
/****************************************************************************
|
||||
* Implementation of random *
|
||||
****************************************************************************/
|
||||
|
||||
// return a Scalar filled with numRandomBits beginning from the least significant bit
|
||||
template <typename Scalar>
|
||||
Scalar getRandomBits(int numRandomBits) {
|
||||
using BitsType = typename numext::get_integer_by_size<sizeof(Scalar)>::unsigned_type;
|
||||
enum : int {
|
||||
StdRandBits = meta_floor_log2<(unsigned int)(RAND_MAX) + 1>::value,
|
||||
ScalarBits = sizeof(Scalar) * CHAR_BIT
|
||||
};
|
||||
eigen_assert((numRandomBits >= 0) && (numRandomBits <= ScalarBits));
|
||||
const BitsType mask = BitsType(-1) >> (ScalarBits - numRandomBits);
|
||||
BitsType randomBits = BitsType(0);
|
||||
for (int shift = 0; shift < numRandomBits; shift += StdRandBits) {
|
||||
int r = std::rand();
|
||||
randomBits |= static_cast<BitsType>(r) << shift;
|
||||
}
|
||||
// clear the excess bits
|
||||
randomBits &= mask;
|
||||
return numext::bit_cast<Scalar, BitsType>(randomBits);
|
||||
}
|
||||
|
||||
template <typename Scalar, bool IsComplex, bool IsInteger>
|
||||
struct random_default_impl {};
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_impl : random_default_impl<Scalar, NumTraits<Scalar>::IsComplex, NumTraits<Scalar>::IsInteger> {};
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_retval {
|
||||
typedef Scalar type;
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random(const Scalar& x, const Scalar& y);
|
||||
template <typename Scalar>
|
||||
inline EIGEN_MATHFUNC_RETVAL(random, Scalar) random();
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_default_impl<Scalar, false, false> {
|
||||
using BitsType = typename numext::get_integer_by_size<sizeof(Scalar)>::unsigned_type;
|
||||
enum : int { MantissaBits = NumTraits<Scalar>::digits() - 1 };
|
||||
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y, int numRandomBits = MantissaBits) {
|
||||
Scalar half_x = Scalar(0.5) * x;
|
||||
Scalar half_y = Scalar(0.5) * y;
|
||||
Scalar result = (half_x + half_y) + (half_y - half_x) * run(numRandomBits);
|
||||
// result is in the half-open interval [x, y) -- provided that x < y
|
||||
return result;
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline Scalar run(int numRandomBits = MantissaBits) {
|
||||
eigen_assert(numRandomBits >= 0 && numRandomBits <= MantissaBits);
|
||||
BitsType randomBits = getRandomBits<BitsType>(numRandomBits);
|
||||
// if fewer than MantissaBits is requested, shift them to the left
|
||||
randomBits <<= (MantissaBits - numRandomBits);
|
||||
// randomBits is in the half-open interval [2,4)
|
||||
randomBits |= numext::bit_cast<BitsType>(Scalar(2));
|
||||
// result is in the half-open interval [-1,1)
|
||||
Scalar result = numext::bit_cast<Scalar>(randomBits) - Scalar(3);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: fix this for PPC
|
||||
template <bool Specialize = sizeof(long double) == 2 * sizeof(uint64_t) && !EIGEN_ARCH_PPC>
|
||||
struct random_longdouble_impl {
|
||||
enum : int {
|
||||
Size = sizeof(long double),
|
||||
MantissaBits = NumTraits<long double>::digits() - 1,
|
||||
LowBits = MantissaBits > 64 ? 64 : MantissaBits,
|
||||
HighBits = MantissaBits > 64 ? MantissaBits - 64 : 0
|
||||
};
|
||||
static EIGEN_DEVICE_FUNC inline long double run() {
|
||||
EIGEN_USING_STD(memcpy)
|
||||
uint64_t randomBits[2];
|
||||
long double result = 2.0L;
|
||||
memcpy(&randomBits, &result, Size);
|
||||
randomBits[0] |= getRandomBits<uint64_t>(LowBits);
|
||||
randomBits[1] |= getRandomBits<uint64_t>(HighBits);
|
||||
memcpy(&result, &randomBits, Size);
|
||||
result -= 3.0L;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct random_longdouble_impl<false> {
|
||||
using Impl = random_impl<double>;
|
||||
static EIGEN_DEVICE_FUNC inline long double run() { return static_cast<long double>(Impl::run()); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct random_impl<long double> {
|
||||
static EIGEN_DEVICE_FUNC inline long double run(const long double& x, const long double& y) {
|
||||
long double half_x = 0.5L * x;
|
||||
long double half_y = 0.5L * y;
|
||||
long double result = (half_x + half_y) + (half_y - half_x) * run();
|
||||
return result;
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline long double run() { return random_longdouble_impl<>::run(); }
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_default_impl<Scalar, false, true> {
|
||||
static inline Scalar run(const Scalar& x, const Scalar& y) {
|
||||
using BitsType = typename numext::get_integer_by_size<sizeof(Scalar)>::unsigned_type;
|
||||
enum : int { ScalarBits = sizeof(Scalar) * CHAR_BIT };
|
||||
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) {
|
||||
if (y <= x) return x;
|
||||
// ScalarU is the unsigned counterpart of Scalar, possibly Scalar itself.
|
||||
typedef typename make_unsigned<Scalar>::type ScalarU;
|
||||
// ScalarX is the widest of ScalarU and unsigned int.
|
||||
// We'll deal only with ScalarX and unsigned int below thus avoiding signed
|
||||
// types and arithmetic and signed overflows (which are undefined behavior).
|
||||
typedef std::conditional_t<(ScalarU(-1) > unsigned(-1)), ScalarU, unsigned> ScalarX;
|
||||
// The following difference doesn't overflow, provided our integer types are two's
|
||||
// complement and have the same number of padding bits in signed and unsigned variants.
|
||||
// This is the case in most modern implementations of C++.
|
||||
ScalarX range = ScalarX(y) - ScalarX(x);
|
||||
ScalarX offset = 0;
|
||||
ScalarX divisor = 1;
|
||||
ScalarX multiplier = 1;
|
||||
const unsigned rand_max = RAND_MAX;
|
||||
if (range <= rand_max)
|
||||
divisor = (rand_max + 1) / (range + 1);
|
||||
else
|
||||
multiplier = 1 + range / (rand_max + 1);
|
||||
// Rejection sampling.
|
||||
const BitsType range = static_cast<BitsType>(y) - static_cast<BitsType>(x) + 1;
|
||||
// handle edge case where [x,y] spans the entire range of Scalar
|
||||
if (range == 0) return getRandomBits<Scalar>(ScalarBits);
|
||||
// calculate the number of random bits needed to fill range
|
||||
const int numRandomBits = log2_ceil(range);
|
||||
BitsType randomBits;
|
||||
do {
|
||||
offset = (unsigned(std::rand()) * multiplier) / divisor;
|
||||
} while (offset > range);
|
||||
return Scalar(ScalarX(x) + offset);
|
||||
randomBits = getRandomBits<BitsType>(numRandomBits);
|
||||
// if the random draw is outside [0, range), try again (rejection sampling)
|
||||
// in the worst-case scenario, the probability of rejection is: 1/2 - 1/2^numRandomBits < 50%
|
||||
} while (randomBits >= range);
|
||||
Scalar result = x + static_cast<Scalar>(randomBits);
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline Scalar run() {
|
||||
static EIGEN_DEVICE_FUNC inline Scalar run() {
|
||||
#ifdef EIGEN_MAKING_DOCS
|
||||
return run(Scalar(NumTraits<Scalar>::IsSigned ? -10 : 0), Scalar(10));
|
||||
#else
|
||||
enum {
|
||||
rand_bits = meta_floor_log2<(unsigned int)(RAND_MAX) + 1>::value,
|
||||
scalar_bits = sizeof(Scalar) * CHAR_BIT,
|
||||
shift = plain_enum_max(0, int(rand_bits) - int(scalar_bits)),
|
||||
offset = NumTraits<Scalar>::IsSigned ? (1 << (plain_enum_min(rand_bits, scalar_bits) - 1)) : 0
|
||||
};
|
||||
return Scalar((std::rand() >> shift) - offset);
|
||||
return getRandomBits<Scalar>(ScalarBits);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct random_impl<bool> {
|
||||
static EIGEN_DEVICE_FUNC inline bool run(const bool& x, const bool& y) {
|
||||
if (y <= x) return x;
|
||||
return run();
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline bool run() { return getRandomBits<int>(1) ? true : false; }
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
struct random_default_impl<Scalar, true, false> {
|
||||
static inline Scalar run(const Scalar& x, const Scalar& y) {
|
||||
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar& x, const Scalar& y) {
|
||||
return Scalar(random(x.real(), y.real()), random(x.imag(), y.imag()));
|
||||
}
|
||||
static inline Scalar run() {
|
||||
static EIGEN_DEVICE_FUNC inline Scalar run() {
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
return Scalar(random<RealScalar>(), random<RealScalar>());
|
||||
}
|
||||
@ -1863,13 +1945,6 @@ EIGEN_DEVICE_FUNC inline bool isApproxOrLessThan(
|
||||
*** The special case of the bool type ***
|
||||
******************************************/
|
||||
|
||||
template <>
|
||||
struct random_impl<bool> {
|
||||
static inline bool run() { return random<int>(0, 1) == 0 ? false : true; }
|
||||
|
||||
static inline bool run(const bool& a, const bool& b) { return random<int>(a, b) == 0 ? false : true; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_fuzzy_impl<bool> {
|
||||
typedef bool RealScalar;
|
||||
|
@ -206,9 +206,7 @@ struct GenericNumTraits {
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static inline T highest() { return (numext::numeric_limits<T>::max)(); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static inline T lowest() {
|
||||
return IsInteger ? (numext::numeric_limits<T>::min)() : static_cast<T>(-(numext::numeric_limits<T>::max)());
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static inline T lowest() { return (numext::numeric_limits<T>::lowest)(); }
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_CONSTEXPR static inline T infinity() { return numext::numeric_limits<T>::infinity(); }
|
||||
|
||||
|
@ -677,16 +677,22 @@ EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const bfloat16& v
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
struct random_default_impl<bfloat16, false, false> {
|
||||
static inline bfloat16 run(const bfloat16& x, const bfloat16& y) {
|
||||
return x + (y - x) * bfloat16(float(std::rand()) / float(RAND_MAX));
|
||||
}
|
||||
static inline bfloat16 run() { return run(bfloat16(-1.f), bfloat16(1.f)); }
|
||||
struct is_arithmetic<bfloat16> {
|
||||
enum { value = true };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct is_arithmetic<bfloat16> {
|
||||
enum { value = true };
|
||||
struct random_impl<bfloat16> {
|
||||
enum : int { MantissaBits = 7 };
|
||||
using Impl = random_impl<float>;
|
||||
static EIGEN_DEVICE_FUNC inline bfloat16 run(const bfloat16& x, const bfloat16& y) {
|
||||
float result = Impl::run(x, y, MantissaBits);
|
||||
return bfloat16(result);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline bfloat16 run() {
|
||||
float result = Impl::run(MantissaBits);
|
||||
return bfloat16(result);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
@ -762,16 +762,22 @@ EIGEN_ALWAYS_INLINE std::ostream& operator<<(std::ostream& os, const half& v) {
|
||||
namespace internal {
|
||||
|
||||
template <>
|
||||
struct random_default_impl<half, false, false> {
|
||||
static inline half run(const half& x, const half& y) {
|
||||
return x + (y - x) * half(float(std::rand()) / float(RAND_MAX));
|
||||
}
|
||||
static inline half run() { return run(half(-1.f), half(1.f)); }
|
||||
struct is_arithmetic<half> {
|
||||
enum { value = true };
|
||||
};
|
||||
|
||||
template <>
|
||||
struct is_arithmetic<half> {
|
||||
enum { value = true };
|
||||
struct random_impl<half> {
|
||||
enum : int { MantissaBits = 10 };
|
||||
using Impl = random_impl<float>;
|
||||
static EIGEN_DEVICE_FUNC inline half run(const half& x, const half& y) {
|
||||
float result = Impl::run(x, y, MantissaBits);
|
||||
return half(result);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline half run() {
|
||||
float result = Impl::run(MantissaBits);
|
||||
return half(result);
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
@ -183,6 +183,20 @@ template <>
|
||||
EIGEN_STRONG_INLINE float cast(const AnnoyingScalar& x) {
|
||||
return *x.v;
|
||||
}
|
||||
|
||||
template <>
|
||||
struct random_impl<AnnoyingScalar> {
|
||||
using Impl = random_impl<float>;
|
||||
static EIGEN_DEVICE_FUNC inline AnnoyingScalar run(const AnnoyingScalar& x, const AnnoyingScalar& y) {
|
||||
float result = Impl::run(*x.v, *y.v);
|
||||
return AnnoyingScalar(result);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline AnnoyingScalar run() {
|
||||
float result = Impl::run();
|
||||
return AnnoyingScalar(result);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
||||
|
@ -28,6 +28,23 @@ struct MovableScalar : public Base {
|
||||
|
||||
template <>
|
||||
struct NumTraits<MovableScalar<float>> : GenericNumTraits<float> {};
|
||||
|
||||
namespace internal {
|
||||
template <typename T>
|
||||
struct random_impl<MovableScalar<T>> {
|
||||
using MoveableT = MovableScalar<T>;
|
||||
using Impl = random_impl<T>;
|
||||
static EIGEN_DEVICE_FUNC inline MoveableT run(const MoveableT& x, const MoveableT& y) {
|
||||
T result = Impl::run(x, y);
|
||||
return MoveableT(result);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline MoveableT run() {
|
||||
T result = Impl::run();
|
||||
return MoveableT(result);
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
} // namespace Eigen
|
||||
|
||||
#endif
|
||||
|
@ -26,3 +26,21 @@ class SafeScalar {
|
||||
T val_;
|
||||
bool initialized_;
|
||||
};
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
template <typename T>
|
||||
struct random_impl<SafeScalar<T>> {
|
||||
using SafeT = SafeScalar<T>;
|
||||
using Impl = random_impl<T>;
|
||||
static EIGEN_DEVICE_FUNC inline SafeT run(const SafeT& x, const SafeT& y) {
|
||||
T result = Impl::run(x, y);
|
||||
return SafeT(result);
|
||||
}
|
||||
static EIGEN_DEVICE_FUNC inline SafeT run() {
|
||||
T result = Impl::run();
|
||||
return SafeT(result);
|
||||
}
|
||||
};
|
||||
} // namespace internal
|
||||
} // namespace Eigen
|
||||
|
@ -36,10 +36,10 @@ void alignedbox(const BoxType& box) {
|
||||
|
||||
const Index dim = box.dim();
|
||||
|
||||
VectorType p0 = VectorType::Random(dim);
|
||||
VectorType p1 = VectorType::Random(dim);
|
||||
VectorType p0 = VectorType::Random(dim) / Scalar(2);
|
||||
VectorType p1 = VectorType::Random(dim) / Scalar(2);
|
||||
while (p1 == p0) {
|
||||
p1 = VectorType::Random(dim);
|
||||
p1 = VectorType::Random(dim) / Scalar(2);
|
||||
}
|
||||
RealScalar s1 = internal::random<RealScalar>(0, 1);
|
||||
|
||||
@ -216,7 +216,7 @@ template <typename Scalar, int Dim>
|
||||
Matrix<Scalar, Dim, (1 << Dim)> boxGetCorners(const Matrix<Scalar, Dim, 1>& min_, const Matrix<Scalar, Dim, 1>& max_) {
|
||||
Matrix<Scalar, Dim, (1 << Dim)> result;
|
||||
for (Index i = 0; i < (1 << Dim); ++i) {
|
||||
for (Index j = 0; j < Dim; ++j) result(j, i) = (i & (1 << j)) ? min_(j) : max_(j);
|
||||
for (Index j = 0; j < Dim; ++j) result(j, i) = (i & (Index(1) << j)) ? min_(j) : max_(j);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
244
test/rand.cpp
244
test/rand.cpp
@ -7,10 +7,9 @@
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#include <cstdlib>
|
||||
#include "main.h"
|
||||
|
||||
typedef long long int64;
|
||||
|
||||
template <typename Scalar>
|
||||
Scalar check_in_range(Scalar x, Scalar y) {
|
||||
Scalar r = internal::random<Scalar>(x, y);
|
||||
@ -25,8 +24,8 @@ template <typename Scalar>
|
||||
void check_all_in_range(Scalar x, Scalar y) {
|
||||
Array<int, 1, Dynamic> mask(y - x + 1);
|
||||
mask.fill(0);
|
||||
long n = (y - x + 1) * 32;
|
||||
for (long k = 0; k < n; ++k) {
|
||||
int64_t n = (y - x + 1) * 32;
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
mask(check_in_range(x, y) - x)++;
|
||||
}
|
||||
for (Index i = 0; i < mask.size(); ++i)
|
||||
@ -34,82 +33,203 @@ void check_all_in_range(Scalar x, Scalar y) {
|
||||
VERIFY((mask > 0).all());
|
||||
}
|
||||
|
||||
template <typename Scalar, typename EnableIf = void>
|
||||
class HistogramHelper {
|
||||
public:
|
||||
HistogramHelper(int nbins) : HistogramHelper(Scalar(-1), Scalar(1), nbins) {}
|
||||
HistogramHelper(Scalar lower, Scalar upper, int nbins) {
|
||||
lower_ = static_cast<double>(lower);
|
||||
upper_ = static_cast<double>(upper);
|
||||
num_bins_ = nbins;
|
||||
bin_width_ = (upper_ - lower_) / static_cast<double>(nbins);
|
||||
}
|
||||
int bin(Scalar v) {
|
||||
double result = (static_cast<double>(v) - lower_) / bin_width_;
|
||||
return std::min<int>(static_cast<int>(result), num_bins_ - 1);
|
||||
}
|
||||
|
||||
double uniform_bin_probability(int bin) {
|
||||
double range = upper_ - lower_;
|
||||
if (bin < num_bins_ - 1) {
|
||||
return bin_width_ / range;
|
||||
}
|
||||
return (upper_ - (lower_ + double(bin) * bin_width_)) / range;
|
||||
}
|
||||
|
||||
private:
|
||||
double lower_;
|
||||
double upper_;
|
||||
int num_bins_;
|
||||
double bin_width_;
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
class HistogramHelper<Scalar, std::enable_if_t<Eigen::NumTraits<Scalar>::IsInteger>> {
|
||||
public:
|
||||
using RangeType = typename Eigen::internal::make_unsigned<Scalar>::type;
|
||||
HistogramHelper(int nbins)
|
||||
: HistogramHelper(Eigen::NumTraits<Scalar>::lowest(), Eigen::NumTraits<Scalar>::highest(), nbins) {}
|
||||
HistogramHelper(Scalar lower, Scalar upper, int nbins)
|
||||
: lower_{lower}, upper_{upper}, num_bins_{nbins}, bin_width_{bin_width(lower, upper, nbins)} {}
|
||||
|
||||
int bin(Scalar v) { return static_cast<int>(RangeType(v - lower_) / bin_width_); }
|
||||
|
||||
double uniform_bin_probability(int bin) {
|
||||
// Avoid overflow in computing range.
|
||||
double range = static_cast<double>(RangeType(upper_ - lower_)) + 1.0;
|
||||
if (bin < num_bins_ - 1) {
|
||||
return static_cast<double>(bin_width_) / range;
|
||||
}
|
||||
return static_cast<double>(RangeType(upper_) - RangeType((lower_ + bin * bin_width_)) + 1) / range;
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr Scalar bin_width(Scalar lower, Scalar upper, int nbins) {
|
||||
// Avoid overflow in computing the full range.
|
||||
return RangeType(upper - nbins - lower + 1) / nbins + 1;
|
||||
}
|
||||
|
||||
Scalar lower_;
|
||||
Scalar upper_;
|
||||
int num_bins_;
|
||||
Scalar bin_width_;
|
||||
};
|
||||
|
||||
template <typename Scalar>
|
||||
void check_histogram(Scalar x, Scalar y, int bins) {
|
||||
Array<int, 1, Dynamic> hist(bins);
|
||||
hist.fill(0);
|
||||
int f = 100000;
|
||||
int n = bins * f;
|
||||
int64 range = int64(y) - int64(x);
|
||||
int divisor = int((range + 1) / bins);
|
||||
assert(((range + 1) % bins) == 0);
|
||||
for (int k = 0; k < n; ++k) {
|
||||
Eigen::VectorXd hist = Eigen::VectorXd::Zero(bins);
|
||||
HistogramHelper<Scalar> hist_helper(x, y, bins);
|
||||
int64_t n = static_cast<int64_t>(bins) * 10000; // Approx 10000 per bin.
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
Scalar r = check_in_range(x, y);
|
||||
hist(int((int64(r) - int64(x)) / divisor))++;
|
||||
int bin = hist_helper.bin(r);
|
||||
hist(bin)++;
|
||||
}
|
||||
VERIFY((((hist.cast<double>() / double(f)) - 1.0).abs() < 0.03).all());
|
||||
// Normalize bins by probability.
|
||||
for (int i = 0; i < bins; ++i) {
|
||||
hist(i) = hist(i) / n / hist_helper.uniform_bin_probability(i);
|
||||
}
|
||||
VERIFY(((hist.array() - 1.0).abs() < 0.05).all());
|
||||
}
|
||||
|
||||
template <typename Scalar>
|
||||
void check_histogram(int bins) {
|
||||
Eigen::VectorXd hist = Eigen::VectorXd::Zero(bins);
|
||||
HistogramHelper<Scalar> hist_helper(bins);
|
||||
int64_t n = static_cast<int64_t>(bins) * 10000; // Approx 10000 per bin.
|
||||
for (int64_t k = 0; k < n; ++k) {
|
||||
Scalar r = Eigen::internal::random<Scalar>();
|
||||
int bin = hist_helper.bin(r);
|
||||
hist(bin)++;
|
||||
}
|
||||
// Normalize bins by probability.
|
||||
for (int i = 0; i < bins; ++i) {
|
||||
hist(i) = hist(i) / n / hist_helper.uniform_bin_probability(i);
|
||||
}
|
||||
VERIFY(((hist.array() - 1.0).abs() < 0.05).all());
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(rand) {
|
||||
long long_ref = NumTraits<long>::highest() / 10;
|
||||
int64_t int64_ref = NumTraits<int64_t>::highest() / 10;
|
||||
// the minimum guarantees that these conversions are safe
|
||||
auto char_offset = static_cast<signed char>((std::min)(g_repeat, 64));
|
||||
auto short_offset = static_cast<signed short>((std::min)(g_repeat, 8000));
|
||||
int8_t int8t_offset = static_cast<int8_t>((std::min)(g_repeat, 64));
|
||||
int16_t int16t_offset = static_cast<int16_t>((std::min)(g_repeat, 8000));
|
||||
EIGEN_UNUSED_VARIABLE(int64_ref);
|
||||
EIGEN_UNUSED_VARIABLE(int8t_offset);
|
||||
EIGEN_UNUSED_VARIABLE(int16t_offset);
|
||||
|
||||
for (int i = 0; i < g_repeat * 10000; i++) {
|
||||
CALL_SUBTEST(check_in_range<float>(10, 11));
|
||||
CALL_SUBTEST(check_in_range<float>(1.24234523f, 1.24234523f));
|
||||
CALL_SUBTEST(check_in_range<float>(-1, 1));
|
||||
CALL_SUBTEST(check_in_range<float>(-1432.2352f, -1432.2352f));
|
||||
CALL_SUBTEST_1(check_in_range<float>(10.0f, 11.0f));
|
||||
CALL_SUBTEST_1(check_in_range<float>(1.24234523f, 1.24234523f));
|
||||
CALL_SUBTEST_1(check_in_range<float>(-1.0f, 1.0f));
|
||||
CALL_SUBTEST_1(check_in_range<float>(-1432.2352f, -1432.2352f));
|
||||
|
||||
CALL_SUBTEST(check_in_range<double>(10, 11));
|
||||
CALL_SUBTEST(check_in_range<double>(1.24234523, 1.24234523));
|
||||
CALL_SUBTEST(check_in_range<double>(-1, 1));
|
||||
CALL_SUBTEST(check_in_range<double>(-1432.2352, -1432.2352));
|
||||
CALL_SUBTEST_2(check_in_range<double>(10.0, 11.0));
|
||||
CALL_SUBTEST_2(check_in_range<double>(1.24234523, 1.24234523));
|
||||
CALL_SUBTEST_2(check_in_range<double>(-1.0, 1.0));
|
||||
CALL_SUBTEST_2(check_in_range<double>(-1432.2352, -1432.2352));
|
||||
|
||||
CALL_SUBTEST(check_in_range<int>(0, -1));
|
||||
CALL_SUBTEST(check_in_range<short>(0, -1));
|
||||
CALL_SUBTEST(check_in_range<long>(0, -1));
|
||||
CALL_SUBTEST(check_in_range<int>(-673456, 673456));
|
||||
CALL_SUBTEST(check_in_range<int>(-RAND_MAX + 10, RAND_MAX - 10));
|
||||
CALL_SUBTEST(check_in_range<short>(-24345, 24345));
|
||||
CALL_SUBTEST(check_in_range<long>(-long_ref, long_ref));
|
||||
CALL_SUBTEST_3(check_in_range<long double>(10.0L, 11.0L));
|
||||
CALL_SUBTEST_3(check_in_range<long double>(1.24234523L, 1.24234523L));
|
||||
CALL_SUBTEST_3(check_in_range<long double>(-1.0L, 1.0L));
|
||||
CALL_SUBTEST_3(check_in_range<long double>(-1432.2352L, -1432.2352L));
|
||||
|
||||
CALL_SUBTEST_4(check_in_range<half>(half(10.0f), half(11.0f)));
|
||||
CALL_SUBTEST_4(check_in_range<half>(half(1.24234523f), half(1.24234523f)));
|
||||
CALL_SUBTEST_4(check_in_range<half>(half(-1.0f), half(1.0f)));
|
||||
CALL_SUBTEST_4(check_in_range<half>(half(-1432.2352f), half(-1432.2352f)));
|
||||
|
||||
CALL_SUBTEST_5(check_in_range<bfloat16>(bfloat16(10.0f), bfloat16(11.0f)));
|
||||
CALL_SUBTEST_5(check_in_range<bfloat16>(bfloat16(1.24234523f), bfloat16(1.24234523f)));
|
||||
CALL_SUBTEST_5(check_in_range<bfloat16>(bfloat16(-1.0f), bfloat16(1.0f)));
|
||||
CALL_SUBTEST_5(check_in_range<bfloat16>(bfloat16(-1432.2352f), bfloat16(-1432.2352f)));
|
||||
|
||||
CALL_SUBTEST_6(check_in_range<int32_t>(0, -1));
|
||||
CALL_SUBTEST_6(check_in_range<int16_t>(0, -1));
|
||||
CALL_SUBTEST_6(check_in_range<int64_t>(0, -1));
|
||||
CALL_SUBTEST_6(check_in_range<int32_t>(-673456, 673456));
|
||||
CALL_SUBTEST_6(check_in_range<int32_t>(-RAND_MAX + 10, RAND_MAX - 10));
|
||||
CALL_SUBTEST_6(check_in_range<int16_t>(-24345, 24345));
|
||||
CALL_SUBTEST_6(check_in_range<int64_t>(-int64_ref, int64_ref));
|
||||
}
|
||||
|
||||
CALL_SUBTEST(check_all_in_range<signed char>(11, 11));
|
||||
CALL_SUBTEST(check_all_in_range<signed char>(11, 11 + char_offset));
|
||||
CALL_SUBTEST(check_all_in_range<signed char>(-5, 5));
|
||||
CALL_SUBTEST(check_all_in_range<signed char>(-11 - char_offset, -11));
|
||||
CALL_SUBTEST(check_all_in_range<signed char>(-126, -126 + char_offset));
|
||||
CALL_SUBTEST(check_all_in_range<signed char>(126 - char_offset, 126));
|
||||
CALL_SUBTEST(check_all_in_range<signed char>(-126, 126));
|
||||
CALL_SUBTEST_7(check_all_in_range<int8_t>(11, 11));
|
||||
CALL_SUBTEST_7(check_all_in_range<int8_t>(11, 11 + int8t_offset));
|
||||
CALL_SUBTEST_7(check_all_in_range<int8_t>(-5, 5));
|
||||
CALL_SUBTEST_7(check_all_in_range<int8_t>(-11 - int8t_offset, -11));
|
||||
CALL_SUBTEST_7(check_all_in_range<int8_t>(-126, -126 + int8t_offset));
|
||||
CALL_SUBTEST_7(check_all_in_range<int8_t>(126 - int8t_offset, 126));
|
||||
CALL_SUBTEST_7(check_all_in_range<int8_t>(-126, 126));
|
||||
|
||||
CALL_SUBTEST(check_all_in_range<short>(11, 11));
|
||||
CALL_SUBTEST(check_all_in_range<short>(11, 11 + short_offset));
|
||||
CALL_SUBTEST(check_all_in_range<short>(-5, 5));
|
||||
CALL_SUBTEST(check_all_in_range<short>(-11 - short_offset, -11));
|
||||
CALL_SUBTEST(check_all_in_range<short>(-24345, -24345 + short_offset));
|
||||
CALL_SUBTEST(check_all_in_range<short>(24345, 24345 + short_offset));
|
||||
CALL_SUBTEST_8(check_all_in_range<int16_t>(11, 11));
|
||||
CALL_SUBTEST_8(check_all_in_range<int16_t>(11, 11 + int16t_offset));
|
||||
CALL_SUBTEST_8(check_all_in_range<int16_t>(-5, 5));
|
||||
CALL_SUBTEST_8(check_all_in_range<int16_t>(-11 - int16t_offset, -11));
|
||||
CALL_SUBTEST_8(check_all_in_range<int16_t>(-24345, -24345 + int16t_offset));
|
||||
CALL_SUBTEST_8(check_all_in_range<int16_t>(24345, 24345 + int16t_offset));
|
||||
|
||||
CALL_SUBTEST(check_all_in_range<int>(11, 11));
|
||||
CALL_SUBTEST(check_all_in_range<int>(11, 11 + g_repeat));
|
||||
CALL_SUBTEST(check_all_in_range<int>(-5, 5));
|
||||
CALL_SUBTEST(check_all_in_range<int>(-11 - g_repeat, -11));
|
||||
CALL_SUBTEST(check_all_in_range<int>(-673456, -673456 + g_repeat));
|
||||
CALL_SUBTEST(check_all_in_range<int>(673456, 673456 + g_repeat));
|
||||
CALL_SUBTEST_9(check_all_in_range<int32_t>(11, 11));
|
||||
CALL_SUBTEST_9(check_all_in_range<int32_t>(11, 11 + g_repeat));
|
||||
CALL_SUBTEST_9(check_all_in_range<int32_t>(-5, 5));
|
||||
CALL_SUBTEST_9(check_all_in_range<int32_t>(-11 - g_repeat, -11));
|
||||
CALL_SUBTEST_9(check_all_in_range<int32_t>(-673456, -673456 + g_repeat));
|
||||
CALL_SUBTEST_9(check_all_in_range<int32_t>(673456, 673456 + g_repeat));
|
||||
|
||||
CALL_SUBTEST(check_all_in_range<long>(11, 11));
|
||||
CALL_SUBTEST(check_all_in_range<long>(11, 11 + g_repeat));
|
||||
CALL_SUBTEST(check_all_in_range<long>(-5, 5));
|
||||
CALL_SUBTEST(check_all_in_range<long>(-11 - g_repeat, -11));
|
||||
CALL_SUBTEST(check_all_in_range<long>(-long_ref, -long_ref + g_repeat));
|
||||
CALL_SUBTEST(check_all_in_range<long>(long_ref, long_ref + g_repeat));
|
||||
CALL_SUBTEST_10(check_all_in_range<int64_t>(11, 11));
|
||||
CALL_SUBTEST_10(check_all_in_range<int64_t>(11, 11 + g_repeat));
|
||||
CALL_SUBTEST_10(check_all_in_range<int64_t>(-5, 5));
|
||||
CALL_SUBTEST_10(check_all_in_range<int64_t>(-11 - g_repeat, -11));
|
||||
CALL_SUBTEST_10(check_all_in_range<int64_t>(-int64_ref, -int64_ref + g_repeat));
|
||||
CALL_SUBTEST_10(check_all_in_range<int64_t>(int64_ref, int64_ref + g_repeat));
|
||||
|
||||
CALL_SUBTEST(check_histogram<int>(-5, 5, 11));
|
||||
CALL_SUBTEST_11(check_histogram<int32_t>(-5, 5, 11));
|
||||
int bins = 100;
|
||||
CALL_SUBTEST(check_histogram<int>(-3333, -3333 + bins * (3333 / bins) - 1, bins));
|
||||
EIGEN_UNUSED_VARIABLE(bins)
|
||||
CALL_SUBTEST_11(check_histogram<int32_t>(-3333, -3333 + bins * (3333 / bins) - 1, bins));
|
||||
bins = 1000;
|
||||
CALL_SUBTEST(check_histogram<int>(-RAND_MAX + 10, -RAND_MAX + 10 + bins * (RAND_MAX / bins) - 1, bins));
|
||||
CALL_SUBTEST(
|
||||
check_histogram<int>(-RAND_MAX + 10, -int64(RAND_MAX) + 10 + bins * (2 * int64(RAND_MAX) / bins) - 1, bins));
|
||||
CALL_SUBTEST_11(check_histogram<int32_t>(-RAND_MAX + 10, -RAND_MAX + 10 + bins * (RAND_MAX / bins) - 1, bins));
|
||||
CALL_SUBTEST_11(check_histogram<int32_t>(-RAND_MAX + 10,
|
||||
-int64_t(RAND_MAX) + 10 + bins * (2 * int64_t(RAND_MAX) / bins) - 1, bins));
|
||||
|
||||
CALL_SUBTEST_12(check_histogram<uint8_t>(/*bins=*/16));
|
||||
CALL_SUBTEST_12(check_histogram<uint16_t>(/*bins=*/1024));
|
||||
CALL_SUBTEST_12(check_histogram<uint32_t>(/*bins=*/1024));
|
||||
CALL_SUBTEST_12(check_histogram<uint64_t>(/*bins=*/1024));
|
||||
|
||||
CALL_SUBTEST_13(check_histogram<int8_t>(/*bins=*/16));
|
||||
CALL_SUBTEST_13(check_histogram<int16_t>(/*bins=*/1024));
|
||||
CALL_SUBTEST_13(check_histogram<int32_t>(/*bins=*/1024));
|
||||
CALL_SUBTEST_13(check_histogram<int64_t>(/*bins=*/1024));
|
||||
|
||||
CALL_SUBTEST_14(check_histogram<float>(-10.0f, 10.0f, /*bins=*/1024));
|
||||
CALL_SUBTEST_14(check_histogram<double>(-10.0, 10.0, /*bins=*/1024));
|
||||
CALL_SUBTEST_14(check_histogram<long double>(-10.0L, 10.0L, /*bins=*/1024));
|
||||
CALL_SUBTEST_14(check_histogram<half>(half(-10.0f), half(10.0f), /*bins=*/512));
|
||||
CALL_SUBTEST_14(check_histogram<bfloat16>(bfloat16(-10.0f), bfloat16(10.0f), /*bins=*/64));
|
||||
|
||||
CALL_SUBTEST_15(check_histogram<float>(/*bins=*/1024));
|
||||
CALL_SUBTEST_15(check_histogram<double>(/*bins=*/1024));
|
||||
CALL_SUBTEST_15(check_histogram<long double>(/*bins=*/1024));
|
||||
CALL_SUBTEST_15(check_histogram<half>(/*bins=*/512));
|
||||
CALL_SUBTEST_15(check_histogram<bfloat16>(/*bins=*/64));
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user