Improve EventCount used by the non-blocking threadpool.

The current algorithm requires threads to commit/cancel waiting in order
they called Prewait. Spinning caused by that serialization can consume
lots of CPU time on some workloads. Restructure the algorithm to not
require that serialization and remove spin waits from Commit/CancelWait.
Note: this reduces max number of threads from 2^16 to 2^14 to leave
more space for ABA counter (which is now 22 bits).
Implementation details are explained in comments.
This commit is contained in:
Rasmus Munk Larsen 2019-02-22 13:56:26 -08:00
parent 0b25a5c431
commit 6560692c67
3 changed files with 110 additions and 93 deletions

View File

@ -20,7 +20,8 @@ namespace Eigen {
// if (predicate) // if (predicate)
// return act(); // return act();
// EventCount::Waiter& w = waiters[my_index]; // EventCount::Waiter& w = waiters[my_index];
// ec.Prewait(&w); // if (!ec.Prewait(&w))
// return act();
// if (predicate) { // if (predicate) {
// ec.CancelWait(&w); // ec.CancelWait(&w);
// return act(); // return act();
@ -50,78 +51,78 @@ class EventCount {
public: public:
class Waiter; class Waiter;
EventCount(MaxSizeVector<Waiter>& waiters) : waiters_(waiters) { EventCount(MaxSizeVector<Waiter>& waiters)
: state_(kStackMask), waiters_(waiters) {
eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1); eigen_plain_assert(waiters.size() < (1 << kWaiterBits) - 1);
// Initialize epoch to something close to overflow to test overflow.
state_ = kStackMask | (kEpochMask - kEpochInc * waiters.size() * 2);
} }
~EventCount() { ~EventCount() {
// Ensure there are no waiters. // Ensure there are no waiters.
eigen_plain_assert((state_.load() & (kStackMask | kWaiterMask)) == kStackMask); eigen_plain_assert(state_.load() == kStackMask);
} }
// Prewait prepares for waiting. // Prewait prepares for waiting.
// After calling this function the thread must re-check the wait predicate // If Prewait returns true, the thread must re-check the wait predicate
// and call either CancelWait or CommitWait passing the same Waiter object. // and then call either CancelWait or CommitWait.
void Prewait(Waiter* w) { // Otherwise, the thread should assume the predicate may be true
w->epoch = state_.fetch_add(kWaiterInc, std::memory_order_relaxed); // and don't call CancelWait/CommitWait (there was a concurrent Notify call).
std::atomic_thread_fence(std::memory_order_seq_cst); bool Prewait() {
uint64_t state = state_.load(std::memory_order_relaxed);
for (;;) {
CheckState(state);
uint64_t newstate = state + kWaiterInc;
if ((state & kSignalMask) != 0) {
// Consume the signal and cancel waiting.
newstate -= kSignalInc + kWaiterInc;
}
CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate,
std::memory_order_seq_cst))
return (state & kSignalMask) == 0;
}
} }
// CommitWait commits waiting. // CommitWait commits waiting after Prewait.
void CommitWait(Waiter* w) { void CommitWait(Waiter* w) {
eigen_plain_assert((w->epoch & ~kEpochMask) == 0);
w->state = Waiter::kNotSignaled; w->state = Waiter::kNotSignaled;
// Modification epoch of this waiter. const uint64_t me = (w - &waiters_[0]) | w->epoch;
uint64_t epoch =
(w->epoch & kEpochMask) +
(((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
uint64_t state = state_.load(std::memory_order_seq_cst); uint64_t state = state_.load(std::memory_order_seq_cst);
for (;;) { for (;;) {
if (int64_t((state & kEpochMask) - epoch) < 0) { CheckState(state, true);
// The preceding waiter has not decided on its fate. Wait until it uint64_t newstate;
// calls either CancelWait or CommitWait, or is notified. if ((state & kSignalMask) != 0) {
EIGEN_THREAD_YIELD(); // Consume the signal and return immidiately.
state = state_.load(std::memory_order_seq_cst); newstate = state - kWaiterInc - kSignalInc;
continue; } else {
// Remove this thread from pre-wait counter and add to the waiter stack.
newstate = ((state & kWaiterMask) - kWaiterInc) | me;
w->next.store(state & (kStackMask | kEpochMask),
std::memory_order_relaxed);
} }
// We've already been notified. CheckState(newstate);
if (int64_t((state & kEpochMask) - epoch) > 0) return;
// Remove this thread from prewait counter and add it to the waiter list.
eigen_plain_assert((state & kWaiterMask) != 0);
uint64_t newstate = state - kWaiterInc + kEpochInc;
newstate = (newstate & ~kStackMask) | (w - &waiters_[0]);
if ((state & kStackMask) == kStackMask)
w->next.store(nullptr, std::memory_order_relaxed);
else
w->next.store(&waiters_[state & kStackMask], std::memory_order_relaxed);
if (state_.compare_exchange_weak(state, newstate, if (state_.compare_exchange_weak(state, newstate,
std::memory_order_release)) std::memory_order_acq_rel)) {
break; if ((state & kSignalMask) == 0) {
w->epoch += kEpochInc;
Park(w);
}
return;
}
} }
Park(w);
} }
// CancelWait cancels effects of the previous Prewait call. // CancelWait cancels effects of the previous Prewait call.
void CancelWait(Waiter* w) { void CancelWait() {
uint64_t epoch =
(w->epoch & kEpochMask) +
(((w->epoch & kWaiterMask) >> kWaiterShift) << kEpochShift);
uint64_t state = state_.load(std::memory_order_relaxed); uint64_t state = state_.load(std::memory_order_relaxed);
for (;;) { for (;;) {
if (int64_t((state & kEpochMask) - epoch) < 0) { CheckState(state, true);
// The preceding waiter has not decided on its fate. Wait until it uint64_t newstate = state - kWaiterInc;
// calls either CancelWait or CommitWait, or is notified. // Also take away a signal if any.
EIGEN_THREAD_YIELD(); if ((state & kSignalMask) != 0) newstate -= kSignalInc;
state = state_.load(std::memory_order_relaxed); CheckState(newstate);
continue; if (state_.compare_exchange_weak(state, newstate,
} std::memory_order_acq_rel))
// We've already been notified.
if (int64_t((state & kEpochMask) - epoch) > 0) return;
// Remove this thread from prewait counter.
eigen_plain_assert((state & kWaiterMask) != 0);
if (state_.compare_exchange_weak(state, state - kWaiterInc + kEpochInc,
std::memory_order_relaxed))
return; return;
} }
} }
@ -132,35 +133,33 @@ class EventCount {
std::atomic_thread_fence(std::memory_order_seq_cst); std::atomic_thread_fence(std::memory_order_seq_cst);
uint64_t state = state_.load(std::memory_order_acquire); uint64_t state = state_.load(std::memory_order_acquire);
for (;;) { for (;;) {
CheckState(state);
const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
const uint64_t signals = (state & kSignalMask) >> kSignalShift;
// Easy case: no waiters. // Easy case: no waiters.
if ((state & kStackMask) == kStackMask && (state & kWaiterMask) == 0) if ((state & kStackMask) == kStackMask && waiters == signals) return;
return;
uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
uint64_t newstate; uint64_t newstate;
if (notifyAll) { if (notifyAll) {
// Reset prewait counter and empty wait list. // Empty wait stack and set signal to number of pre-wait threads.
newstate = (state & kEpochMask) + (kEpochInc * waiters) + kStackMask; newstate =
} else if (waiters) { (state & kWaiterMask) | (waiters << kSignalShift) | kStackMask;
} else if (signals < waiters) {
// There is a thread in pre-wait state, unblock it. // There is a thread in pre-wait state, unblock it.
newstate = state + kEpochInc - kWaiterInc; newstate = state + kSignalInc;
} else { } else {
// Pop a waiter from list and unpark it. // Pop a waiter from list and unpark it.
Waiter* w = &waiters_[state & kStackMask]; Waiter* w = &waiters_[state & kStackMask];
Waiter* wnext = w->next.load(std::memory_order_relaxed); uint64_t next = w->next.load(std::memory_order_relaxed);
uint64_t next = kStackMask; newstate = (state & (kWaiterMask | kSignalMask)) | next;
if (wnext != nullptr) next = wnext - &waiters_[0];
// Note: we don't add kEpochInc here. ABA problem on the lock-free stack
// can't happen because a waiter is re-pushed onto the stack only after
// it was in the pre-wait state which inevitably leads to epoch
// increment.
newstate = (state & kEpochMask) + next;
} }
CheckState(newstate);
if (state_.compare_exchange_weak(state, newstate, if (state_.compare_exchange_weak(state, newstate,
std::memory_order_acquire)) { std::memory_order_acq_rel)) {
if (!notifyAll && waiters) return; // unblocked pre-wait thread if (!notifyAll && (signals < waiters))
return; // unblocked pre-wait thread
if ((state & kStackMask) == kStackMask) return; if ((state & kStackMask) == kStackMask) return;
Waiter* w = &waiters_[state & kStackMask]; Waiter* w = &waiters_[state & kStackMask];
if (!notifyAll) w->next.store(nullptr, std::memory_order_relaxed); if (!notifyAll) w->next.store(kStackMask, std::memory_order_relaxed);
Unpark(w); Unpark(w);
return; return;
} }
@ -171,11 +170,11 @@ class EventCount {
friend class EventCount; friend class EventCount;
// Align to 128 byte boundary to prevent false sharing with other Waiter // Align to 128 byte boundary to prevent false sharing with other Waiter
// objects in the same vector. // objects in the same vector.
EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<Waiter*> next; EIGEN_ALIGN_TO_BOUNDARY(128) std::atomic<uint64_t> next;
std::mutex mu; std::mutex mu;
std::condition_variable cv; std::condition_variable cv;
uint64_t epoch; uint64_t epoch = 0;
unsigned state; unsigned state = kNotSignaled;
enum { enum {
kNotSignaled, kNotSignaled,
kWaiting, kWaiting,
@ -185,23 +184,41 @@ class EventCount {
private: private:
// State_ layout: // State_ layout:
// - low kStackBits is a stack of waiters committed wait. // - low kWaiterBits is a stack of waiters committed wait
// (indexes in waiters_ array are used as stack elements,
// kStackMask means empty stack).
// - next kWaiterBits is count of waiters in prewait state. // - next kWaiterBits is count of waiters in prewait state.
// - next kEpochBits is modification counter. // - next kWaiterBits is count of pending signals.
static const uint64_t kStackBits = 16; // - remaining bits are ABA counter for the stack.
static const uint64_t kStackMask = (1ull << kStackBits) - 1; // (stored in Waiter node and incremented on push).
static const uint64_t kWaiterBits = 16; static const uint64_t kWaiterBits = 14;
static const uint64_t kWaiterShift = 16; static const uint64_t kStackMask = (1ull << kWaiterBits) - 1;
static const uint64_t kWaiterShift = kWaiterBits;
static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1) static const uint64_t kWaiterMask = ((1ull << kWaiterBits) - 1)
<< kWaiterShift; << kWaiterShift;
static const uint64_t kWaiterInc = 1ull << kWaiterBits; static const uint64_t kWaiterInc = 1ull << kWaiterShift;
static const uint64_t kEpochBits = 32; static const uint64_t kSignalShift = 2 * kWaiterBits;
static const uint64_t kEpochShift = 32; static const uint64_t kSignalMask = ((1ull << kWaiterBits) - 1)
<< kSignalShift;
static const uint64_t kSignalInc = 1ull << kSignalShift;
static const uint64_t kEpochShift = 3 * kWaiterBits;
static const uint64_t kEpochBits = 64 - kEpochShift;
static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift; static const uint64_t kEpochMask = ((1ull << kEpochBits) - 1) << kEpochShift;
static const uint64_t kEpochInc = 1ull << kEpochShift; static const uint64_t kEpochInc = 1ull << kEpochShift;
std::atomic<uint64_t> state_; std::atomic<uint64_t> state_;
MaxSizeVector<Waiter>& waiters_; MaxSizeVector<Waiter>& waiters_;
static void CheckState(uint64_t state, bool waiter = false) {
static_assert(kEpochBits >= 20, "not enough bits to prevent ABA problem");
const uint64_t waiters = (state & kWaiterMask) >> kWaiterShift;
const uint64_t signals = (state & kSignalMask) >> kSignalShift;
eigen_plain_assert(waiters >= signals);
eigen_plain_assert(waiters < (1 << kWaiterBits) - 1);
eigen_plain_assert(!waiter || waiters > 0);
(void)waiters;
(void)signals;
}
void Park(Waiter* w) { void Park(Waiter* w) {
std::unique_lock<std::mutex> lock(w->mu); std::unique_lock<std::mutex> lock(w->mu);
while (w->state != Waiter::kSignaled) { while (w->state != Waiter::kSignaled) {
@ -210,10 +227,10 @@ class EventCount {
} }
} }
void Unpark(Waiter* waiters) { void Unpark(Waiter* w) {
Waiter* next = nullptr; for (Waiter* next; w; w = next) {
for (Waiter* w = waiters; w; w = next) { uint64_t wnext = w->next.load(std::memory_order_relaxed) & kStackMask;
next = w->next.load(std::memory_order_relaxed); next = wnext == kStackMask ? nullptr : &waiters_[wnext];
unsigned state; unsigned state;
{ {
std::unique_lock<std::mutex> lock(w->mu); std::unique_lock<std::mutex> lock(w->mu);

View File

@ -374,11 +374,11 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
eigen_plain_assert(!t->f); eigen_plain_assert(!t->f);
// We already did best-effort emptiness check in Steal, so prepare for // We already did best-effort emptiness check in Steal, so prepare for
// blocking. // blocking.
ec_.Prewait(waiter); if (!ec_.Prewait()) return true;
// Now do a reliable emptiness check. // Now do a reliable emptiness check.
int victim = NonEmptyQueueIndex(); int victim = NonEmptyQueueIndex();
if (victim != -1) { if (victim != -1) {
ec_.CancelWait(waiter); ec_.CancelWait();
if (cancelled_) { if (cancelled_) {
return false; return false;
} else { } else {
@ -392,7 +392,7 @@ class ThreadPoolTempl : public Eigen::ThreadPoolInterface {
blocked_++; blocked_++;
// TODO is blocked_ required to be unsigned? // TODO is blocked_ required to be unsigned?
if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) { if (done_ && blocked_ == static_cast<unsigned>(num_threads_)) {
ec_.CancelWait(waiter); ec_.CancelWait();
// Almost done, but need to re-check queues. // Almost done, but need to re-check queues.
// Consider that all queues are empty and all worker threads are preempted // Consider that all queues are empty and all worker threads are preempted
// right after incrementing blocked_ above. Now a free-standing thread // right after incrementing blocked_ above. Now a free-standing thread

View File

@ -30,11 +30,11 @@ static void test_basic_eventcount()
EventCount ec(waiters); EventCount ec(waiters);
EventCount::Waiter& w = waiters[0]; EventCount::Waiter& w = waiters[0];
ec.Notify(false); ec.Notify(false);
ec.Prewait(&w); VERIFY(ec.Prewait());
ec.Notify(true); ec.Notify(true);
ec.CommitWait(&w); ec.CommitWait(&w);
ec.Prewait(&w); VERIFY(ec.Prewait());
ec.CancelWait(&w); ec.CancelWait();
} }
// Fake bounded counter-based queue. // Fake bounded counter-based queue.
@ -112,7 +112,7 @@ static void test_stress_eventcount()
unsigned idx = rand_reentrant(&rnd) % kQueues; unsigned idx = rand_reentrant(&rnd) % kQueues;
if (queues[idx].Pop()) continue; if (queues[idx].Pop()) continue;
j--; j--;
ec.Prewait(&w); if (!ec.Prewait()) continue;
bool empty = true; bool empty = true;
for (int q = 0; q < kQueues; q++) { for (int q = 0; q < kQueues; q++) {
if (!queues[q].Empty()) { if (!queues[q].Empty()) {
@ -121,7 +121,7 @@ static void test_stress_eventcount()
} }
} }
if (!empty) { if (!empty) {
ec.CancelWait(&w); ec.CancelWait();
continue; continue;
} }
ec.CommitWait(&w); ec.CommitWait(&w);