Made ThreadPoolDevice inherit from a new pure abstract ThreadPoolInterface class: this enables users to leverage their existing threadpool when using eigen tensors.

This commit is contained in:
Benoit Steiner 2015-06-30 14:21:24 -07:00
parent 28b36632ec
commit f587075987

View File

@ -55,10 +55,19 @@ struct DefaultDevice {
// We should really use a thread pool here but first we need to find a portable thread pool library. // We should really use a thread pool here but first we need to find a portable thread pool library.
#ifdef EIGEN_USE_THREADS #ifdef EIGEN_USE_THREADS
// This defines an interface that ThreadPoolDevice can take to use
// custom thread pools underneath.
class ThreadPoolInterface {
public:
virtual void Schedule(std::function<void()> fn) = 0;
virtual ~ThreadPoolInterface() {}
};
// The implementation of the ThreadPool type ensures that the Schedule method // The implementation of the ThreadPool type ensures that the Schedule method
// runs the functions it is provided in FIFO order when the scheduling is done // runs the functions it is provided in FIFO order when the scheduling is done
// by a single thread. // by a single thread.
class ThreadPool { class ThreadPool : public ThreadPoolInterface {
public: public:
// Construct a pool that contains "num_threads" threads. // Construct a pool that contains "num_threads" threads.
explicit ThreadPool(int num_threads) { explicit ThreadPool(int num_threads) {
@ -199,7 +208,7 @@ static EIGEN_STRONG_INLINE void wait_until_ready(Notification* n) {
// Build a thread pool device on top the an existing pool of threads. // Build a thread pool device on top the an existing pool of threads.
struct ThreadPoolDevice { struct ThreadPoolDevice {
ThreadPoolDevice(ThreadPool* pool, size_t num_cores) : pool_(pool), num_threads_(num_cores) { } ThreadPoolDevice(ThreadPoolInterface* pool, size_t num_cores) : pool_(pool), num_threads_(num_cores) { }
EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const { EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
return internal::aligned_malloc(num_bytes); return internal::aligned_malloc(num_bytes);
@ -241,7 +250,7 @@ struct ThreadPoolDevice {
} }
private: private:
ThreadPool* pool_; ThreadPoolInterface* pool_;
size_t num_threads_; size_t num_threads_;
}; };