mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-13 20:26:03 +08:00
Properly vectorized the random number generators
This commit is contained in:
parent
caa54d888f
commit
ac2e6e0d03
@ -342,17 +342,17 @@ template <typename T> class UniformRandomGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
T operator()(Index, Index = 0) const {
|
T operator()(Index) const {
|
||||||
return random<T>();
|
return random<T>();
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
typename internal::packet_traits<T>::type packetOp(Index, Index = 0) const {
|
PacketType packetOp(Index) const {
|
||||||
const int packetSize = internal::packet_traits<T>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX T values[packetSize];
|
EIGEN_ALIGN_MAX T values[packetSize];
|
||||||
for (int i = 0; i < packetSize; ++i) {
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
values[i] = random<T>();
|
values[i] = random<T>();
|
||||||
}
|
}
|
||||||
return internal::pload<typename internal::packet_traits<T>::type>(values);
|
return internal::pload<PacketType>(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -370,22 +370,22 @@ template <> class UniformRandomGenerator<float> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
UniformRandomGenerator(const UniformRandomGenerator<float>& other) {
|
UniformRandomGenerator(const UniformRandomGenerator<float>& other) {
|
||||||
m_generator.seed(other(0, 0) * UINT_MAX);
|
m_generator.seed(other(0) * UINT_MAX);
|
||||||
m_deterministic = other.m_deterministic;
|
m_deterministic = other.m_deterministic;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
float operator()(Index, Index = 0) const {
|
float operator()(Index) const {
|
||||||
return m_distribution(m_generator);
|
return m_distribution(m_generator);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
typename internal::packet_traits<float>::type packetOp(Index i, Index j = 0) const {
|
PacketType packetOp(Index i) const {
|
||||||
const int packetSize = internal::packet_traits<float>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX float values[packetSize];
|
EIGEN_ALIGN_MAX float values[packetSize];
|
||||||
for (int k = 0; k < packetSize; ++k) {
|
for (int k = 0; k < packetSize; ++k) {
|
||||||
values[k] = this->operator()(i, j);
|
values[k] = this->operator()(i);
|
||||||
}
|
}
|
||||||
return internal::pload<typename internal::packet_traits<float>::type>(values);
|
return internal::pload<PacketType>(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -407,22 +407,22 @@ template <> class UniformRandomGenerator<double> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
UniformRandomGenerator(const UniformRandomGenerator<double>& other) {
|
UniformRandomGenerator(const UniformRandomGenerator<double>& other) {
|
||||||
m_generator.seed(other(0, 0) * UINT_MAX);
|
m_generator.seed(other(0) * UINT_MAX);
|
||||||
m_deterministic = other.m_deterministic;
|
m_deterministic = other.m_deterministic;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
double operator()(Index, Index = 0) const {
|
double operator()(Index) const {
|
||||||
return m_distribution(m_generator);
|
return m_distribution(m_generator);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
typename internal::packet_traits<double>::type packetOp(Index i, Index j = 0) const {
|
PacketType packetOp(Index i) const {
|
||||||
const int packetSize = internal::packet_traits<double>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX double values[packetSize];
|
EIGEN_ALIGN_MAX double values[packetSize];
|
||||||
for (int k = 0; k < packetSize; ++k) {
|
for (int k = 0; k < packetSize; ++k) {
|
||||||
values[k] = this->operator()(i, j);
|
values[k] = this->operator()(i);
|
||||||
}
|
}
|
||||||
return internal::pload<typename internal::packet_traits<double>::type>(values);
|
return internal::pload<PacketType>(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -458,11 +458,12 @@ template <> class UniformRandomGenerator<float> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ float operator()(Index, Index = 0) const {
|
__device__ float operator()(Index) const {
|
||||||
return curand_uniform(&m_state);
|
return curand_uniform(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
__device__ float4 packetOp(Index, Index = 0) const {
|
__device__ float4 packetOp(Index) const {
|
||||||
|
EIGEN_STATIC_ASSERT((is_same<PacketType, float4>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_uniform4(&m_state);
|
return curand_uniform4(&m_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -487,11 +488,12 @@ template <> class UniformRandomGenerator<double> {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ double operator()(Index, Index = 0) const {
|
__device__ double operator()(Index) const {
|
||||||
return curand_uniform_double(&m_state);
|
return curand_uniform_double(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
__device__ double2 packetOp(Index, Index = 0) const {
|
__device__ double2 packetOp(Index) const {
|
||||||
|
EIGEN_STATIC_ASSERT((is_same<PacketType, double2>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_uniform2_double(&m_state);
|
return curand_uniform2_double(&m_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -516,7 +518,7 @@ template <> class UniformRandomGenerator<std::complex<float> > {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ std::complex<float> operator()(Index, Index = 0) const {
|
__device__ std::complex<float> operator()(Index) const {
|
||||||
float4 vals = curand_uniform4(&m_state);
|
float4 vals = curand_uniform4(&m_state);
|
||||||
return std::complex<float>(vals.x, vals.y);
|
return std::complex<float>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
@ -542,7 +544,7 @@ template <> class UniformRandomGenerator<std::complex<double> > {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ std::complex<double> operator()(Index, Index = 0) const {
|
__device__ std::complex<double> operator()(Index) const {
|
||||||
double2 vals = curand_uniform2_double(&m_state);
|
double2 vals = curand_uniform2_double(&m_state);
|
||||||
return std::complex<double>(vals.x, vals.y);
|
return std::complex<double>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
@ -554,6 +556,14 @@ template <> class UniformRandomGenerator<std::complex<double> > {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct functor_traits<UniformRandomGenerator<Scalar> > {
|
||||||
|
enum {
|
||||||
|
PacketAccess = UniformRandomGenerator<Scalar>::PacketAccess
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#if (!defined (EIGEN_USE_GPU) || !defined(__CUDACC__) || !defined(__CUDA_ARCH__)) && __cplusplus > 199711
|
#if (!defined (EIGEN_USE_GPU) || !defined(__CUDACC__) || !defined(__CUDA_ARCH__)) && __cplusplus > 199711
|
||||||
// We're not compiling a cuda kernel
|
// We're not compiling a cuda kernel
|
||||||
@ -568,21 +578,21 @@ template <typename T> class NormalRandomGenerator {
|
|||||||
}
|
}
|
||||||
NormalRandomGenerator(const NormalRandomGenerator& other)
|
NormalRandomGenerator(const NormalRandomGenerator& other)
|
||||||
: m_deterministic(other.m_deterministic), m_distribution(other.m_distribution) {
|
: m_deterministic(other.m_deterministic), m_distribution(other.m_distribution) {
|
||||||
m_generator.seed(other(0, 0) * UINT_MAX);
|
m_generator.seed(other(0) * UINT_MAX);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
T operator()(Index, Index = 0) const {
|
T operator()(Index) const {
|
||||||
return m_distribution(m_generator);
|
return m_distribution(m_generator);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
typename internal::packet_traits<T>::type packetOp(Index, Index = 0) const {
|
PacketType packetOp(Index) const {
|
||||||
const int packetSize = internal::packet_traits<T>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX T values[packetSize];
|
EIGEN_ALIGN_MAX T values[packetSize];
|
||||||
for (int i = 0; i < packetSize; ++i) {
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
values[i] = m_distribution(m_generator);
|
values[i] = m_distribution(m_generator);
|
||||||
}
|
}
|
||||||
return internal::pload<typename internal::packet_traits<T>::type>(values);
|
return internal::pload<PacketType>(values);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -612,11 +622,12 @@ template <> class NormalRandomGenerator<float> {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ float operator()(Index, Index = 0) const {
|
__device__ float operator()(Index) const {
|
||||||
return curand_normal(&m_state);
|
return curand_normal(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
__device__ float4 packetOp(Index, Index = 0) const {
|
__device__ float4 packetOp(Index) const {
|
||||||
|
EIGEN_STATIC_ASSERT((is_same<PacketType, float4>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_normal4(&m_state);
|
return curand_normal4(&m_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -641,11 +652,12 @@ template <> class NormalRandomGenerator<double> {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ double operator()(Index, Index = 0) const {
|
__device__ double operator()(Index) const {
|
||||||
return curand_normal_double(&m_state);
|
return curand_normal_double(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index, typename PacketType>
|
||||||
__device__ double2 packetOp(Index, Index = 0) const {
|
__device__ double2 packetOp(Index) const {
|
||||||
|
EIGEN_STATIC_ASSERT((is_same<PacketType, double2>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_normal2_double(&m_state);
|
return curand_normal2_double(&m_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -670,7 +682,7 @@ template <> class NormalRandomGenerator<std::complex<float> > {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ std::complex<float> operator()(Index, Index = 0) const {
|
__device__ std::complex<float> operator()(Index) const {
|
||||||
float4 vals = curand_normal4(&m_state);
|
float4 vals = curand_normal4(&m_state);
|
||||||
return std::complex<float>(vals.x, vals.y);
|
return std::complex<float>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
@ -696,7 +708,7 @@ template <> class NormalRandomGenerator<std::complex<double> > {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
template<typename Index>
|
||||||
__device__ std::complex<double> operator()(Index, Index = 0) const {
|
__device__ std::complex<double> operator()(Index) const {
|
||||||
double2 vals = curand_normal2_double(&m_state);
|
double2 vals = curand_normal2_double(&m_state);
|
||||||
return std::complex<double>(vals.x, vals.y);
|
return std::complex<double>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
@ -718,6 +730,13 @@ template <typename T> class NormalRandomGenerator {
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
template <typename Scalar>
|
||||||
|
struct functor_traits<NormalRandomGenerator<Scalar> > {
|
||||||
|
enum {
|
||||||
|
PacketAccess = NormalRandomGenerator<Scalar>::PacketAccess
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
template <typename T, typename Index, size_t NumDims>
|
template <typename T, typename Index, size_t NumDims>
|
||||||
class GaussianGenerator {
|
class GaussianGenerator {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user