Return -1 from CurrentThreadId when called by thread outside the pool.

This commit is contained in:
Rasmus Munk Larsen 2016-06-23 16:40:07 -07:00
parent d39df320d2
commit a9c1e4d7b7
5 changed files with 14 additions and 11 deletions

View File

@ -172,6 +172,8 @@ struct ThreadPoolDevice {
pool_->Schedule(func);
}
// Returns a logical thread index between 0 and pool_->NumThreads() - 1 if
// called from one of the threads in pool_. Returns -1 otherwise.
EIGEN_STRONG_INLINE int currentThreadId() const {
return pool_->CurrentThreadId();
}

View File

@ -99,13 +99,13 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
return static_cast<int>(threads_.size());
}
int CurrentThreadId() const {
int CurrentThreadId() const final {
const PerThread* pt =
const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread();
if (pt->pool == this) {
return pt->thread_id;
} else {
return NumThreads();
return -1;
}
}
@ -113,10 +113,10 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
typedef typename Environment::EnvThread Thread;
struct PerThread {
constexpr PerThread() : pool(NULL), index(-1), rand(0) { }
constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) { }
NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads.
int thread_id; // Worker thread index in pool.
uint64_t rand; // Random generator state.
uint64_t rand; // Random generator state.
int thread_id; // Worker thread index in pool.
};
Environment env_;

View File

@ -78,7 +78,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
if (pt->pool == this) {
return pt->thread_id;
} else {
return NumThreads();
return -1;
}
}
@ -128,8 +128,9 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
};
struct PerThread {
ThreadPoolTempl* pool; // Parent pool, or null for normal threads.
int thread_id; // Worker thread index in pool.
constexpr PerThread() : pool(NULL), thread_id(-1) { }
SimpleThreadPoolTempl* pool; // Parent pool, or null for normal threads.
int thread_id; // Worker thread index in pool.
};
Environment env_;
@ -141,7 +142,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
bool exiting_ = false;
PerThread* GetPerThread() const {
static EIGEN_THREAD_LOCAL PerThread per_thread;
EIGEN_THREAD_LOCAL PerThread per_thread;
return &per_thread;
}
};

View File

@ -22,7 +22,7 @@ class ThreadPoolInterface {
virtual int NumThreads() const = 0;
// Returns a logical thread index between 0 and NumThreads() - 1 if called
// from one of the threads in the pool. Returns NumThreads() otherwise.
// from one of the threads in the pool. Returns -1 otherwise.
virtual int CurrentThreadId() const = 0;
virtual ~ThreadPoolInterface() {}

View File

@ -28,7 +28,7 @@ static void test_parallelism()
const int kThreads = 16; // code below expects that this is a multiple of 4
NonBlockingThreadPool tp(kThreads);
VERIFY_IS_EQUAL(tp.NumThreads(), kThreads);
VERIFY_IS_EQUAL(tp.CurrentThreadId(), kThreads);
VERIFY_IS_EQUAL(tp.CurrentThreadId(), -1);
for (int iter = 0; iter < 100; ++iter) {
std::atomic<int> running(0);
std::atomic<int> done(0);