Increased the functionality of the tensor devices

This commit is contained in:
Benoit Steiner 2015-01-14 11:45:17 -08:00
parent 5692723c58
commit 0a0ab6dd15

View File

@ -43,11 +43,14 @@ typedef std::promise<void> Promise;
static EIGEN_STRONG_INLINE void wait_until_ready(const Future* f) { static EIGEN_STRONG_INLINE void wait_until_ready(const Future* f) {
f->wait(); f->wait();
// eigen_assert(f->ready()); }
static EIGEN_STRONG_INLINE void get_when_ready(Future* f) {
f->get();
} }
struct ThreadPoolDevice { struct ThreadPoolDevice {
ThreadPoolDevice(/*ThreadPool* pool, */size_t num_cores) : num_threads_(num_cores) { } ThreadPoolDevice(size_t num_cores) : num_threads_(num_cores) { }
EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const { EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
return internal::aligned_malloc(num_bytes); return internal::aligned_malloc(num_bytes);
@ -79,9 +82,9 @@ struct ThreadPoolDevice {
} }
private: private:
// todo: NUMA, ...
size_t num_threads_; size_t num_threads_;
}; };
#endif #endif
@ -114,6 +117,10 @@ static inline int sharedMemPerBlock() {
return m_deviceProperties.sharedMemPerBlock; return m_deviceProperties.sharedMemPerBlock;
} }
static inline void setCudaSharedMemConfig(cudaSharedMemConfig config) {
cudaError_t status = cudaDeviceSetSharedMemConfig(config);
assert(status == cudaSuccess);
}
struct GpuDevice { struct GpuDevice {
// The cudastream is not owned: the caller is responsible for its initialization and eventual destruction. // The cudastream is not owned: the caller is responsible for its initialization and eventual destruction.
@ -163,10 +170,19 @@ struct GpuDevice {
return 32; return 32;
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void synchronize() const {
cudaStreamSynchronize(*stream_);
}
private: private:
// TODO: multigpu. // TODO: multigpu.
const cudaStream_t* stream_; const cudaStream_t* stream_;
}; };
#define LAUNCH_CUDA_KERNEL(kernel, gridsize, blocksize, sharedmem, device, ...) \
(kernel) <<< (gridsize), (blocksize), (sharedmem), (device).stream() >>> (__VA_ARGS__); \
assert(cudaGetLastError() == cudaSuccess);
#endif #endif
} // end namespace Eigen } // end namespace Eigen