mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-22 09:39:34 +08:00
Format TensorDeviceThreadPool.h & use if constexpr for c++20.
This commit is contained in:
parent
21223f6bb6
commit
464c1d0978
@ -69,7 +69,7 @@ struct ThreadPoolDevice {
|
||||
Barrier barrier(static_cast<int>(num_threads - 1));
|
||||
// Launch the last 3 blocks on worker threads.
|
||||
for (size_t i = 1; i < num_threads; ++i) {
|
||||
enqueue([n, i, src_ptr, dst_ptr, blocksize, &barrier] {
|
||||
pool_->Schedule([n, i, src_ptr, dst_ptr, blocksize, &barrier] {
|
||||
::memcpy(dst_ptr + i * blocksize, src_ptr + i * blocksize, numext::mini(blocksize, n - (i * blocksize)));
|
||||
barrier.Notify();
|
||||
});
|
||||
@ -120,11 +120,11 @@ struct ThreadPoolDevice {
|
||||
|
||||
template <class Function, class... Args>
|
||||
EIGEN_STRONG_INLINE void enqueue(Function&& f, Args&&... args) const {
|
||||
if (sizeof...(args) > 0) {
|
||||
#if EIGEN_COMP_CXXVER >= 20
|
||||
auto run_f = [f = std::forward<Function>(f),
|
||||
...args = std::forward<Args>(args)]() { f(args...); };
|
||||
if constexpr (sizeof...(args) > 0) {
|
||||
auto run_f = [f = std::forward<Function>(f), ... args = std::forward<Args>(args)]() { f(args...); };
|
||||
#else
|
||||
if (sizeof...(args) > 0) {
|
||||
auto run_f = [f = std::forward<Function>(f), &args...]() { f(args...); };
|
||||
#endif
|
||||
pool_->Schedule(std::move(run_f));
|
||||
@ -168,9 +168,7 @@ struct ThreadPoolDevice {
|
||||
} else {
|
||||
// Execute the root in the thread pool to avoid running work on more than
|
||||
// numThreads() threads.
|
||||
pool_->Schedule([this, n, &block, &barrier, &f]() {
|
||||
handleRange(0, n, block.size, &barrier, pool_, f);
|
||||
});
|
||||
pool_->Schedule([this, n, &block, &barrier, &f]() { handleRange(0, n, block.size, &barrier, pool_, f); });
|
||||
}
|
||||
|
||||
barrier.Wait();
|
||||
@ -246,14 +244,12 @@ struct ThreadPoolDevice {
|
||||
private:
|
||||
typedef TensorCostModel<ThreadPoolDevice> CostModel;
|
||||
|
||||
static void handleRange(Index firstIdx, Index lastIdx, Index granularity,
|
||||
Barrier* barrier, ThreadPoolInterface* pool, const std::function<void(Index, Index)>& f) {
|
||||
static void handleRange(Index firstIdx, Index lastIdx, Index granularity, Barrier* barrier, ThreadPoolInterface* pool,
|
||||
const std::function<void(Index, Index)>& f) {
|
||||
while (lastIdx - firstIdx > granularity) {
|
||||
// Split into halves and schedule the second half on a different thread.
|
||||
const Index midIdx = firstIdx + numext::div_ceil((lastIdx - firstIdx) / 2, granularity) * granularity;
|
||||
pool->Schedule([=, &f]() {
|
||||
handleRange(midIdx, lastIdx, granularity, barrier, pool, f);
|
||||
});
|
||||
pool->Schedule([=, &f]() { handleRange(midIdx, lastIdx, granularity, barrier, pool, f); });
|
||||
lastIdx = midIdx;
|
||||
}
|
||||
// Single block or less, execute directly.
|
||||
|
Loading…
x
Reference in New Issue
Block a user