mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
ThreadLocal container that does not rely on thread local storage
This commit is contained in:
parent
17226100c5
commit
e3dec4dcc1
@ -45,11 +45,7 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "src/util/CXX11Meta.h"
|
||||
#include "src/util/MaxSizeVector.h"
|
||||
|
||||
#include "src/ThreadPool/ThreadLocal.h"
|
||||
#ifndef EIGEN_THREAD_LOCAL
|
||||
// There are non-parenthesized calls to "max" in the <unordered_map> header,
|
||||
// which trigger a check in test/main.h causing compilation to fail.
|
||||
// We work around the check here by removing the check for max in
|
||||
@ -58,7 +54,11 @@
|
||||
#undef max
|
||||
#endif
|
||||
#include <unordered_map>
|
||||
#endif
|
||||
|
||||
#include "src/util/CXX11Meta.h"
|
||||
#include "src/util/MaxSizeVector.h"
|
||||
|
||||
#include "src/ThreadPool/ThreadLocal.h"
|
||||
#include "src/ThreadPool/ThreadYield.h"
|
||||
#include "src/ThreadPool/ThreadCancel.h"
|
||||
#include "src/ThreadPool/EventCount.h"
|
||||
|
@ -60,6 +60,226 @@
|
||||
#endif
|
||||
#endif // defined(__ANDROID__) && defined(__clang__)
|
||||
|
||||
#endif // EIGEN_AVOID_THREAD_LOCAL
|
||||
#endif // EIGEN_AVOID_THREAD_LOCAL
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
// Thread local container for elements of type Factory::T, that does not use
|
||||
// thread local storage. It will lazily initialize elements for each thread that
|
||||
// accesses this object. As long as the number of unique threads accessing this
|
||||
// storage is smaller than `kAllocationMultiplier * num_threads`, it is
|
||||
// lock-free and wait-free. Otherwise it will use a mutex for synchronization.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// struct Counter {
|
||||
// int value;
|
||||
// }
|
||||
//
|
||||
// struct CounterFactory {
|
||||
// using T = Counter;
|
||||
//
|
||||
// Counter Allocate() { return {0}; }
|
||||
// void Release(Counter&) {}
|
||||
// };
|
||||
//
|
||||
// CounterFactory factory;
|
||||
// Eigen::ThreadLocal<CounterFactory> counter(factory, 10);
|
||||
//
|
||||
// // Each thread will have access to it's own counter object.
|
||||
// Counter& cnt = counter.local();
|
||||
// cnt++;
|
||||
//
|
||||
// WARNING: Eigen::ThreadLocal uses the OS-specific value returned by
|
||||
// std::this_thread::get_id() to identify threads. This value is not guaranteed
|
||||
// to be unique except for the life of the thread. A newly created thread may
|
||||
// get an OS-specific ID equal to that of an already destroyed thread.
|
||||
//
|
||||
// Somewhat similar to TBB thread local storage, with similar restrictions:
|
||||
// https://www.threadingbuildingblocks.org/docs/help/reference/thread_local_storage/enumerable_thread_specific_cls.html
|
||||
//
|
||||
template<typename Factory>
|
||||
class ThreadLocal {
|
||||
// We allocate larger storage for thread local data, than the number of
|
||||
// threads, because thread pool size might grow, or threads outside of a
|
||||
// thread pool might steal the work. We still expect this number to be of the
|
||||
// same order of magnitude as the original `num_threads`.
|
||||
static constexpr int kAllocationMultiplier = 4;
|
||||
|
||||
using T = typename Factory::T;
|
||||
|
||||
// We preallocate default constructed elements in MaxSizedVector.
|
||||
static_assert(std::is_default_constructible<T>::value,
|
||||
"ThreadLocal data type must be default constructible");
|
||||
|
||||
public:
|
||||
explicit ThreadLocal(Factory& factory, int num_threads)
|
||||
: factory_(factory),
|
||||
num_records_(kAllocationMultiplier * num_threads),
|
||||
data_(num_records_),
|
||||
ptr_(num_records_),
|
||||
filled_records_(0) {
|
||||
eigen_assert(num_threads >= 0);
|
||||
data_.resize(num_records_);
|
||||
for (int i = 0; i < num_records_; ++i) {
|
||||
ptr_.emplace_back(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
T& local() {
|
||||
std::thread::id this_thread = std::this_thread::get_id();
|
||||
if (num_records_ == 0) return SpilledLocal(this_thread);
|
||||
|
||||
std::size_t h = std::hash<std::thread::id>()(this_thread);
|
||||
const int start_idx = h % num_records_;
|
||||
|
||||
// NOTE: From the definition of `std::this_thread::get_id()` it is
|
||||
// guaranteed that we never can have concurrent insertions with the same key
|
||||
// to our hash-map like data structure. If we didn't find an element during
|
||||
// the initial traversal, it's guaranteed that no one else could have
|
||||
// inserted it while we are in this function. This allows to massively
|
||||
// simplify out lock-free insert-only hash map.
|
||||
|
||||
// Check if we already have an element for `this_thread`.
|
||||
int idx = start_idx;
|
||||
while (ptr_[idx].load() != nullptr) {
|
||||
ThreadIdAndValue& record = *(ptr_[idx].load());
|
||||
if (record.thread_id == this_thread) return record.value;
|
||||
|
||||
idx += 1;
|
||||
if (idx >= num_records_) idx -= num_records_;
|
||||
if (idx == start_idx) break;
|
||||
}
|
||||
|
||||
// If we are here, it means that we found an insertion point in lookup
|
||||
// table at `idx`, or we did a full traversal and table is full.
|
||||
|
||||
// If lock-free storage is full, fallback on mutex.
|
||||
if (filled_records_.load() >= num_records_)
|
||||
return SpilledLocal(this_thread);
|
||||
|
||||
// We double check that we still have space to insert an element into a lock
|
||||
// free storage. If old value in `filled_records_` is larger than the
|
||||
// records capacity, it means that some other thread added an element while
|
||||
// we were traversing lookup table.
|
||||
int insertion_index =
|
||||
filled_records_.fetch_add(1, std::memory_order_relaxed);
|
||||
if (insertion_index >= num_records_) return SpilledLocal(this_thread);
|
||||
|
||||
// At this point it's guaranteed that we can access to
|
||||
// data_[insertion_index_] without a data race.
|
||||
data_[insertion_index] = {this_thread, factory_.Allocate()};
|
||||
|
||||
// That's the pointer we'll put into the lookup table.
|
||||
ThreadIdAndValue* inserted = &data_[insertion_index];
|
||||
|
||||
// We'll use nullptr pointer to ThreadIdAndValue in a compare-and-swap loop.
|
||||
ThreadIdAndValue* empty = nullptr;
|
||||
|
||||
// Now we have to find an insertion point into the lookup table. We start
|
||||
// from the `idx` that was identified as an insertion point above, it's
|
||||
// guaranteed that we will have an empty record somewhere in a lookup table
|
||||
// (because we created a record in the `data_`).
|
||||
const int insertion_idx = idx;
|
||||
|
||||
do {
|
||||
// Always start search from the original insertion candidate.
|
||||
idx = insertion_idx;
|
||||
while (ptr_[idx].load() != nullptr) {
|
||||
idx += 1;
|
||||
if (idx >= num_records_) idx -= num_records_;
|
||||
// If we did a full loop, it means that we don't have any free entries
|
||||
// in the lookup table, and this means that something is terribly wrong.
|
||||
eigen_assert(idx != insertion_idx);
|
||||
}
|
||||
// Atomic CAS of the pointer guarantees that any other thread, that will
|
||||
// follow this pointer will see all the mutations in the `data_`.
|
||||
} while (!ptr_[idx].compare_exchange_weak(empty, inserted));
|
||||
|
||||
return inserted->value;
|
||||
}
|
||||
|
||||
// WARN: It's not thread safe to call it concurrently with `local()`.
|
||||
void ForEach(std::function<void(std::thread::id, T & )> f) {
|
||||
// Reading directly from `data_` is unsafe, because only CAS to the
|
||||
// record in `ptr_` makes all changes visible to other threads.
|
||||
for (auto& ptr : ptr_) {
|
||||
ThreadIdAndValue* record = ptr.load();
|
||||
if (record == nullptr) continue;
|
||||
f(record->thread_id, record->value);
|
||||
}
|
||||
|
||||
// We did not spill into the map based storage.
|
||||
if (filled_records_.load(std::memory_order_relaxed) < num_records_) return;
|
||||
|
||||
// Adds a happens before edge from the last call to SpilledLocal().
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
for (auto& kv : per_thread_map_) {
|
||||
f(kv.first, kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
// WARN: It's not thread safe to call it concurrently with `local()`.
|
||||
~ThreadLocal() {
|
||||
// Reading directly from `data_` is unsafe, because only CAS to the record
|
||||
// in `ptr_` makes all changes visible to other threads.
|
||||
for (auto& ptr : ptr_) {
|
||||
ThreadIdAndValue* record = ptr.load();
|
||||
if (record == nullptr) continue;
|
||||
factory_.Release(record->value);
|
||||
}
|
||||
|
||||
// We did not spill into the map based storage.
|
||||
if (filled_records_.load(std::memory_order_relaxed) < num_records_) return;
|
||||
|
||||
// Adds a happens before edge from the last call to SpilledLocal().
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
for (auto& kv : per_thread_map_) {
|
||||
factory_.Release(kv.second);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct ThreadIdAndValue {
|
||||
std::thread::id thread_id;
|
||||
T value;
|
||||
};
|
||||
|
||||
// Use unordered map guarded by a mutex when lock free storage is full.
|
||||
T& SpilledLocal(std::thread::id this_thread) {
|
||||
std::unique_lock<std::mutex> lock(mu_);
|
||||
|
||||
auto it = per_thread_map_.find(this_thread);
|
||||
if (it == per_thread_map_.end()) {
|
||||
auto result = per_thread_map_.emplace(this_thread, factory_.Allocate());
|
||||
eigen_assert(result.second);
|
||||
return (*result.first).second;
|
||||
} else {
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
Factory& factory_;
|
||||
const int num_records_;
|
||||
|
||||
// Storage that backs lock-free lookup table `ptr_`. Records stored in this
|
||||
// storage contiguously starting from index 0.
|
||||
MaxSizeVector<ThreadIdAndValue> data_;
|
||||
|
||||
// Atomic pointers to the data stored in `data_`. Used as a lookup table for
|
||||
// linear probing hash map (https://en.wikipedia.org/wiki/Linear_probing).
|
||||
MaxSizeVector<std::atomic<ThreadIdAndValue*>> ptr_;
|
||||
|
||||
// Number of records stored in the `data_`.
|
||||
std::atomic<int> filled_records_;
|
||||
|
||||
// We fallback on per thread map if lock-free storage is full. In practice
|
||||
// this should never happen, if `num_threads` is a reasonable estimate of the
|
||||
// number of threads running in a system.
|
||||
std::mutex mu_; // Protects per_thread_map_.
|
||||
std::unordered_map<std::thread::id, T> per_thread_map_;
|
||||
};
|
||||
|
||||
} // namespace Eigen
|
||||
|
||||
#endif // EIGEN_CXX11_THREADPOOL_THREAD_LOCAL_H
|
||||
|
@ -201,6 +201,7 @@ if(EIGEN_TEST_CXX11)
|
||||
ei_add_test(cxx11_tensor_shuffling)
|
||||
ei_add_test(cxx11_tensor_striding)
|
||||
ei_add_test(cxx11_tensor_notification "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_thread_local "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_thread_pool "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_executor "-pthread" "${CMAKE_THREAD_LIBS_INIT}")
|
||||
ei_add_test(cxx11_tensor_ref)
|
||||
|
158
unsupported/test/cxx11_tensor_thread_local.cpp
Normal file
158
unsupported/test/cxx11_tensor_thread_local.cpp
Normal file
@ -0,0 +1,158 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <iostream>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "main.h"
|
||||
#include <Eigen/CXX11/ThreadPool>
|
||||
|
||||
class Counter {
|
||||
public:
|
||||
Counter() : Counter(0) {}
|
||||
explicit Counter(int value)
|
||||
: created_by_(std::this_thread::get_id()), value_(value) {}
|
||||
|
||||
void inc() {
|
||||
// Check that mutation happens only in a thread that created this counter.
|
||||
VERIFY_IS_EQUAL(std::this_thread::get_id(), created_by_);
|
||||
value_++;
|
||||
}
|
||||
int value() { return value_; }
|
||||
|
||||
private:
|
||||
std::thread::id created_by_;
|
||||
int value_;
|
||||
};
|
||||
|
||||
struct CounterFactory {
|
||||
using T = Counter;
|
||||
|
||||
T Allocate() { return Counter(0); }
|
||||
void Release(T&) {}
|
||||
};
|
||||
|
||||
void test_simple_thread_local() {
|
||||
CounterFactory factory;
|
||||
int num_threads = internal::random<int>(4, 32);
|
||||
Eigen::ThreadPool thread_pool(num_threads);
|
||||
Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads);
|
||||
|
||||
int num_tasks = 3 * num_threads;
|
||||
Eigen::Barrier barrier(num_tasks);
|
||||
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
thread_pool.Schedule([&counter, &barrier]() {
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
|
||||
barrier.Wait();
|
||||
|
||||
counter.ForEach(
|
||||
[](std::thread::id, Counter& cnt) { VERIFY_IS_EQUAL(cnt.value(), 3); });
|
||||
}
|
||||
|
||||
void test_zero_sized_thread_local() {
|
||||
CounterFactory factory;
|
||||
Eigen::ThreadLocal<CounterFactory> counter(factory, 0);
|
||||
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
|
||||
int total = 0;
|
||||
counter.ForEach([&total](std::thread::id, Counter& cnt) {
|
||||
total += cnt.value();
|
||||
VERIFY_IS_EQUAL(cnt.value(), 1);
|
||||
});
|
||||
|
||||
VERIFY_IS_EQUAL(total, 1);
|
||||
}
|
||||
|
||||
// All thread local values fits into the lock-free storage.
|
||||
void test_large_number_of_tasks_no_spill() {
|
||||
CounterFactory factory;
|
||||
int num_threads = internal::random<int>(4, 32);
|
||||
Eigen::ThreadPool thread_pool(num_threads);
|
||||
Eigen::ThreadLocal<CounterFactory> counter(factory, num_threads);
|
||||
|
||||
int num_tasks = 10000;
|
||||
Eigen::Barrier barrier(num_tasks);
|
||||
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
thread_pool.Schedule([&counter, &barrier]() {
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
|
||||
barrier.Wait();
|
||||
|
||||
int total = 0;
|
||||
std::unordered_set<std::thread::id> unique_threads;
|
||||
|
||||
counter.ForEach([&](std::thread::id id, Counter& cnt) {
|
||||
total += cnt.value();
|
||||
unique_threads.insert(id);
|
||||
});
|
||||
|
||||
VERIFY_IS_EQUAL(total, num_tasks);
|
||||
// Not all threads in a pool might be woken up to execute submitted tasks.
|
||||
// Also thread_pool.Schedule() might use current thread if queue is full.
|
||||
VERIFY_IS_EQUAL(
|
||||
unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true);
|
||||
}
|
||||
|
||||
// Lock free thread local storage is too small to fit all the unique threads,
|
||||
// and it spills to a map guarded by a mutex.
|
||||
void test_large_number_of_tasks_with_spill() {
|
||||
CounterFactory factory;
|
||||
int num_threads = internal::random<int>(4, 32);
|
||||
Eigen::ThreadPool thread_pool(num_threads);
|
||||
Eigen::ThreadLocal<CounterFactory> counter(factory, 1); // This is too small
|
||||
|
||||
int num_tasks = 10000;
|
||||
Eigen::Barrier barrier(num_tasks);
|
||||
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
thread_pool.Schedule([&counter, &barrier]() {
|
||||
Counter& local = counter.local();
|
||||
local.inc();
|
||||
barrier.Notify();
|
||||
});
|
||||
}
|
||||
|
||||
barrier.Wait();
|
||||
|
||||
int total = 0;
|
||||
std::unordered_set<std::thread::id> unique_threads;
|
||||
|
||||
counter.ForEach([&](std::thread::id id, Counter& cnt) {
|
||||
total += cnt.value();
|
||||
unique_threads.insert(id);
|
||||
});
|
||||
|
||||
VERIFY_IS_EQUAL(total, num_tasks);
|
||||
// Not all threads in a pool might be woken up to execute submitted tasks.
|
||||
// Also thread_pool.Schedule() might use current thread if queue is full.
|
||||
VERIFY_IS_EQUAL(
|
||||
unique_threads.size() <= (static_cast<size_t>(num_threads + 1)), true);
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_thread_local) {
|
||||
CALL_SUBTEST(test_simple_thread_local());
|
||||
CALL_SUBTEST(test_zero_sized_thread_local());
|
||||
CALL_SUBTEST(test_large_number_of_tasks_no_spill());
|
||||
CALL_SUBTEST(test_large_number_of_tasks_with_spill());
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user