Added a test to validate full reduction on tensor of half floats

This commit is contained in:
Benoit Steiner 2016-05-05 16:52:50 -07:00
parent 678a17ba79
commit 69a8a4e1f3

View File

@ -305,6 +305,49 @@ void test_cuda_reductions() {
gpu_device.deallocate(d_res_float);
}
void test_cuda_full_reductions() {
Eigen::CudaStreamDevice stream;
Eigen::GpuDevice gpu_device(&stream);
int size = 13;
int num_elem = size*size;
float* d_float1 = (float*)gpu_device.allocate(num_elem * sizeof(float));
float* d_float2 = (float*)gpu_device.allocate(num_elem * sizeof(float));
Eigen::half* d_res_half = (Eigen::half*)gpu_device.allocate(1 * sizeof(Eigen::half));
Eigen::half* d_res_float = (Eigen::half*)gpu_device.allocate(1 * sizeof(Eigen::half));
Eigen::TensorMap<Eigen::Tensor<float, 2>, Eigen::Aligned> gpu_float1(
d_float1, size, size);
Eigen::TensorMap<Eigen::Tensor<float, 2>, Eigen::Aligned> gpu_float2(
d_float2, size, size);
Eigen::TensorMap<Eigen::Tensor<Eigen::half, 0>, Eigen::Aligned> gpu_res_half(
d_res_half);
Eigen::TensorMap<Eigen::Tensor<Eigen::half, 0>, Eigen::Aligned> gpu_res_float(
d_res_float);
gpu_float1.device(gpu_device) = gpu_float1.random();
gpu_float2.device(gpu_device) = gpu_float2.random();
gpu_res_float.device(gpu_device) = gpu_float1.sum().cast<Eigen::half>();
gpu_res_half.device(gpu_device) = gpu_float1.cast<Eigen::half>().sum();
Tensor<Eigen::half, 0> half_prec;
Tensor<Eigen::half, 0> full_prec;
gpu_device.memcpyDeviceToHost(half_prec.data(), d_res_half, sizeof(Eigen::half));
gpu_device.memcpyDeviceToHost(full_prec.data(), d_res_float, sizeof(Eigen::half));
gpu_device.synchronize();
VERIFY_IS_APPROX(full_prec(), half_prec());
gpu_device.deallocate(d_float1);
gpu_device.deallocate(d_float2);
gpu_device.deallocate(d_res_half);
gpu_device.deallocate(d_res_float);
}
void test_cuda_forced_evals() {
Eigen::CudaStreamDevice stream;
@ -354,6 +397,7 @@ void test_cxx11_tensor_of_float16_cuda()
CALL_SUBTEST_1(test_cuda_trancendental());
CALL_SUBTEST_2(test_cuda_contractions());
CALL_SUBTEST_3(test_cuda_reductions());
CALL_SUBTEST_3(test_cuda_full_reductions());
CALL_SUBTEST_4(test_cuda_forced_evals());
#else