diff --git a/Eigen/src/ThreadPool/ForkJoin.h b/Eigen/src/ThreadPool/ForkJoin.h index f67abd33d..233d2837a 100644 --- a/Eigen/src/ThreadPool/ForkJoin.h +++ b/Eigen/src/ThreadPool/ForkJoin.h @@ -61,9 +61,9 @@ class ForkJoinScheduler { // Runs `do_func` asynchronously for the range [start, end) with a specified // granularity. `do_func` should be of type `std::function + template static void ParallelForAsync(Index start, Index end, Index granularity, DoFnType&& do_func, DoneFnType&& done, - ThreadPool* thread_pool) { + ThreadPoolTempl* thread_pool) { if (start >= end) { done(); return; @@ -76,8 +76,11 @@ class ForkJoinScheduler { } // Synchronous variant of ParallelForAsync. - template - static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) { + // WARNING: Making nested calls to `ParallelFor`, e.g., calling `ParallelFor` inside a task passed into another + // `ParallelFor` call, may lead to deadlocks due to how task stealing is implemented. + template + static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, + ThreadPoolTempl* thread_pool) { if (start >= end) return; Barrier barrier(1); auto done = [&barrier]() { barrier.Notify(); }; @@ -87,8 +90,8 @@ class ForkJoinScheduler { private: // Schedules `right_thunk`, runs `left_thunk`, and runs other tasks until `right_thunk` has finished. - template - static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPool* thread_pool) { + template + static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPoolTempl* thread_pool) { std::atomic right_done(false); auto execute_right = [&right_thunk, &right_done]() { std::forward(right_thunk)(); @@ -114,16 +117,16 @@ class ForkJoinScheduler { return start + offset; } - template - static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) { + template + static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, + ThreadPoolTempl* thread_pool) { Index mid = ComputeMidpoint(start, end, granularity); if ((end - start) < granularity || mid == start || mid == end) { do_func(start, end); return; } - ForkJoin([start, mid, granularity, &do_func, thread_pool]() { - RunParallelFor(start, mid, granularity, do_func, thread_pool); - }, + ForkJoin([start, mid, granularity, &do_func, + thread_pool]() { RunParallelFor(start, mid, granularity, do_func, thread_pool); }, [mid, end, granularity, &do_func, thread_pool]() { RunParallelFor(mid, end, granularity, do_func, thread_pool); },