Update ThreadLocal to use separate Initialize/Release callables

This commit is contained in:
Eugene Zhulenev 2019-09-10 16:13:32 -07:00
parent e3dec4dcc1
commit d918bd9a8b
2 changed files with 80 additions and 73 deletions

View File

@ -64,27 +64,38 @@
namespace Eigen { namespace Eigen {
// Thread local container for elements of type Factory::T, that does not use namespace internal {
// thread local storage. It will lazily initialize elements for each thread that template <typename T>
// accesses this object. As long as the number of unique threads accessing this struct ThreadLocalNoOpInitialize {
// storage is smaller than `kAllocationMultiplier * num_threads`, it is void operator()(T&) const {}
// lock-free and wait-free. Otherwise it will use a mutex for synchronization. };
template <typename T>
struct ThreadLocalNoOpRelease {
void operator()(T&) const {}
};
} // namespace internal
// Thread local container for elements of type T, that does not use thread local
// storage. As long as the number of unique threads accessing this storage
// is smaller than `capacity_`, it is lock-free and wait-free. Otherwise it will
// use a mutex for synchronization.
//
// Type `T` has to be default constructible, and by default each thread will get
// a default constructed value. It is possible to specify custom `initialize`
// callable, that will be called lazily from each thread accessing this object,
// and will be passed a default initialized object of type `T`. Also it's
// possible to pass a custom `release` callable, that will be invoked before
// calling ~T().
// //
// Example: // Example:
// //
// struct Counter { // struct Counter {
// int value; // int value = 0;
// } // }
// //
// struct CounterFactory { // Eigen::ThreadLocal<Counter> counter(10);
// using T = Counter;
//
// Counter Allocate() { return {0}; }
// void Release(Counter&) {}
// };
//
// CounterFactory factory;
// Eigen::ThreadLocal<CounterFactory> counter(factory, 10);
// //
// // Each thread will have access to it's own counter object. // // Each thread will have access to it's own counter object.
// Counter& cnt = counter.local(); // Counter& cnt = counter.local();
@ -98,40 +109,43 @@ namespace Eigen {
// Somewhat similar to TBB thread local storage, with similar restrictions: // Somewhat similar to TBB thread local storage, with similar restrictions:
// https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html // https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html
// //
template<typename Factory> template <typename T,
typename Initialize = internal::ThreadLocalNoOpInitialize<T>,
typename Release = internal::ThreadLocalNoOpRelease<T>>
class ThreadLocal { class ThreadLocal {
// We allocate larger storage for thread local data, than the number of
// threads, because thread pool size might grow, or threads outside of a
// thread pool might steal the work. We still expect this number to be of the
// same order of magnitude as the original `num_threads`.
static constexpr int kAllocationMultiplier = 4;
using T = typename Factory::T;
// We preallocate default constructed elements in MaxSizedVector. // We preallocate default constructed elements in MaxSizedVector.
static_assert(std::is_default_constructible<T>::value, static_assert(std::is_default_constructible<T>::value,
"ThreadLocal data type must be default constructible"); "ThreadLocal data type must be default constructible");
public: public:
explicit ThreadLocal(Factory& factory, int num_threads) explicit ThreadLocal(int capacity)
: factory_(factory), : ThreadLocal(capacity, internal::ThreadLocalNoOpInitialize<T>(),
num_records_(kAllocationMultiplier * num_threads), internal::ThreadLocalNoOpRelease<T>()) {}
data_(num_records_),
ptr_(num_records_), ThreadLocal(int capacity, Initialize initialize)
: ThreadLocal(capacity, std::move(initialize),
internal::ThreadLocalNoOpRelease<T>()) {}
ThreadLocal(int capacity, Initialize initialize, Release release)
: initialize_(std::move(initialize)),
release_(std::move(release)),
capacity_(capacity),
data_(capacity_),
ptr_(capacity_),
filled_records_(0) { filled_records_(0) {
eigen_assert(num_threads >= 0); eigen_assert(capacity_ >= 0);
data_.resize(num_records_); data_.resize(capacity_);
for (int i = 0; i < num_records_; ++i) { for (int i = 0; i < capacity_; ++i) {
ptr_.emplace_back(nullptr); ptr_.emplace_back(nullptr);
} }
} }
T& local() { T& local() {
std::thread::id this_thread = std::this_thread::get_id(); std::thread::id this_thread = std::this_thread::get_id();
if (num_records_ == 0) return SpilledLocal(this_thread); if (capacity_ == 0) return SpilledLocal(this_thread);
std::size_t h = std::hash<std::thread::id>()(this_thread); std::size_t h = std::hash<std::thread::id>()(this_thread);
const int start_idx = h % num_records_; const int start_idx = h % capacity_;
// NOTE: From the definition of `std::this_thread::get_id()` it is // NOTE: From the definition of `std::this_thread::get_id()` it is
// guaranteed that we never can have concurrent insertions with the same key // guaranteed that we never can have concurrent insertions with the same key
@ -147,7 +161,7 @@ class ThreadLocal {
if (record.thread_id == this_thread) return record.value; if (record.thread_id == this_thread) return record.value;
idx += 1; idx += 1;
if (idx >= num_records_) idx -= num_records_; if (idx >= capacity_) idx -= capacity_;
if (idx == start_idx) break; if (idx == start_idx) break;
} }
@ -155,8 +169,7 @@ class ThreadLocal {
// table at `idx`, or we did a full traversal and table is full. // table at `idx`, or we did a full traversal and table is full.
// If lock-free storage is full, fallback on mutex. // If lock-free storage is full, fallback on mutex.
if (filled_records_.load() >= num_records_) if (filled_records_.load() >= capacity_) return SpilledLocal(this_thread);
return SpilledLocal(this_thread);
// We double check that we still have space to insert an element into a lock // We double check that we still have space to insert an element into a lock
// free storage. If old value in `filled_records_` is larger than the // free storage. If old value in `filled_records_` is larger than the
@ -164,11 +177,12 @@ class ThreadLocal {
// we were traversing lookup table. // we were traversing lookup table.
int insertion_index = int insertion_index =
filled_records_.fetch_add(1, std::memory_order_relaxed); filled_records_.fetch_add(1, std::memory_order_relaxed);
if (insertion_index >= num_records_) return SpilledLocal(this_thread); if (insertion_index >= capacity_) return SpilledLocal(this_thread);
// At this point it's guaranteed that we can access to // At this point it's guaranteed that we can access to
// data_[insertion_index_] without a data race. // data_[insertion_index_] without a data race.
data_[insertion_index] = {this_thread, factory_.Allocate()}; data_[insertion_index].thread_id = this_thread;
initialize_(data_[insertion_index].value);
// That's the pointer we'll put into the lookup table. // That's the pointer we'll put into the lookup table.
ThreadIdAndValue* inserted = &data_[insertion_index]; ThreadIdAndValue* inserted = &data_[insertion_index];
@ -187,7 +201,7 @@ class ThreadLocal {
idx = insertion_idx; idx = insertion_idx;
while (ptr_[idx].load() != nullptr) { while (ptr_[idx].load() != nullptr) {
idx += 1; idx += 1;
if (idx >= num_records_) idx -= num_records_; if (idx >= capacity_) idx -= capacity_;
// If we did a full loop, it means that we don't have any free entries // If we did a full loop, it means that we don't have any free entries
// in the lookup table, and this means that something is terribly wrong. // in the lookup table, and this means that something is terribly wrong.
eigen_assert(idx != insertion_idx); eigen_assert(idx != insertion_idx);
@ -200,7 +214,7 @@ class ThreadLocal {
} }
// WARN: It's not thread safe to call it concurrently with `local()`. // WARN: It's not thread safe to call it concurrently with `local()`.
void ForEach(std::function<void(std::thread::id, T & )> f) { void ForEach(std::function<void(std::thread::id, T&)> f) {
// Reading directly from `data_` is unsafe, because only CAS to the // Reading directly from `data_` is unsafe, because only CAS to the
// record in `ptr_` makes all changes visible to other threads. // record in `ptr_` makes all changes visible to other threads.
for (auto& ptr : ptr_) { for (auto& ptr : ptr_) {
@ -210,7 +224,7 @@ class ThreadLocal {
} }
// We did not spill into the map based storage. // We did not spill into the map based storage.
if (filled_records_.load(std::memory_order_relaxed) < num_records_) return; if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
// Adds a happens before edge from the last call to SpilledLocal(). // Adds a happens before edge from the last call to SpilledLocal().
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
@ -226,16 +240,16 @@ class ThreadLocal {
for (auto& ptr : ptr_) { for (auto& ptr : ptr_) {
ThreadIdAndValue* record = ptr.load(); ThreadIdAndValue* record = ptr.load();
if (record == nullptr) continue; if (record == nullptr) continue;
factory_.Release(record->value); release_(record->value);
} }
// We did not spill into the map based storage. // We did not spill into the map based storage.
if (filled_records_.load(std::memory_order_relaxed) < num_records_) return; if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
// Adds a happens before edge from the last call to SpilledLocal(). // Adds a happens before edge from the last call to SpilledLocal().
std::unique_lock<std::mutex> lock(mu_); std::unique_lock<std::mutex> lock(mu_);
for (auto& kv : per_thread_map_) { for (auto& kv : per_thread_map_) {
factory_.Release(kv.second); release_(kv.second);
} }
} }
@ -251,16 +265,18 @@ class ThreadLocal {
auto it = per_thread_map_.find(this_thread); auto it = per_thread_map_.find(this_thread);
if (it == per_thread_map_.end()) { if (it == per_thread_map_.end()) {
auto result = per_thread_map_.emplace(this_thread, factory_.Allocate()); auto result = per_thread_map_.emplace(this_thread, T());
eigen_assert(result.second); eigen_assert(result.second);
initialize_((*result.first).second);
return (*result.first).second; return (*result.first).second;
} else { } else {
return it->second; return it->second;
} }
} }
Factory& factory_; Initialize initialize_;
const int num_records_; Release release_;
const int capacity_;
// Storage that backs lock-free lookup table `ptr_`. Records stored in this // Storage that backs lock-free lookup table `ptr_`. Records stored in this
// storage contiguously starting from index 0. // storage contiguously starting from index 0.
@ -274,7 +290,7 @@ class ThreadLocal {
std::atomic<int> filled_records_; std::atomic<int> filled_records_;
// We fallback on per thread map if lock-free storage is full. In practice // We fallback on per thread map if lock-free storage is full. In practice
// this should never happen, if `num_threads` is a reasonable estimate of the // this should never happen, if `capacity_` is a reasonable estimate of the
// number of threads running in a system. // number of threads running in a system.
std::mutex mu_; // Protects per_thread_map_. std::mutex mu_; // Protects per_thread_map_.
std::unordered_map<std::thread::id, T> per_thread_map_; std::unordered_map<std::thread::id, T> per_thread_map_;

View File

@ -13,36 +13,30 @@
#include "main.h" #include "main.h"
#include <Eigen/CXX11/ThreadPool> #include <Eigen/CXX11/ThreadPool>
class Counter { struct Counter {
public: Counter() = default;
Counter() : Counter(0) {}
explicit Counter(int value)
: created_by_(std::this_thread::get_id()), value_(value) {}
void inc() { void inc() {
// Check that mutation happens only in a thread that created this counter. // Check that mutation happens only in a thread that created this counter.
VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by_); VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by);
value_++; counter_value++;
} }
int value() { return value_; } int value() { return counter_value; }
private: std::thread::id created_by;
std::thread::id created_by_; int counter_value = 0;
int value_;
}; };
struct CounterFactory { struct InitCounter {
using T = Counter; void operator()(Counter& counter) {
counter.created_by = std::this_thread::get_id();
T Allocate() { return Counter(0); } }
void Release(T&) {}
}; };
void test_simple_thread_local() { void test_simple_thread_local() {
CounterFactory factory;
int num_threads = internal::random<int>(4, 32); int num_threads = internal::random<int>(4, 32);
Eigen::ThreadPool thread_pool(num_threads); Eigen::ThreadPool thread_pool(num_threads);
Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads); Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
int num_tasks = 3 * num_threads; int num_tasks = 3 * num_threads;
Eigen::Barrier barrier(num_tasks); Eigen::Barrier barrier(num_tasks);
@ -64,8 +58,7 @@ void test_simple_thread_local() {
} }
void test_zero_sized_thread_local() { void test_zero_sized_thread_local() {
CounterFactory factory; Eigen::ThreadLocal<Counter, InitCounter> counter(0, InitCounter());
Eigen::ThreadLocal<CounterFactory> counter(factory, 0);
Counter& local = counter.local(); Counter& local = counter.local();
local.inc(); local.inc();
@ -81,10 +74,9 @@ void test_zero_sized_thread_local() {
// All thread local values fits into the lock-free storage. // All thread local values fits into the lock-free storage.
void test_large_number_of_tasks_no_spill() { void test_large_number_of_tasks_no_spill() {
CounterFactory factory;
int num_threads = internal::random<int>(4, 32); int num_threads = internal::random<int>(4, 32);
Eigen::ThreadPool thread_pool(num_threads); Eigen::ThreadPool thread_pool(num_threads);
Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads); Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
int num_tasks = 10000; int num_tasks = 10000;
Eigen::Barrier barrier(num_tasks); Eigen::Barrier barrier(num_tasks);
@ -117,10 +109,9 @@ void test_large_number_of_tasks_no_spill() {
// Lock free thread local storage is too small to fit all the unique threads, // Lock free thread local storage is too small to fit all the unique threads,
// and it spills to a map guarded by a mutex. // and it spills to a map guarded by a mutex.
void test_large_number_of_tasks_with_spill() { void test_large_number_of_tasks_with_spill() {
CounterFactory factory;
int num_threads = internal::random<int>(4, 32); int num_threads = internal::random<int>(4, 32);
Eigen::ThreadPool thread_pool(num_threads); Eigen::ThreadPool thread_pool(num_threads);
Eigen::ThreadLocal<CounterFactory> counter(factory, 1); // This is too small Eigen::ThreadLocal<Counter, InitCounter> counter(1, InitCounter());
int num_tasks = 10000; int num_tasks = 10000;
Eigen::Barrier barrier(num_tasks); Eigen::Barrier barrier(num_tasks);