Generalize the Eigen ForkJoin scheduler to use any ThreadPool interface.

This commit is contained in:
William Kong 2025-03-19 18:44:03 +00:00 committed by Rasmus Munk Larsen
parent 70f2aead9a
commit 3143968195

View File

@ -61,9 +61,9 @@ class ForkJoinScheduler {
// Runs `do_func` asynchronously for the range [start, end) with a specified // Runs `do_func` asynchronously for the range [start, end) with a specified
// granularity. `do_func` should be of type `std::function<void(Index, // granularity. `do_func` should be of type `std::function<void(Index,
// Index)`. `done()` is called exactly once after all tasks have been executed. // Index)`. `done()` is called exactly once after all tasks have been executed.
template <typename DoFnType, typename DoneFnType> template <typename DoFnType, typename DoneFnType, typename ThreadPoolEnv>
static void ParallelForAsync(Index start, Index end, Index granularity, DoFnType&& do_func, DoneFnType&& done, static void ParallelForAsync(Index start, Index end, Index granularity, DoFnType&& do_func, DoneFnType&& done,
ThreadPool* thread_pool) { ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
if (start >= end) { if (start >= end) {
done(); done();
return; return;
@ -76,8 +76,11 @@ class ForkJoinScheduler {
} }
// Synchronous variant of ParallelForAsync. // Synchronous variant of ParallelForAsync.
template <typename DoFnType> // WARNING: Making nested calls to `ParallelFor`, e.g., calling `ParallelFor` inside a task passed into another
static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) { // `ParallelFor` call, may lead to deadlocks due to how task stealing is implemented.
template <typename DoFnType, typename ThreadPoolEnv>
static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
if (start >= end) return; if (start >= end) return;
Barrier barrier(1); Barrier barrier(1);
auto done = [&barrier]() { barrier.Notify(); }; auto done = [&barrier]() { barrier.Notify(); };
@ -87,8 +90,8 @@ class ForkJoinScheduler {
private: private:
// Schedules `right_thunk`, runs `left_thunk`, and runs other tasks until `right_thunk` has finished. // Schedules `right_thunk`, runs `left_thunk`, and runs other tasks until `right_thunk` has finished.
template <typename LeftType, typename RightType> template <typename LeftType, typename RightType, typename ThreadPoolEnv>
static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPool* thread_pool) { static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
std::atomic<bool> right_done(false); std::atomic<bool> right_done(false);
auto execute_right = [&right_thunk, &right_done]() { auto execute_right = [&right_thunk, &right_done]() {
std::forward<RightType>(right_thunk)(); std::forward<RightType>(right_thunk)();
@ -114,16 +117,16 @@ class ForkJoinScheduler {
return start + offset; return start + offset;
} }
template <typename DoFnType> template <typename DoFnType, typename ThreadPoolEnv>
static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) { static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
Index mid = ComputeMidpoint(start, end, granularity); Index mid = ComputeMidpoint(start, end, granularity);
if ((end - start) < granularity || mid == start || mid == end) { if ((end - start) < granularity || mid == start || mid == end) {
do_func(start, end); do_func(start, end);
return; return;
} }
ForkJoin([start, mid, granularity, &do_func, thread_pool]() { ForkJoin([start, mid, granularity, &do_func,
RunParallelFor(start, mid, granularity, do_func, thread_pool); thread_pool]() { RunParallelFor(start, mid, granularity, do_func, thread_pool); },
},
[mid, end, granularity, &do_func, thread_pool]() { [mid, end, granularity, &do_func, thread_pool]() {
RunParallelFor(mid, end, granularity, do_func, thread_pool); RunParallelFor(mid, end, granularity, do_func, thread_pool);
}, },