mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Added a test to validate tensor casting on cuda devices
This commit is contained in:
parent
6620aaa4b3
commit
4470c99975
@ -460,6 +460,45 @@ static void test_cuda_constant_broadcast()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_cuda_cast()
|
||||
{
|
||||
Tensor<double, 3> in(Eigen::array<int, 3>(72,53,97));
|
||||
Tensor<float, 3> out(Eigen::array<int, 3>(72,53,97));
|
||||
in.setRandom();
|
||||
|
||||
std::size_t in_bytes = in.size() * sizeof(double);
|
||||
std::size_t out_bytes = out.size() * sizeof(float);
|
||||
|
||||
double* d_in;
|
||||
float* d_out;
|
||||
cudaMalloc((void**)(&d_in), in_bytes);
|
||||
cudaMalloc((void**)(&d_out), out_bytes);
|
||||
|
||||
cudaMemcpy(d_in, in.data(), in_bytes, cudaMemcpyHostToDevice);
|
||||
|
||||
cudaStream_t stream;
|
||||
assert(cudaStreamCreate(&stream) == cudaSuccess);
|
||||
Eigen::GpuDevice gpu_device(&stream);
|
||||
|
||||
Eigen::TensorMap<Eigen::Tensor<double, 3> > gpu_in(d_in, Eigen::array<int, 3>(72,53,97));
|
||||
Eigen::TensorMap<Eigen::Tensor<float, 3> > gpu_out(d_out, Eigen::array<int, 3>(72,53,97));
|
||||
|
||||
gpu_out.device(gpu_device) = gpu_in.template cast<float>();
|
||||
|
||||
assert(cudaMemcpyAsync(out.data(), d_out, out_bytes, cudaMemcpyDeviceToHost, gpu_device.stream()) == cudaSuccess);
|
||||
assert(cudaStreamSynchronize(gpu_device.stream()) == cudaSuccess);
|
||||
|
||||
for (int i = 0; i < 72; ++i) {
|
||||
for (int j = 0; j < 53; ++j) {
|
||||
for (int k = 0; k < 97; ++k) {
|
||||
VERIFY_IS_APPROX(out(Eigen::array<int, 3>(i,j,k)), static_cast<float>(in(Eigen::array<int, 3>(i,j,k))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void test_cxx11_tensor_cuda()
|
||||
{
|
||||
CALL_SUBTEST(test_cuda_elementwise_small());
|
||||
@ -471,4 +510,5 @@ void test_cxx11_tensor_cuda()
|
||||
CALL_SUBTEST(test_cuda_convolution_2d());
|
||||
CALL_SUBTEST(test_cuda_convolution_3d());
|
||||
CALL_SUBTEST(test_cuda_constant_broadcast());
|
||||
CALL_SUBTEST(test_cuda_cast());
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user