Some serialization API changes were made in commit...

This commit is contained in:
Rohit Santhanam 2022-01-05 16:18:45 +00:00 committed by Antonio Sánchez
parent 9210e71fb3
commit 27a78e4f96

View File

@ -231,6 +231,7 @@ auto run_serialized_on_gpu(size_t buffer_capacity_hint,
std::vector<uint8_t> buffer(capacity);
uint8_t* host_data = nullptr;
uint8_t* host_data_end = nullptr;
uint8_t* host_ptr = nullptr;
uint8_t* device_data = nullptr;
size_t output_data_size = 0;
@ -239,8 +240,9 @@ auto run_serialized_on_gpu(size_t buffer_capacity_hint,
capacity = std::max<size_t>(capacity, output_data_size);
buffer.resize(capacity);
host_data = buffer.data();
host_ptr = Eigen::serialize(host_data, input_data_size);
host_ptr = Eigen::serialize(host_ptr, args...);
host_data_end = buffer.data() + capacity;
host_ptr = Eigen::serialize(host_data, host_data_end, input_data_size);
host_ptr = Eigen::serialize(host_ptr, host_data_end, args...);
// Copy inputs to host.
gpuMalloc((void**)(&device_data), capacity);
@ -265,7 +267,7 @@ auto run_serialized_on_gpu(size_t buffer_capacity_hint,
GPU_CHECK(gpuDeviceSynchronize());
// Determine output buffer size.
host_ptr = Eigen::deserialize(host_data, output_data_size);
const uint8_t* c_host_ptr = Eigen::deserialize(host_data, host_data_end, output_data_size);
// If the output doesn't fit in the buffer, spit out warning and fail.
if (output_data_size > capacity) {
std::cerr << "The serialized output does not fit in the output buffer, "
@ -280,11 +282,11 @@ auto run_serialized_on_gpu(size_t buffer_capacity_hint,
// Deserialize outputs.
auto args_tuple = test_detail::tie(args...);
EIGEN_UNUSED_VARIABLE(args_tuple) // Avoid NVCC compile warning.
host_ptr = Eigen::deserialize(host_ptr, test_detail::get<OutputIndices, Args&...>(args_tuple)...);
c_host_ptr = Eigen::deserialize(c_host_ptr, host_data_end, test_detail::get<OutputIndices, Args&...>(args_tuple)...);
// Maybe deserialize return value, properly handling void.
typename void_helper::ReturnType<decltype(kernel(args...))> result;
host_ptr = Eigen::deserialize(host_ptr, result);
c_host_ptr = Eigen::deserialize(c_host_ptr, host_data_end, result);
return void_helper::restore(result);
}