Format TensorDeviceThreadPool.h & use if constexpr for c++20.

This commit is contained in:
Rasmus Munk Larsen 2025-03-08 01:09:36 +00:00
parent 21223f6bb6
commit 464c1d0978

View File

@ -69,7 +69,7 @@ struct ThreadPoolDevice {
Barrier barrier(static_cast<int>(num_threads - 1));
// Launch the last 3 blocks on worker threads.
for (size_t i = 1; i < num_threads; ++i) {
enqueue([n, i, src_ptr, dst_ptr, blocksize, &barrier] {
pool_->Schedule([n, i, src_ptr, dst_ptr, blocksize, &barrier] {
::memcpy(dst_ptr + i * blocksize, src_ptr + i * blocksize, numext::mini(blocksize, n - (i * blocksize)));
barrier.Notify();
});
@ -120,11 +120,11 @@ struct ThreadPoolDevice {
template <class Function, class... Args>
EIGEN_STRONG_INLINE void enqueue(Function&& f, Args&&... args) const {
if (sizeof...(args) > 0) {
#if EIGEN_COMP_CXXVER >= 20
auto run_f = [f = std::forward<Function>(f),
...args = std::forward<Args>(args)]() { f(args...); };
if constexpr (sizeof...(args) > 0) {
auto run_f = [f = std::forward<Function>(f), ... args = std::forward<Args>(args)]() { f(args...); };
#else
if (sizeof...(args) > 0) {
auto run_f = [f = std::forward<Function>(f), &args...]() { f(args...); };
#endif
pool_->Schedule(std::move(run_f));
@ -168,9 +168,7 @@ struct ThreadPoolDevice {
} else {
// Execute the root in the thread pool to avoid running work on more than
// numThreads() threads.
pool_->Schedule([this, n, &block, &barrier, &f]() {
handleRange(0, n, block.size, &barrier, pool_, f);
});
pool_->Schedule([this, n, &block, &barrier, &f]() { handleRange(0, n, block.size, &barrier, pool_, f); });
}
barrier.Wait();
@ -246,14 +244,12 @@ struct ThreadPoolDevice {
private:
typedef TensorCostModel<ThreadPoolDevice> CostModel;
static void handleRange(Index firstIdx, Index lastIdx, Index granularity,
Barrier* barrier, ThreadPoolInterface* pool, const std::function<void(Index, Index)>& f) {
static void handleRange(Index firstIdx, Index lastIdx, Index granularity, Barrier* barrier, ThreadPoolInterface* pool,
const std::function<void(Index, Index)>& f) {
while (lastIdx - firstIdx > granularity) {
// Split into halves and schedule the second half on a different thread.
const Index midIdx = firstIdx + numext::div_ceil((lastIdx - firstIdx) / 2, granularity) * granularity;
pool->Schedule([=, &f]() {
handleRange(midIdx, lastIdx, granularity, barrier, pool, f);
});
pool->Schedule([=, &f]() { handleRange(midIdx, lastIdx, granularity, barrier, pool, f); });
lastIdx = midIdx;
}
// Single block or less, execute directly.