mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 12:46:00 +08:00
Generalize the Eigen ForkJoin scheduler to use any ThreadPool interface.
This commit is contained in:
parent
70f2aead9a
commit
3143968195
@ -61,9 +61,9 @@ class ForkJoinScheduler {
|
||||
// Runs `do_func` asynchronously for the range [start, end) with a specified
|
||||
// granularity. `do_func` should be of type `std::function<void(Index,
|
||||
// Index)`. `done()` is called exactly once after all tasks have been executed.
|
||||
template <typename DoFnType, typename DoneFnType>
|
||||
template <typename DoFnType, typename DoneFnType, typename ThreadPoolEnv>
|
||||
static void ParallelForAsync(Index start, Index end, Index granularity, DoFnType&& do_func, DoneFnType&& done,
|
||||
ThreadPool* thread_pool) {
|
||||
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||
if (start >= end) {
|
||||
done();
|
||||
return;
|
||||
@ -76,8 +76,11 @@ class ForkJoinScheduler {
|
||||
}
|
||||
|
||||
// Synchronous variant of ParallelForAsync.
|
||||
template <typename DoFnType>
|
||||
static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) {
|
||||
// WARNING: Making nested calls to `ParallelFor`, e.g., calling `ParallelFor` inside a task passed into another
|
||||
// `ParallelFor` call, may lead to deadlocks due to how task stealing is implemented.
|
||||
template <typename DoFnType, typename ThreadPoolEnv>
|
||||
static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
|
||||
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||
if (start >= end) return;
|
||||
Barrier barrier(1);
|
||||
auto done = [&barrier]() { barrier.Notify(); };
|
||||
@ -87,8 +90,8 @@ class ForkJoinScheduler {
|
||||
|
||||
private:
|
||||
// Schedules `right_thunk`, runs `left_thunk`, and runs other tasks until `right_thunk` has finished.
|
||||
template <typename LeftType, typename RightType>
|
||||
static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPool* thread_pool) {
|
||||
template <typename LeftType, typename RightType, typename ThreadPoolEnv>
|
||||
static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||
std::atomic<bool> right_done(false);
|
||||
auto execute_right = [&right_thunk, &right_done]() {
|
||||
std::forward<RightType>(right_thunk)();
|
||||
@ -114,16 +117,16 @@ class ForkJoinScheduler {
|
||||
return start + offset;
|
||||
}
|
||||
|
||||
template <typename DoFnType>
|
||||
static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) {
|
||||
template <typename DoFnType, typename ThreadPoolEnv>
|
||||
static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
|
||||
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||
Index mid = ComputeMidpoint(start, end, granularity);
|
||||
if ((end - start) < granularity || mid == start || mid == end) {
|
||||
do_func(start, end);
|
||||
return;
|
||||
}
|
||||
ForkJoin([start, mid, granularity, &do_func, thread_pool]() {
|
||||
RunParallelFor(start, mid, granularity, do_func, thread_pool);
|
||||
},
|
||||
ForkJoin([start, mid, granularity, &do_func,
|
||||
thread_pool]() { RunParallelFor(start, mid, granularity, do_func, thread_pool); },
|
||||
[mid, end, granularity, &do_func, thread_pool]() {
|
||||
RunParallelFor(mid, end, granularity, do_func, thread_pool);
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user