From 33d0937c6bdf5ec999939fb17f2a553183d14a74 Mon Sep 17 00:00:00 2001 From: adambanas Date: Thu, 27 Jun 2024 09:34:19 +0200 Subject: [PATCH] Add async support for 'chip' and 'extract_volume_patches' --- .../Eigen/CXX11/src/Tensor/TensorChipping.h | 7 + .../CXX11/src/Tensor/TensorVolumePatch.h | 7 + unsupported/test/cxx11_tensor_thread_pool.cpp | 138 ++++++++++++++---- 3 files changed, 126 insertions(+), 26 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h index 32980c79e..000b1fb58 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h @@ -182,6 +182,13 @@ struct TensorEvaluator, Device> { return true; } +#ifdef EIGEN_USE_THREADS + template + EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType /*data*/, EvalSubExprsCallback done) { + m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); }); + } +#endif // EIGEN_USE_THREADS + EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h b/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h index 75063f587..d8faa4d6c 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h @@ -365,6 +365,13 @@ struct TensorEvaluator, D return true; } +#ifdef EIGEN_USE_THREADS + template + EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType /*data*/, EvalSubExprsCallback done) { + m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); }); + } +#endif // EIGEN_USE_THREADS + EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const { diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 8961c8463..a566d7e73 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -80,6 +80,86 @@ void test_async_multithread_elementwise() { } } +void test_multithread_chip() { + Tensor in(2, 3, 5, 7, 11); + Tensor out(3, 5, 7, 11); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random(3, 11)); + + out.device(thread_pool_device) = in.chip(1, 0); + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(out(i, j, k, l), in(1, i, j, k, l)); + } + } + } + } +} + +void test_async_multithread_chip() { + Tensor in(2, 3, 5, 7, 11); + Tensor out(3, 5, 7, 11); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random(3, 11)); + + Eigen::Barrier b(1); + out.device(thread_pool_device, [&b]() { b.Notify(); }) = in.chip(1, 0); + b.Wait(); + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 5; ++j) { + for (int k = 0; k < 7; ++k) { + for (int l = 0; l < 11; ++l) { + VERIFY_IS_EQUAL(out(i, j, k, l), in(1, i, j, k, l)); + } + } + } + } +} + +void test_multithread_volume_patch() { + Tensor in(4, 2, 3, 5, 7); + Tensor out(4, 1, 1, 1, 2 * 3 * 5, 7); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random(3, 11)); + + out.device(thread_pool_device) = in.extract_volume_patches(1, 1, 1); + + for (int i = 0; i < in.size(); ++i) { + VERIFY_IS_EQUAL(in.data()[i], out.data()[i]); + } +} + +void test_async_multithread_volume_patch() { + Tensor in(4, 2, 3, 5, 7); + Tensor out(4, 1, 1, 1, 2 * 3 * 5, 7); + + in.setRandom(); + + Eigen::ThreadPool tp(internal::random(3, 11)); + Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random(3, 11)); + + Eigen::Barrier b(1); + out.device(thread_pool_device, [&b]() { b.Notify(); }) = in.extract_volume_patches(1, 1, 1); + b.Wait(); + + for (int i = 0; i < in.size(); ++i) { + VERIFY_IS_EQUAL(in.data()[i], out.data()[i]); + } +} + void test_multithread_compound_assignment() { Tensor in1(2, 3, 7); Tensor in2(2, 3, 7); @@ -648,43 +728,49 @@ EIGEN_DECLARE_TEST(cxx11_tensor_thread_pool) { CALL_SUBTEST_2(test_multithread_contraction()); CALL_SUBTEST_2(test_multithread_contraction()); - CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread()); - CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread()); - CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel()); - CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel()); + CALL_SUBTEST_3(test_multithread_chip()); + CALL_SUBTEST_3(test_async_multithread_chip()); - CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread()); - CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread()); + CALL_SUBTEST_4(test_multithread_volume_patch()); + CALL_SUBTEST_4(test_async_multithread_volume_patch()); + + CALL_SUBTEST_5(test_multithread_contraction_agrees_with_singlethread()); + CALL_SUBTEST_5(test_multithread_contraction_agrees_with_singlethread()); + CALL_SUBTEST_5(test_multithread_contraction_with_output_kernel()); + CALL_SUBTEST_5(test_multithread_contraction_with_output_kernel()); + + CALL_SUBTEST_6(test_async_multithread_contraction_agrees_with_singlethread()); + CALL_SUBTEST_6(test_async_multithread_contraction_agrees_with_singlethread()); // Test EvalShardedByInnerDimContext parallelization strategy. - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel()); - CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction_with_output_kernel()); + CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction_with_output_kernel()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); - CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); + CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); // Exercise various cases that have been problematic in the past. - CALL_SUBTEST_7(test_contraction_corner_cases()); - CALL_SUBTEST_7(test_contraction_corner_cases()); + CALL_SUBTEST_9(test_contraction_corner_cases()); + CALL_SUBTEST_9(test_contraction_corner_cases()); - CALL_SUBTEST_8(test_full_contraction()); - CALL_SUBTEST_8(test_full_contraction()); + CALL_SUBTEST_10(test_full_contraction()); + CALL_SUBTEST_10(test_full_contraction()); - CALL_SUBTEST_9(test_multithreaded_reductions()); - CALL_SUBTEST_9(test_multithreaded_reductions()); + CALL_SUBTEST_11(test_multithreaded_reductions()); + CALL_SUBTEST_11(test_multithreaded_reductions()); - CALL_SUBTEST_10(test_memcpy()); - CALL_SUBTEST_10(test_multithread_random()); + CALL_SUBTEST_12(test_memcpy()); + CALL_SUBTEST_12(test_multithread_random()); TestAllocator test_allocator; - CALL_SUBTEST_11(test_multithread_shuffle(NULL)); - CALL_SUBTEST_11(test_multithread_shuffle(&test_allocator)); - CALL_SUBTEST_11(test_threadpool_allocate(&test_allocator)); + CALL_SUBTEST_13(test_multithread_shuffle(NULL)); + CALL_SUBTEST_13(test_multithread_shuffle(&test_allocator)); + CALL_SUBTEST_13(test_threadpool_allocate(&test_allocator)); // Force CMake to split this test. - // EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11 + // EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11;12;13 }