mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Add CurrentThreadId and NumThreads methods to Eigen threadpools and TensorDeviceThreadPool.
This commit is contained in:
parent
8d97ba6b22
commit
76308e7fd2
@ -172,6 +172,10 @@ struct ThreadPoolDevice {
|
|||||||
pool_->Schedule(func);
|
pool_->Schedule(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_STRONG_INLINE size_t currentThreadId() const {
|
||||||
|
return pool_->CurrentThreadId();
|
||||||
|
}
|
||||||
|
|
||||||
// parallelFor executes f with [0, n) arguments in parallel and waits for
|
// parallelFor executes f with [0, n) arguments in parallel and waits for
|
||||||
// completion. F accepts a half-open interval [first, last).
|
// completion. F accepts a half-open interval [first, last).
|
||||||
// Block size is choosen based on the iteration cost and resulting parallel
|
// Block size is choosen based on the iteration cost and resulting parallel
|
||||||
|
@ -74,7 +74,7 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
|
|||||||
PerThread* pt = GetPerThread();
|
PerThread* pt = GetPerThread();
|
||||||
if (pt->pool == this) {
|
if (pt->pool == this) {
|
||||||
// Worker thread of this pool, push onto the thread's queue.
|
// Worker thread of this pool, push onto the thread's queue.
|
||||||
Queue* q = queues_[pt->index];
|
Queue* q = queues_[pt->thread_id];
|
||||||
t = q->PushFront(std::move(t));
|
t = q->PushFront(std::move(t));
|
||||||
} else {
|
} else {
|
||||||
// A free-standing thread (or worker of another pool), push onto a random
|
// A free-standing thread (or worker of another pool), push onto a random
|
||||||
@ -95,13 +95,27 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
|
|||||||
env_.ExecuteTask(t); // Push failed, execute directly.
|
env_.ExecuteTask(t); // Push failed, execute directly.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t NumThreads() const final {
|
||||||
|
return threads_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CurrentThreadId() const {
|
||||||
|
const PerThread* pt =
|
||||||
|
const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread();
|
||||||
|
if (pt->pool == this) {
|
||||||
|
return static_cast<size_t>(pt->thread_id);
|
||||||
|
} else {
|
||||||
|
return threads_.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
typedef typename Environment::EnvThread Thread;
|
typedef typename Environment::EnvThread Thread;
|
||||||
|
|
||||||
struct PerThread {
|
struct PerThread {
|
||||||
bool inited;
|
bool inited;
|
||||||
NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
||||||
unsigned index; // Worker thread index in pool.
|
unsigned thread_id; // Worker thread index in pool.
|
||||||
unsigned rand; // Random generator state.
|
unsigned rand; // Random generator state.
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -116,12 +130,12 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
|
|||||||
EventCount ec_;
|
EventCount ec_;
|
||||||
|
|
||||||
// Main worker thread loop.
|
// Main worker thread loop.
|
||||||
void WorkerLoop(unsigned index) {
|
void WorkerLoop(unsigned thread_id) {
|
||||||
PerThread* pt = GetPerThread();
|
PerThread* pt = GetPerThread();
|
||||||
pt->pool = this;
|
pt->pool = this;
|
||||||
pt->index = index;
|
pt->thread_id = thread_id;
|
||||||
Queue* q = queues_[index];
|
Queue* q = queues_[thread_id];
|
||||||
EventCount::Waiter* waiter = &waiters_[index];
|
EventCount::Waiter* waiter = &waiters_[thread_id];
|
||||||
for (;;) {
|
for (;;) {
|
||||||
Task t = q->PopFront();
|
Task t = q->PopFront();
|
||||||
if (!t.f) {
|
if (!t.f) {
|
||||||
|
@ -24,7 +24,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
explicit SimpleThreadPoolTempl(int num_threads, Environment env = Environment())
|
explicit SimpleThreadPoolTempl(int num_threads, Environment env = Environment())
|
||||||
: env_(env), threads_(num_threads), waiters_(num_threads) {
|
: env_(env), threads_(num_threads), waiters_(num_threads) {
|
||||||
for (int i = 0; i < num_threads; i++) {
|
for (int i = 0; i < num_threads; i++) {
|
||||||
threads_.push_back(env.CreateThread([this]() { WorkerLoop(); }));
|
threads_.push_back(env.CreateThread([this, i]() { WorkerLoop(i); }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
|
|
||||||
// Schedule fn() for execution in the pool of threads. The functions are
|
// Schedule fn() for execution in the pool of threads. The functions are
|
||||||
// executed in the order in which they are scheduled.
|
// executed in the order in which they are scheduled.
|
||||||
void Schedule(std::function<void()> fn) {
|
void Schedule(std::function<void()> fn) final {
|
||||||
Task t = env_.CreateTask(std::move(fn));
|
Task t = env_.CreateTask(std::move(fn));
|
||||||
std::unique_lock<std::mutex> l(mu_);
|
std::unique_lock<std::mutex> l(mu_);
|
||||||
if (waiters_.empty()) {
|
if (waiters_.empty()) {
|
||||||
@ -69,9 +69,25 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t NumThreads() const final {
|
||||||
|
return threads_.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t CurrentThreadId() const final {
|
||||||
|
const PerThread* pt = this->GetPerThread();
|
||||||
|
if (pt->pool == this) {
|
||||||
|
return pt->thread_id;
|
||||||
|
} else {
|
||||||
|
return threads_.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
void WorkerLoop() {
|
void WorkerLoop(size_t thread_id) {
|
||||||
std::unique_lock<std::mutex> l(mu_);
|
std::unique_lock<std::mutex> l(mu_);
|
||||||
|
PerThread* pt = GetPerThread();
|
||||||
|
pt->pool = this;
|
||||||
|
pt->thread_id = thread_id;
|
||||||
Waiter w;
|
Waiter w;
|
||||||
Task t;
|
Task t;
|
||||||
while (!exiting_) {
|
while (!exiting_) {
|
||||||
@ -111,6 +127,11 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
bool ready;
|
bool ready;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct PerThread {
|
||||||
|
ThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
||||||
|
size_t thread_id; // Worker thread index in pool.
|
||||||
|
};
|
||||||
|
|
||||||
Environment env_;
|
Environment env_;
|
||||||
std::mutex mu_;
|
std::mutex mu_;
|
||||||
MaxSizeVector<Thread*> threads_; // All threads
|
MaxSizeVector<Thread*> threads_; // All threads
|
||||||
@ -118,6 +139,11 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
std::deque<Task> pending_; // Queue of pending work
|
std::deque<Task> pending_; // Queue of pending work
|
||||||
std::condition_variable empty_; // Signaled on pending_.empty()
|
std::condition_variable empty_; // Signaled on pending_.empty()
|
||||||
bool exiting_ = false;
|
bool exiting_ = false;
|
||||||
|
|
||||||
|
PerThread* GetPerThread() const {
|
||||||
|
static EIGEN_THREAD_LOCAL PerThread per_thread;
|
||||||
|
return &per_thread;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef SimpleThreadPoolTempl<StlThreadEnvironment> SimpleThreadPool;
|
typedef SimpleThreadPoolTempl<StlThreadEnvironment> SimpleThreadPool;
|
||||||
|
@ -18,6 +18,13 @@ class ThreadPoolInterface {
|
|||||||
public:
|
public:
|
||||||
virtual void Schedule(std::function<void()> fn) = 0;
|
virtual void Schedule(std::function<void()> fn) = 0;
|
||||||
|
|
||||||
|
// Returns the number of threads in the pool.
|
||||||
|
virtual size_t NumThreads() const = 0;
|
||||||
|
|
||||||
|
// Returns a logical thread index between 0 and NumThreads() - 1 if called
|
||||||
|
// from one of the threads in the pool. Returns NumThreads() otherwise.
|
||||||
|
virtual size_t CurrentThreadId() const = 0;
|
||||||
|
|
||||||
virtual ~ThreadPoolInterface() {}
|
virtual ~ThreadPoolInterface() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -27,6 +27,8 @@ static void test_parallelism()
|
|||||||
// Test we never-ever fail to match available tasks with idle threads.
|
// Test we never-ever fail to match available tasks with idle threads.
|
||||||
const int kThreads = 16; // code below expects that this is a multiple of 4
|
const int kThreads = 16; // code below expects that this is a multiple of 4
|
||||||
NonBlockingThreadPool tp(kThreads);
|
NonBlockingThreadPool tp(kThreads);
|
||||||
|
VERIFY_IS_EQUAL(tp.NumThreads(), kThreads);
|
||||||
|
VERIFY_IS_EQUAL(tp.CurrentThreadId(), kThreads);
|
||||||
for (int iter = 0; iter < 100; ++iter) {
|
for (int iter = 0; iter < 100; ++iter) {
|
||||||
std::atomic<int> running(0);
|
std::atomic<int> running(0);
|
||||||
std::atomic<int> done(0);
|
std::atomic<int> done(0);
|
||||||
@ -34,6 +36,9 @@ static void test_parallelism()
|
|||||||
// Schedule kThreads tasks and ensure that they all are running.
|
// Schedule kThreads tasks and ensure that they all are running.
|
||||||
for (int i = 0; i < kThreads; ++i) {
|
for (int i = 0; i < kThreads; ++i) {
|
||||||
tp.Schedule([&]() {
|
tp.Schedule([&]() {
|
||||||
|
const size_t thread_id = tp.CurrentThreadId();
|
||||||
|
VERIFY_GE(thread_id, 0);
|
||||||
|
VERIFY_LE(thread_id, kThreads - 1);
|
||||||
running++;
|
running++;
|
||||||
while (phase < 1) {
|
while (phase < 1) {
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user