mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 17:19:36 +08:00
Update ThreadLocal to use separate Initialize/Release callables
This commit is contained in:
parent
e3dec4dcc1
commit
d918bd9a8b
@ -64,27 +64,38 @@
|
|||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
|
||||||
// Thread local container for elements of type Factory::T, that does not use
|
namespace internal {
|
||||||
// thread local storage. It will lazily initialize elements for each thread that
|
template <typename T>
|
||||||
// accesses this object. As long as the number of unique threads accessing this
|
struct ThreadLocalNoOpInitialize {
|
||||||
// storage is smaller than `kAllocationMultiplier * num_threads`, it is
|
void operator()(T&) const {}
|
||||||
// lock-free and wait-free. Otherwise it will use a mutex for synchronization.
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct ThreadLocalNoOpRelease {
|
||||||
|
void operator()(T&) const {}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace internal
|
||||||
|
|
||||||
|
// Thread local container for elements of type T, that does not use thread local
|
||||||
|
// storage. As long as the number of unique threads accessing this storage
|
||||||
|
// is smaller than `capacity_`, it is lock-free and wait-free. Otherwise it will
|
||||||
|
// use a mutex for synchronization.
|
||||||
|
//
|
||||||
|
// Type `T` has to be default constructible, and by default each thread will get
|
||||||
|
// a default constructed value. It is possible to specify custom `initialize`
|
||||||
|
// callable, that will be called lazily from each thread accessing this object,
|
||||||
|
// and will be passed a default initialized object of type `T`. Also it's
|
||||||
|
// possible to pass a custom `release` callable, that will be invoked before
|
||||||
|
// calling ~T().
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
// struct Counter {
|
// struct Counter {
|
||||||
// int value;
|
// int value = 0;
|
||||||
// }
|
// }
|
||||||
//
|
//
|
||||||
// struct CounterFactory {
|
// Eigen::ThreadLocal<Counter> counter(10);
|
||||||
// using T = Counter;
|
|
||||||
//
|
|
||||||
// Counter Allocate() { return {0}; }
|
|
||||||
// void Release(Counter&) {}
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// CounterFactory factory;
|
|
||||||
// Eigen::ThreadLocal<CounterFactory> counter(factory, 10);
|
|
||||||
//
|
//
|
||||||
// // Each thread will have access to it's own counter object.
|
// // Each thread will have access to it's own counter object.
|
||||||
// Counter& cnt = counter.local();
|
// Counter& cnt = counter.local();
|
||||||
@ -98,40 +109,43 @@ namespace Eigen {
|
|||||||
// Somewhat similar to TBB thread local storage, with similar restrictions:
|
// Somewhat similar to TBB thread local storage, with similar restrictions:
|
||||||
// https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html
|
// https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html
|
||||||
//
|
//
|
||||||
template<typename Factory>
|
template <typename T,
|
||||||
|
typename Initialize = internal::ThreadLocalNoOpInitialize<T>,
|
||||||
|
typename Release = internal::ThreadLocalNoOpRelease<T>>
|
||||||
class ThreadLocal {
|
class ThreadLocal {
|
||||||
// We allocate larger storage for thread local data, than the number of
|
|
||||||
// threads, because thread pool size might grow, or threads outside of a
|
|
||||||
// thread pool might steal the work. We still expect this number to be of the
|
|
||||||
// same order of magnitude as the original `num_threads`.
|
|
||||||
static constexpr int kAllocationMultiplier = 4;
|
|
||||||
|
|
||||||
using T = typename Factory::T;
|
|
||||||
|
|
||||||
// We preallocate default constructed elements in MaxSizedVector.
|
// We preallocate default constructed elements in MaxSizedVector.
|
||||||
static_assert(std::is_default_constructible<T>::value,
|
static_assert(std::is_default_constructible<T>::value,
|
||||||
"ThreadLocal data type must be default constructible");
|
"ThreadLocal data type must be default constructible");
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ThreadLocal(Factory& factory, int num_threads)
|
explicit ThreadLocal(int capacity)
|
||||||
: factory_(factory),
|
: ThreadLocal(capacity, internal::ThreadLocalNoOpInitialize<T>(),
|
||||||
num_records_(kAllocationMultiplier * num_threads),
|
internal::ThreadLocalNoOpRelease<T>()) {}
|
||||||
data_(num_records_),
|
|
||||||
ptr_(num_records_),
|
ThreadLocal(int capacity, Initialize initialize)
|
||||||
|
: ThreadLocal(capacity, std::move(initialize),
|
||||||
|
internal::ThreadLocalNoOpRelease<T>()) {}
|
||||||
|
|
||||||
|
ThreadLocal(int capacity, Initialize initialize, Release release)
|
||||||
|
: initialize_(std::move(initialize)),
|
||||||
|
release_(std::move(release)),
|
||||||
|
capacity_(capacity),
|
||||||
|
data_(capacity_),
|
||||||
|
ptr_(capacity_),
|
||||||
filled_records_(0) {
|
filled_records_(0) {
|
||||||
eigen_assert(num_threads >= 0);
|
eigen_assert(capacity_ >= 0);
|
||||||
data_.resize(num_records_);
|
data_.resize(capacity_);
|
||||||
for (int i = 0; i < num_records_; ++i) {
|
for (int i = 0; i < capacity_; ++i) {
|
||||||
ptr_.emplace_back(nullptr);
|
ptr_.emplace_back(nullptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
T& local() {
|
T& local() {
|
||||||
std::thread::id this_thread = std::this_thread::get_id();
|
std::thread::id this_thread = std::this_thread::get_id();
|
||||||
if (num_records_ == 0) return SpilledLocal(this_thread);
|
if (capacity_ == 0) return SpilledLocal(this_thread);
|
||||||
|
|
||||||
std::size_t h = std::hash<std::thread::id>()(this_thread);
|
std::size_t h = std::hash<std::thread::id>()(this_thread);
|
||||||
const int start_idx = h % num_records_;
|
const int start_idx = h % capacity_;
|
||||||
|
|
||||||
// NOTE: From the definition of `std::this_thread::get_id()` it is
|
// NOTE: From the definition of `std::this_thread::get_id()` it is
|
||||||
// guaranteed that we never can have concurrent insertions with the same key
|
// guaranteed that we never can have concurrent insertions with the same key
|
||||||
@ -147,7 +161,7 @@ class ThreadLocal {
|
|||||||
if (record.thread_id == this_thread) return record.value;
|
if (record.thread_id == this_thread) return record.value;
|
||||||
|
|
||||||
idx += 1;
|
idx += 1;
|
||||||
if (idx >= num_records_) idx -= num_records_;
|
if (idx >= capacity_) idx -= capacity_;
|
||||||
if (idx == start_idx) break;
|
if (idx == start_idx) break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,8 +169,7 @@ class ThreadLocal {
|
|||||||
// table at `idx`, or we did a full traversal and table is full.
|
// table at `idx`, or we did a full traversal and table is full.
|
||||||
|
|
||||||
// If lock-free storage is full, fallback on mutex.
|
// If lock-free storage is full, fallback on mutex.
|
||||||
if (filled_records_.load() >= num_records_)
|
if (filled_records_.load() >= capacity_) return SpilledLocal(this_thread);
|
||||||
return SpilledLocal(this_thread);
|
|
||||||
|
|
||||||
// We double check that we still have space to insert an element into a lock
|
// We double check that we still have space to insert an element into a lock
|
||||||
// free storage. If old value in `filled_records_` is larger than the
|
// free storage. If old value in `filled_records_` is larger than the
|
||||||
@ -164,11 +177,12 @@ class ThreadLocal {
|
|||||||
// we were traversing lookup table.
|
// we were traversing lookup table.
|
||||||
int insertion_index =
|
int insertion_index =
|
||||||
filled_records_.fetch_add(1, std::memory_order_relaxed);
|
filled_records_.fetch_add(1, std::memory_order_relaxed);
|
||||||
if (insertion_index >= num_records_) return SpilledLocal(this_thread);
|
if (insertion_index >= capacity_) return SpilledLocal(this_thread);
|
||||||
|
|
||||||
// At this point it's guaranteed that we can access to
|
// At this point it's guaranteed that we can access to
|
||||||
// data_[insertion_index_] without a data race.
|
// data_[insertion_index_] without a data race.
|
||||||
data_[insertion_index] = {this_thread, factory_.Allocate()};
|
data_[insertion_index].thread_id = this_thread;
|
||||||
|
initialize_(data_[insertion_index].value);
|
||||||
|
|
||||||
// That's the pointer we'll put into the lookup table.
|
// That's the pointer we'll put into the lookup table.
|
||||||
ThreadIdAndValue* inserted = &data_[insertion_index];
|
ThreadIdAndValue* inserted = &data_[insertion_index];
|
||||||
@ -187,7 +201,7 @@ class ThreadLocal {
|
|||||||
idx = insertion_idx;
|
idx = insertion_idx;
|
||||||
while (ptr_[idx].load() != nullptr) {
|
while (ptr_[idx].load() != nullptr) {
|
||||||
idx += 1;
|
idx += 1;
|
||||||
if (idx >= num_records_) idx -= num_records_;
|
if (idx >= capacity_) idx -= capacity_;
|
||||||
// If we did a full loop, it means that we don't have any free entries
|
// If we did a full loop, it means that we don't have any free entries
|
||||||
// in the lookup table, and this means that something is terribly wrong.
|
// in the lookup table, and this means that something is terribly wrong.
|
||||||
eigen_assert(idx != insertion_idx);
|
eigen_assert(idx != insertion_idx);
|
||||||
@ -200,7 +214,7 @@ class ThreadLocal {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// WARN: It's not thread safe to call it concurrently with `local()`.
|
// WARN: It's not thread safe to call it concurrently with `local()`.
|
||||||
void ForEach(std::function<void(std::thread::id, T & )> f) {
|
void ForEach(std::function<void(std::thread::id, T&)> f) {
|
||||||
// Reading directly from `data_` is unsafe, because only CAS to the
|
// Reading directly from `data_` is unsafe, because only CAS to the
|
||||||
// record in `ptr_` makes all changes visible to other threads.
|
// record in `ptr_` makes all changes visible to other threads.
|
||||||
for (auto& ptr : ptr_) {
|
for (auto& ptr : ptr_) {
|
||||||
@ -210,7 +224,7 @@ class ThreadLocal {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We did not spill into the map based storage.
|
// We did not spill into the map based storage.
|
||||||
if (filled_records_.load(std::memory_order_relaxed) < num_records_) return;
|
if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
|
||||||
|
|
||||||
// Adds a happens before edge from the last call to SpilledLocal().
|
// Adds a happens before edge from the last call to SpilledLocal().
|
||||||
std::unique_lock<std::mutex> lock(mu_);
|
std::unique_lock<std::mutex> lock(mu_);
|
||||||
@ -226,16 +240,16 @@ class ThreadLocal {
|
|||||||
for (auto& ptr : ptr_) {
|
for (auto& ptr : ptr_) {
|
||||||
ThreadIdAndValue* record = ptr.load();
|
ThreadIdAndValue* record = ptr.load();
|
||||||
if (record == nullptr) continue;
|
if (record == nullptr) continue;
|
||||||
factory_.Release(record->value);
|
release_(record->value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We did not spill into the map based storage.
|
// We did not spill into the map based storage.
|
||||||
if (filled_records_.load(std::memory_order_relaxed) < num_records_) return;
|
if (filled_records_.load(std::memory_order_relaxed) < capacity_) return;
|
||||||
|
|
||||||
// Adds a happens before edge from the last call to SpilledLocal().
|
// Adds a happens before edge from the last call to SpilledLocal().
|
||||||
std::unique_lock<std::mutex> lock(mu_);
|
std::unique_lock<std::mutex> lock(mu_);
|
||||||
for (auto& kv : per_thread_map_) {
|
for (auto& kv : per_thread_map_) {
|
||||||
factory_.Release(kv.second);
|
release_(kv.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,16 +265,18 @@ class ThreadLocal {
|
|||||||
|
|
||||||
auto it = per_thread_map_.find(this_thread);
|
auto it = per_thread_map_.find(this_thread);
|
||||||
if (it == per_thread_map_.end()) {
|
if (it == per_thread_map_.end()) {
|
||||||
auto result = per_thread_map_.emplace(this_thread, factory_.Allocate());
|
auto result = per_thread_map_.emplace(this_thread, T());
|
||||||
eigen_assert(result.second);
|
eigen_assert(result.second);
|
||||||
|
initialize_((*result.first).second);
|
||||||
return (*result.first).second;
|
return (*result.first).second;
|
||||||
} else {
|
} else {
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Factory& factory_;
|
Initialize initialize_;
|
||||||
const int num_records_;
|
Release release_;
|
||||||
|
const int capacity_;
|
||||||
|
|
||||||
// Storage that backs lock-free lookup table `ptr_`. Records stored in this
|
// Storage that backs lock-free lookup table `ptr_`. Records stored in this
|
||||||
// storage contiguously starting from index 0.
|
// storage contiguously starting from index 0.
|
||||||
@ -274,7 +290,7 @@ class ThreadLocal {
|
|||||||
std::atomic<int> filled_records_;
|
std::atomic<int> filled_records_;
|
||||||
|
|
||||||
// We fallback on per thread map if lock-free storage is full. In practice
|
// We fallback on per thread map if lock-free storage is full. In practice
|
||||||
// this should never happen, if `num_threads` is a reasonable estimate of the
|
// this should never happen, if `capacity_` is a reasonable estimate of the
|
||||||
// number of threads running in a system.
|
// number of threads running in a system.
|
||||||
std::mutex mu_; // Protects per_thread_map_.
|
std::mutex mu_; // Protects per_thread_map_.
|
||||||
std::unordered_map<std::thread::id, T> per_thread_map_;
|
std::unordered_map<std::thread::id, T> per_thread_map_;
|
||||||
|
@ -13,36 +13,30 @@
|
|||||||
#include "main.h"
|
#include "main.h"
|
||||||
#include <Eigen/CXX11/ThreadPool>
|
#include <Eigen/CXX11/ThreadPool>
|
||||||
|
|
||||||
class Counter {
|
struct Counter {
|
||||||
public:
|
Counter() = default;
|
||||||
Counter() : Counter(0) {}
|
|
||||||
explicit Counter(int value)
|
|
||||||
: created_by_(std::this_thread::get_id()), value_(value) {}
|
|
||||||
|
|
||||||
void inc() {
|
void inc() {
|
||||||
// Check that mutation happens only in a thread that created this counter.
|
// Check that mutation happens only in a thread that created this counter.
|
||||||
VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by_);
|
VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by);
|
||||||
value_++;
|
counter_value++;
|
||||||
}
|
}
|
||||||
int value() { return value_; }
|
int value() { return counter_value; }
|
||||||
|
|
||||||
private:
|
std::thread::id created_by;
|
||||||
std::thread::id created_by_;
|
int counter_value = 0;
|
||||||
int value_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct CounterFactory {
|
struct InitCounter {
|
||||||
using T = Counter;
|
void operator()(Counter& counter) {
|
||||||
|
counter.created_by = std::this_thread::get_id();
|
||||||
T Allocate() { return Counter(0); }
|
}
|
||||||
void Release(T&) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void test_simple_thread_local() {
|
void test_simple_thread_local() {
|
||||||
CounterFactory factory;
|
|
||||||
int num_threads = internal::random<int>(4, 32);
|
int num_threads = internal::random<int>(4, 32);
|
||||||
Eigen::ThreadPool thread_pool(num_threads);
|
Eigen::ThreadPool thread_pool(num_threads);
|
||||||
Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads);
|
Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
|
||||||
|
|
||||||
int num_tasks = 3 * num_threads;
|
int num_tasks = 3 * num_threads;
|
||||||
Eigen::Barrier barrier(num_tasks);
|
Eigen::Barrier barrier(num_tasks);
|
||||||
@ -64,8 +58,7 @@ void test_simple_thread_local() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void test_zero_sized_thread_local() {
|
void test_zero_sized_thread_local() {
|
||||||
CounterFactory factory;
|
Eigen::ThreadLocal<Counter, InitCounter> counter(0, InitCounter());
|
||||||
Eigen::ThreadLocal<CounterFactory> counter(factory, 0);
|
|
||||||
|
|
||||||
Counter& local = counter.local();
|
Counter& local = counter.local();
|
||||||
local.inc();
|
local.inc();
|
||||||
@ -81,10 +74,9 @@ void test_zero_sized_thread_local() {
|
|||||||
|
|
||||||
// All thread local values fits into the lock-free storage.
|
// All thread local values fits into the lock-free storage.
|
||||||
void test_large_number_of_tasks_no_spill() {
|
void test_large_number_of_tasks_no_spill() {
|
||||||
CounterFactory factory;
|
|
||||||
int num_threads = internal::random<int>(4, 32);
|
int num_threads = internal::random<int>(4, 32);
|
||||||
Eigen::ThreadPool thread_pool(num_threads);
|
Eigen::ThreadPool thread_pool(num_threads);
|
||||||
Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads);
|
Eigen::ThreadLocal<Counter, InitCounter> counter(num_threads, InitCounter());
|
||||||
|
|
||||||
int num_tasks = 10000;
|
int num_tasks = 10000;
|
||||||
Eigen::Barrier barrier(num_tasks);
|
Eigen::Barrier barrier(num_tasks);
|
||||||
@ -117,10 +109,9 @@ void test_large_number_of_tasks_no_spill() {
|
|||||||
// Lock free thread local storage is too small to fit all the unique threads,
|
// Lock free thread local storage is too small to fit all the unique threads,
|
||||||
// and it spills to a map guarded by a mutex.
|
// and it spills to a map guarded by a mutex.
|
||||||
void test_large_number_of_tasks_with_spill() {
|
void test_large_number_of_tasks_with_spill() {
|
||||||
CounterFactory factory;
|
|
||||||
int num_threads = internal::random<int>(4, 32);
|
int num_threads = internal::random<int>(4, 32);
|
||||||
Eigen::ThreadPool thread_pool(num_threads);
|
Eigen::ThreadPool thread_pool(num_threads);
|
||||||
Eigen::ThreadLocal<CounterFactory> counter(factory, 1); // This is too small
|
Eigen::ThreadLocal<Counter, InitCounter> counter(1, InitCounter());
|
||||||
|
|
||||||
int num_tasks = 10000;
|
int num_tasks = 10000;
|
||||||
Eigen::Barrier barrier(num_tasks);
|
Eigen::Barrier barrier(num_tasks);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user