From 47fefa235f73315bc57d685a7bc9cd8d3577349f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Tue, 3 Sep 2019 17:20:56 -0700 Subject: [PATCH] Allow move-only done callback in TensorAsyncDevice --- .../Eigen/CXX11/src/Tensor/TensorBase.h | 8 ++-- .../Eigen/CXX11/src/Tensor/TensorDevice.h | 10 ++-- .../Eigen/CXX11/src/Tensor/TensorExecutor.h | 34 ++++++++------ .../src/Tensor/TensorForwardDeclarations.h | 4 +- unsupported/test/cxx11_tensor_executor.cpp | 21 ++++++--- unsupported/test/cxx11_tensor_thread_pool.cpp | 47 ++++++++++--------- 6 files changed, 70 insertions(+), 54 deletions(-) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index 095c85dc4..f2aa37256 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -1065,12 +1065,12 @@ class TensorBase : public TensorBase { #ifdef EIGEN_USE_THREADS // Select the async device on which to evaluate the expression. - template + template typename internal::enable_if< internal::is_same::value, - TensorAsyncDevice>::type - device(const DeviceType& dev, std::function done) { - return TensorAsyncDevice(dev, derived(), std::move(done)); + TensorAsyncDevice>::type + device(const DeviceType& dev, DoneCallback done) { + return TensorAsyncDevice(dev, derived(), std::move(done)); } #endif // EIGEN_USE_THREADS diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h index 5122b3623..cc9c65702 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h @@ -73,21 +73,21 @@ template class TensorDevice { * ThreadPoolDevice). * * Example: - * std::function done = []() {}; + * auto done = []() { ... expression evaluation done ... }; * C.device(EIGEN_THREAD_POOL, std::move(done)) = A + B; */ -template +template class TensorAsyncDevice { public: TensorAsyncDevice(const DeviceType& device, ExpressionType& expression, - std::function done) + DoneCallback done) : m_device(device), m_expression(expression), m_done(std::move(done)) {} template EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) { typedef TensorAssignOp Assign; - typedef internal::TensorAsyncExecutor Executor; + typedef internal::TensorAsyncExecutor Executor; // WARNING: After assignment 'm_done' callback will be in undefined state. Assign assign(m_expression, other); @@ -99,7 +99,7 @@ class TensorAsyncDevice { protected: const DeviceType& m_device; ExpressionType& m_expression; - std::function m_done; + DoneCallback m_done; }; #endif // EIGEN_USE_THREADS diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h index 10339e5e7..cf07656b3 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h @@ -101,8 +101,8 @@ class TensorExecutor { * Default async execution strategy is not implemented. Currently it's only * available for ThreadPoolDevice (see definition below). */ -template +template class TensorAsyncExecutor {}; /** @@ -419,15 +419,17 @@ class TensorExecutor -class TensorAsyncExecutor { +template +class TensorAsyncExecutor { public: typedef typename Expression::Index StorageIndex; typedef TensorEvaluator Evaluator; static EIGEN_STRONG_INLINE void runAsync(const Expression& expr, const ThreadPoolDevice& device, - std::function done) { + DoneCallback done) { TensorAsyncExecutorContext* const ctx = new TensorAsyncExecutorContext(expr, device, std::move(done)); @@ -455,7 +457,7 @@ class TensorAsyncExecutor struct TensorAsyncExecutorContext { TensorAsyncExecutorContext(const Expression& expr, const ThreadPoolDevice& thread_pool, - std::function done) + DoneCallback done) : evaluator(expr, thread_pool), on_done(std::move(done)) {} ~TensorAsyncExecutorContext() { @@ -466,12 +468,13 @@ class TensorAsyncExecutor Evaluator evaluator; private: - std::function on_done; + DoneCallback on_done; }; }; -template -class TensorAsyncExecutor { +template +class TensorAsyncExecutor { public: typedef typename traits::Index StorageIndex; typedef typename traits::Scalar Scalar; @@ -485,7 +488,7 @@ class TensorAsyncExecutor done) { + DoneCallback done) { TensorAsyncExecutorContext* const ctx = new TensorAsyncExecutorContext(expr, device, std::move(done)); @@ -494,9 +497,10 @@ class TensorAsyncExecutor::value) { - internal::TensorAsyncExecutor::runAsync( - expr, device, [ctx]() { delete ctx; }); + auto delete_ctx = [ctx]() { delete ctx; }; + internal::TensorAsyncExecutor< + Expression, ThreadPoolDevice, decltype(delete_ctx), Vectorizable, + /*Tileable*/ false>::runAsync(expr, device, std::move(delete_ctx)); return; } @@ -532,7 +536,7 @@ class TensorAsyncExecutor done) + DoneCallback done) : device(thread_pool), evaluator(expr, thread_pool), on_done(std::move(done)) {} @@ -548,7 +552,7 @@ class TensorAsyncExecutor on_done; + DoneCallback on_done; }; }; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index e823bd932..772dbbe35 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -94,7 +94,7 @@ template class MakePointer_ = MakePointer> cl template class TensorForcedEvalOp; template class TensorDevice; -template class TensorAsyncDevice; +template class TensorAsyncDevice; template struct TensorEvaluator; struct NoOpOutputKernel; @@ -168,7 +168,7 @@ template ::value> class TensorExecutor; -template ::value, bool Tileable = IsTileable::value> class TensorAsyncExecutor; diff --git a/unsupported/test/cxx11_tensor_executor.cpp b/unsupported/test/cxx11_tensor_executor.cpp index f4d0401da..aa4ab0b80 100644 --- a/unsupported/test/cxx11_tensor_executor.cpp +++ b/unsupported/test/cxx11_tensor_executor.cpp @@ -578,11 +578,15 @@ static void test_async_execute_unary_expr(Device d) src.setRandom(); const auto expr = src.square(); - using Assign = TensorAssignOp; - using Executor = internal::TensorAsyncExecutor; Eigen::Barrier done(1); - Executor::runAsync(Assign(dst, expr), d, [&done]() { done.Notify(); }); + auto on_done = [&done]() { done.Notify(); }; + + using Assign = TensorAssignOp; + using DoneCallback = decltype(on_done); + using Executor = internal::TensorAsyncExecutor; + + Executor::runAsync(Assign(dst, expr), d, on_done); done.Wait(); for (Index i = 0; i < dst.dimensions().TotalSize(); ++i) { @@ -610,12 +614,15 @@ static void test_async_execute_binary_expr(Device d) const auto expr = lhs + rhs; + Eigen::Barrier done(1); + auto on_done = [&done]() { done.Notify(); }; + using Assign = TensorAssignOp; - using Executor = internal::TensorAsyncExecutor; - Eigen::Barrier done(1); - Executor::runAsync(Assign(dst, expr), d, [&done]() { done.Notify(); }); + Executor::runAsync(Assign(dst, expr), d, on_done); done.Wait(); for (Index i = 0; i < dst.dimensions().TotalSize(); ++i) { diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index 62973cd08..dae7b0335 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -683,34 +683,39 @@ EIGEN_DECLARE_TEST(cxx11_tensor_thread_pool) CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread()); CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel()); CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel()); - CALL_SUBTEST_3(test_async_multithread_contraction_agrees_with_singlethread()); - CALL_SUBTEST_3(test_async_multithread_contraction_agrees_with_singlethread()); + + CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread()); + CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread()); // Test EvalShardedByInnerDimContext parallelization strategy. - CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction_with_output_kernel()); - CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction_with_output_kernel()); - CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction()); - CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); - CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); + CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel()); + CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel()); + + CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction()); + CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); + CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel()); // Exercise various cases that have been problematic in the past. - CALL_SUBTEST_5(test_contraction_corner_cases()); - CALL_SUBTEST_5(test_contraction_corner_cases()); + CALL_SUBTEST_7(test_contraction_corner_cases()); + CALL_SUBTEST_7(test_contraction_corner_cases()); - CALL_SUBTEST_6(test_full_contraction()); - CALL_SUBTEST_6(test_full_contraction()); + CALL_SUBTEST_8(test_full_contraction()); + CALL_SUBTEST_8(test_full_contraction()); - CALL_SUBTEST_7(test_multithreaded_reductions()); - CALL_SUBTEST_7(test_multithreaded_reductions()); + CALL_SUBTEST_9(test_multithreaded_reductions()); + CALL_SUBTEST_9(test_multithreaded_reductions()); - CALL_SUBTEST_7(test_memcpy()); - CALL_SUBTEST_7(test_multithread_random()); + CALL_SUBTEST_10(test_memcpy()); + CALL_SUBTEST_10(test_multithread_random()); TestAllocator test_allocator; - CALL_SUBTEST_7(test_multithread_shuffle(NULL)); - CALL_SUBTEST_7(test_multithread_shuffle(&test_allocator)); - CALL_SUBTEST_7(test_threadpool_allocate(&test_allocator)); + CALL_SUBTEST_11(test_multithread_shuffle(NULL)); + CALL_SUBTEST_11(test_multithread_shuffle(&test_allocator)); + CALL_SUBTEST_11(test_threadpool_allocate(&test_allocator)); + + // Force CMake to split this test. + // EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11 }