diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index 16132398d..65f5c87e9 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -51,11 +51,7 @@ typedef unsigned __int64 uint64_t; #endif #ifdef EIGEN_USE_THREADS -#include -#include -#include -#include -#include +#include "ThreadPool" #endif #ifdef EIGEN_USE_GPU diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h index cd3dd214b..6da16985f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h @@ -12,145 +12,9 @@ namespace Eigen { -// This defines an interface that ThreadPoolDevice can take to use -// custom thread pools underneath. -class ThreadPoolInterface { - public: - virtual void Schedule(std::function fn) = 0; - - virtual ~ThreadPoolInterface() {} -}; - -// The implementation of the ThreadPool type ensures that the Schedule method -// runs the functions it is provided in FIFO order when the scheduling is done -// by a single thread. -// Environment provides a way to create threads and also allows to intercept -// task submission and execution. -template -class ThreadPoolTempl : public ThreadPoolInterface { - public: - // Construct a pool that contains "num_threads" threads. - explicit ThreadPoolTempl(int num_threads, Environment env = Environment()) - : env_(env), threads_(num_threads), waiters_(num_threads) { - for (int i = 0; i < num_threads; i++) { - threads_.push_back(env.CreateThread([this]() { WorkerLoop(); })); - } - } - - // Wait until all scheduled work has finished and then destroy the - // set of threads. - ~ThreadPoolTempl() { - { - // Wait for all work to get done. - std::unique_lock l(mu_); - while (!pending_.empty()) { - empty_.wait(l); - } - exiting_ = true; - - // Wakeup all waiters. - for (auto w : waiters_) { - w->ready = true; - w->task.f = nullptr; - w->cv.notify_one(); - } - } - - // Wait for threads to finish. - for (auto t : threads_) { - delete t; - } - } - - // Schedule fn() for execution in the pool of threads. The functions are - // executed in the order in which they are scheduled. - void Schedule(std::function fn) { - Task t = env_.CreateTask(std::move(fn)); - std::unique_lock l(mu_); - if (waiters_.empty()) { - pending_.push_back(std::move(t)); - } else { - Waiter* w = waiters_.back(); - waiters_.pop_back(); - w->ready = true; - w->task = std::move(t); - w->cv.notify_one(); - } - } - - protected: - void WorkerLoop() { - std::unique_lock l(mu_); - Waiter w; - Task t; - while (!exiting_) { - if (pending_.empty()) { - // Wait for work to be assigned to me - w.ready = false; - waiters_.push_back(&w); - while (!w.ready) { - w.cv.wait(l); - } - t = w.task; - w.task.f = nullptr; - } else { - // Pick up pending work - t = std::move(pending_.front()); - pending_.pop_front(); - if (pending_.empty()) { - empty_.notify_all(); - } - } - if (t.f) { - mu_.unlock(); - env_.ExecuteTask(t); - t.f = nullptr; - mu_.lock(); - } - } - } - - private: - typedef typename Environment::Task Task; - typedef typename Environment::EnvThread Thread; - - struct Waiter { - std::condition_variable cv; - Task task; - bool ready; - }; - - Environment env_; - std::mutex mu_; - MaxSizeVector threads_; // All threads - MaxSizeVector waiters_; // Stack of waiting threads. - std::deque pending_; // Queue of pending work - std::condition_variable empty_; // Signaled on pending_.empty() - bool exiting_ = false; -}; - -struct StlThreadEnvironment { - struct Task { - std::function f; - }; - - // EnvThread constructor must start the thread, - // destructor must join the thread. - class EnvThread { - public: - EnvThread(std::function f) : thr_(f) {} - ~EnvThread() { thr_.join(); } - - private: - std::thread thr_; - }; - - EnvThread* CreateThread(std::function f) { return new EnvThread(f); } - Task CreateTask(std::function f) { return Task{std::move(f)}; } - void ExecuteTask(const Task& t) { t.f(); } -}; - -typedef ThreadPoolTempl ThreadPool; +// Use the SimpleThreadPool by default. We'll switch to the new non blocking +// thread pool later. +typedef SimpleThreadPool ThreadPool; // Barrier is an object that allows one or more threads to wait until