A few cleanups to threaded product code and test.

This commit is contained in:
Rasmus Munk Larsen 2024-08-09 09:35:23 -07:00
parent 59498c96fe
commit 99ffad1971
2 changed files with 2 additions and 2 deletions

View File

@ -71,7 +71,7 @@ inline void setNbThreads(int v) { internal::manage_multi_threading(SetAction, &v
// TODO(rmlarsen): Make the device API available instead of
// storing a local static pointer variable to avoid this issue.
inline ThreadPool* setGemmThreadPool(ThreadPool* new_pool) {
static ThreadPool* pool;
static ThreadPool* pool = nullptr;
if (new_pool != nullptr) {
// This will wait for work in all threads in *pool to finish,
// then destroy the old ThreadPool, and then replace it with new_pool.
@ -232,7 +232,6 @@ EIGEN_STRONG_INLINE void parallelize_gemm(const Functor& func, Index rows, Index
}
#elif defined(EIGEN_GEMM_THREADPOOL)
ei_declare_aligned_stack_constructed_variable(GemmParallelTaskInfo<Index>, meta_info, threads, 0);
Barrier barrier(threads);
auto task = [=, &func, &barrier, &task_info](int i) {
Index actual_threads = threads;

View File

@ -19,6 +19,7 @@ void test_parallelize_gemm() {
c.noalias() = a * b;
ThreadPool pool(num_threads);
Eigen::setGemmThreadPool(&pool);
MatrixXf c_threaded(n, n);
c_threaded.noalias() = a * b;