mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 12:46:00 +08:00
Generalize the Eigen ForkJoin scheduler to use any ThreadPool interface.
This commit is contained in:
parent
70f2aead9a
commit
3143968195
@ -61,9 +61,9 @@ class ForkJoinScheduler {
|
|||||||
// Runs `do_func` asynchronously for the range [start, end) with a specified
|
// Runs `do_func` asynchronously for the range [start, end) with a specified
|
||||||
// granularity. `do_func` should be of type `std::function<void(Index,
|
// granularity. `do_func` should be of type `std::function<void(Index,
|
||||||
// Index)`. `done()` is called exactly once after all tasks have been executed.
|
// Index)`. `done()` is called exactly once after all tasks have been executed.
|
||||||
template <typename DoFnType, typename DoneFnType>
|
template <typename DoFnType, typename DoneFnType, typename ThreadPoolEnv>
|
||||||
static void ParallelForAsync(Index start, Index end, Index granularity, DoFnType&& do_func, DoneFnType&& done,
|
static void ParallelForAsync(Index start, Index end, Index granularity, DoFnType&& do_func, DoneFnType&& done,
|
||||||
ThreadPool* thread_pool) {
|
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||||
if (start >= end) {
|
if (start >= end) {
|
||||||
done();
|
done();
|
||||||
return;
|
return;
|
||||||
@ -76,8 +76,11 @@ class ForkJoinScheduler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Synchronous variant of ParallelForAsync.
|
// Synchronous variant of ParallelForAsync.
|
||||||
template <typename DoFnType>
|
// WARNING: Making nested calls to `ParallelFor`, e.g., calling `ParallelFor` inside a task passed into another
|
||||||
static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) {
|
// `ParallelFor` call, may lead to deadlocks due to how task stealing is implemented.
|
||||||
|
template <typename DoFnType, typename ThreadPoolEnv>
|
||||||
|
static void ParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
|
||||||
|
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||||
if (start >= end) return;
|
if (start >= end) return;
|
||||||
Barrier barrier(1);
|
Barrier barrier(1);
|
||||||
auto done = [&barrier]() { barrier.Notify(); };
|
auto done = [&barrier]() { barrier.Notify(); };
|
||||||
@ -87,8 +90,8 @@ class ForkJoinScheduler {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
// Schedules `right_thunk`, runs `left_thunk`, and runs other tasks until `right_thunk` has finished.
|
// Schedules `right_thunk`, runs `left_thunk`, and runs other tasks until `right_thunk` has finished.
|
||||||
template <typename LeftType, typename RightType>
|
template <typename LeftType, typename RightType, typename ThreadPoolEnv>
|
||||||
static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPool* thread_pool) {
|
static void ForkJoin(LeftType&& left_thunk, RightType&& right_thunk, ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||||
std::atomic<bool> right_done(false);
|
std::atomic<bool> right_done(false);
|
||||||
auto execute_right = [&right_thunk, &right_done]() {
|
auto execute_right = [&right_thunk, &right_done]() {
|
||||||
std::forward<RightType>(right_thunk)();
|
std::forward<RightType>(right_thunk)();
|
||||||
@ -114,16 +117,16 @@ class ForkJoinScheduler {
|
|||||||
return start + offset;
|
return start + offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename DoFnType>
|
template <typename DoFnType, typename ThreadPoolEnv>
|
||||||
static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func, ThreadPool* thread_pool) {
|
static void RunParallelFor(Index start, Index end, Index granularity, DoFnType&& do_func,
|
||||||
|
ThreadPoolTempl<ThreadPoolEnv>* thread_pool) {
|
||||||
Index mid = ComputeMidpoint(start, end, granularity);
|
Index mid = ComputeMidpoint(start, end, granularity);
|
||||||
if ((end - start) < granularity || mid == start || mid == end) {
|
if ((end - start) < granularity || mid == start || mid == end) {
|
||||||
do_func(start, end);
|
do_func(start, end);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ForkJoin([start, mid, granularity, &do_func, thread_pool]() {
|
ForkJoin([start, mid, granularity, &do_func,
|
||||||
RunParallelFor(start, mid, granularity, do_func, thread_pool);
|
thread_pool]() { RunParallelFor(start, mid, granularity, do_func, thread_pool); },
|
||||||
},
|
|
||||||
[mid, end, granularity, &do_func, thread_pool]() {
|
[mid, end, granularity, &do_func, thread_pool]() {
|
||||||
RunParallelFor(mid, end, granularity, do_func, thread_pool);
|
RunParallelFor(mid, end, granularity, do_func, thread_pool);
|
||||||
},
|
},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user