mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Return -1 from CurrentThreadId when called by thread outside the pool.
This commit is contained in:
parent
d39df320d2
commit
a9c1e4d7b7
@ -172,6 +172,8 @@ struct ThreadPoolDevice {
|
|||||||
pool_->Schedule(func);
|
pool_->Schedule(func);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns a logical thread index between 0 and pool_->NumThreads() - 1 if
|
||||||
|
// called from one of the threads in pool_. Returns -1 otherwise.
|
||||||
EIGEN_STRONG_INLINE int currentThreadId() const {
|
EIGEN_STRONG_INLINE int currentThreadId() const {
|
||||||
return pool_->CurrentThreadId();
|
return pool_->CurrentThreadId();
|
||||||
}
|
}
|
||||||
|
@ -99,13 +99,13 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
|
|||||||
return static_cast<int>(threads_.size());
|
return static_cast<int>(threads_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
int CurrentThreadId() const {
|
int CurrentThreadId() const final {
|
||||||
const PerThread* pt =
|
const PerThread* pt =
|
||||||
const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread();
|
const_cast<NonBlockingThreadPoolTempl*>(this)->GetPerThread();
|
||||||
if (pt->pool == this) {
|
if (pt->pool == this) {
|
||||||
return pt->thread_id;
|
return pt->thread_id;
|
||||||
} else {
|
} else {
|
||||||
return NumThreads();
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,10 +113,10 @@ class NonBlockingThreadPoolTempl : public Eigen::ThreadPoolInterface {
|
|||||||
typedef typename Environment::EnvThread Thread;
|
typedef typename Environment::EnvThread Thread;
|
||||||
|
|
||||||
struct PerThread {
|
struct PerThread {
|
||||||
constexpr PerThread() : pool(NULL), index(-1), rand(0) { }
|
constexpr PerThread() : pool(NULL), rand(0), thread_id(-1) { }
|
||||||
NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
NonBlockingThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
||||||
int thread_id; // Worker thread index in pool.
|
|
||||||
uint64_t rand; // Random generator state.
|
uint64_t rand; // Random generator state.
|
||||||
|
int thread_id; // Worker thread index in pool.
|
||||||
};
|
};
|
||||||
|
|
||||||
Environment env_;
|
Environment env_;
|
||||||
|
@ -78,7 +78,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
if (pt->pool == this) {
|
if (pt->pool == this) {
|
||||||
return pt->thread_id;
|
return pt->thread_id;
|
||||||
} else {
|
} else {
|
||||||
return NumThreads();
|
return -1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -128,7 +128,8 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct PerThread {
|
struct PerThread {
|
||||||
ThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
constexpr PerThread() : pool(NULL), thread_id(-1) { }
|
||||||
|
SimpleThreadPoolTempl* pool; // Parent pool, or null for normal threads.
|
||||||
int thread_id; // Worker thread index in pool.
|
int thread_id; // Worker thread index in pool.
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -141,7 +142,7 @@ class SimpleThreadPoolTempl : public ThreadPoolInterface {
|
|||||||
bool exiting_ = false;
|
bool exiting_ = false;
|
||||||
|
|
||||||
PerThread* GetPerThread() const {
|
PerThread* GetPerThread() const {
|
||||||
static EIGEN_THREAD_LOCAL PerThread per_thread;
|
EIGEN_THREAD_LOCAL PerThread per_thread;
|
||||||
return &per_thread;
|
return &per_thread;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -22,7 +22,7 @@ class ThreadPoolInterface {
|
|||||||
virtual int NumThreads() const = 0;
|
virtual int NumThreads() const = 0;
|
||||||
|
|
||||||
// Returns a logical thread index between 0 and NumThreads() - 1 if called
|
// Returns a logical thread index between 0 and NumThreads() - 1 if called
|
||||||
// from one of the threads in the pool. Returns NumThreads() otherwise.
|
// from one of the threads in the pool. Returns -1 otherwise.
|
||||||
virtual int CurrentThreadId() const = 0;
|
virtual int CurrentThreadId() const = 0;
|
||||||
|
|
||||||
virtual ~ThreadPoolInterface() {}
|
virtual ~ThreadPoolInterface() {}
|
||||||
|
@ -28,7 +28,7 @@ static void test_parallelism()
|
|||||||
const int kThreads = 16; // code below expects that this is a multiple of 4
|
const int kThreads = 16; // code below expects that this is a multiple of 4
|
||||||
NonBlockingThreadPool tp(kThreads);
|
NonBlockingThreadPool tp(kThreads);
|
||||||
VERIFY_IS_EQUAL(tp.NumThreads(), kThreads);
|
VERIFY_IS_EQUAL(tp.NumThreads(), kThreads);
|
||||||
VERIFY_IS_EQUAL(tp.CurrentThreadId(), kThreads);
|
VERIFY_IS_EQUAL(tp.CurrentThreadId(), -1);
|
||||||
for (int iter = 0; iter < 100; ++iter) {
|
for (int iter = 0; iter < 100; ++iter) {
|
||||||
std::atomic<int> running(0);
|
std::atomic<int> running(0);
|
||||||
std::atomic<int> done(0);
|
std::atomic<int> done(0);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user