Made the initialization of a CUDA device thread safe.

This commit is contained in:
Benoit Steiner 2016-09-26 11:00:32 -07:00
parent 48dfe98abd
commit 6565f8d60f
2 changed files with 31 additions and 1 deletions

View File

@ -64,6 +64,10 @@ typedef unsigned __int64 uint64_t;
#if defined(__CUDACC__) #if defined(__CUDACC__)
#include <curand_kernel.h> #include <curand_kernel.h>
#endif #endif
#if __cplusplus >= 201103L
#include <atomic>
#include <unistd.h>
#endif
#endif #endif
#include "src/Tensor/TensorMacros.h" #include "src/Tensor/TensorMacros.h"

View File

@ -42,7 +42,21 @@ static bool m_devicePropInitialized = false;
static void initializeDeviceProp() { static void initializeDeviceProp() {
if (!m_devicePropInitialized) { if (!m_devicePropInitialized) {
if (!m_devicePropInitialized) { // Attempts to ensure proper behavior in the case of multiple threads
// calling this function simultaneously. This would be trivial to
// implement if we could use std::mutex, but unfortunately mutex don't
// compile with nvcc, so we resort to atomics and thread fences instead.
// Note that if the caller uses a compiler that doesn't support c++11 we
// can't ensure that the initialization is thread safe.
#if __cplusplus >= 201103L
static std::atomic<bool> first(true);
if (first.exchange(false)) {
#else
static bool first = true;
if (first) {
first = false;
#endif
// We're the first thread to reach this point.
int num_devices; int num_devices;
cudaError_t status = cudaGetDeviceCount(&num_devices); cudaError_t status = cudaGetDeviceCount(&num_devices);
if (status != cudaSuccess) { if (status != cudaSuccess) {
@ -63,7 +77,19 @@ static void initializeDeviceProp() {
assert(status == cudaSuccess); assert(status == cudaSuccess);
} }
} }
#if __cplusplus >= 201103L
std::atomic_thread_fence(std::memory_order_release);
#endif
m_devicePropInitialized = true; m_devicePropInitialized = true;
} else {
// Wait for the other thread to inititialize the properties.
while (!m_devicePropInitialized) {
#if __cplusplus >= 201103L
std::atomic_thread_fence(std::memory_order_acquire);
#endif
sleep(1);
}
} }
} }
} }