mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Added a few tests to validate the generation of random tensors on GPU.
This commit is contained in:
parent
6a5717dc74
commit
cbb14ed47e
@ -16,7 +16,55 @@
|
||||
#include "main.h"
|
||||
#include <Eigen/CXX11/Tensor>
|
||||
|
||||
static void test_default()
|
||||
|
||||
void test_cuda_random_uniform()
|
||||
{
|
||||
Tensor<float, 2> out(72,97);
|
||||
out.setZero();
|
||||
|
||||
std::size_t out_bytes = out.size() * sizeof(float);
|
||||
|
||||
float* d_out;
|
||||
cudaMalloc((void**)(&d_out), out_bytes);
|
||||
|
||||
Eigen::CudaStreamDevice stream;
|
||||
Eigen::GpuDevice gpu_device(&stream);
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<float, 2> > gpu_out(d_out, 72,97);
|
||||
|
||||
gpu_out.device(gpu_device) = gpu_out.random();
|
||||
|
||||
assert(cudaMemcpyAsync(out.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||
|
||||
// For now we just check thes code doesn't crash.
|
||||
// TODO: come up with a valid test of randomness
|
||||
}
|
||||
|
||||
|
||||
void test_cuda_random_normal()
|
||||
{
|
||||
Tensor<float, 2> out(72,97);
|
||||
out.setZero();
|
||||
|
||||
std::size_t out_bytes = out.size() * sizeof(float);
|
||||
|
||||
float* d_out;
|
||||
cudaMalloc((void**)(&d_out), out_bytes);
|
||||
|
||||
Eigen::CudaStreamDevice stream;
|
||||
Eigen::GpuDevice gpu_device(&stream);
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<float, 2> > gpu_out(d_out, 72,97);
|
||||
|
||||
Eigen::internal::NormalRandomGenerator<float> gen(true);
|
||||
gpu_out.device(gpu_device) = gpu_out.random(gen);
|
||||
|
||||
assert(cudaMemcpyAsync(out.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||
}
|
||||
|
||||
static void test_complex()
|
||||
{
|
||||
Tensor<std::complex<float>, 1> vec(6);
|
||||
vec.setRandom();
|
||||
@ -31,5 +79,7 @@ static void test_default()
|
||||
|
||||
void test_cxx11_tensor_random_cuda()
|
||||
{
|
||||
CALL_SUBTEST(test_default());
|
||||
CALL_SUBTEST(test_cuda_random_uniform());
|
||||
CALL_SUBTEST(test_cuda_random_normal());
|
||||
CALL_SUBTEST(test_complex());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user