Add a ThreadPoolInterface* getter for ThreadPoolDevice.

This commit is contained in:
Penporn Koanantakool 2018-06-02 12:07:49 -07:00
parent 84868da904
commit e2ed0cf8ab

View File

@ -169,7 +169,7 @@ struct ThreadPoolDevice {
// parallelFor executes f with [0, n) arguments in parallel and waits for // parallelFor executes f with [0, n) arguments in parallel and waits for
// completion. F accepts a half-open interval [first, last). // completion. F accepts a half-open interval [first, last).
// Block size is choosen based on the iteration cost and resulting parallel // Block size is chosen based on the iteration cost and resulting parallel
// efficiency. If block_align is not nullptr, it is called to round up the // efficiency. If block_align is not nullptr, it is called to round up the
// block size. // block size.
void parallelFor(Index n, const TensorOpCost& cost, void parallelFor(Index n, const TensorOpCost& cost,
@ -261,6 +261,9 @@ struct ThreadPoolDevice {
parallelFor(n, cost, nullptr, std::move(f)); parallelFor(n, cost, nullptr, std::move(f));
} }
// Thread pool accessor.
ThreadPoolInterface* getPool() const { return pool_; }
private: private:
ThreadPoolInterface* pool_; ThreadPoolInterface* pool_;
int num_threads_; int num_threads_;