Allow move-only done callback in TensorAsyncDevice

This commit is contained in:
Eugene Zhulenev 2019-09-03 17:20:56 -07:00
parent a8d264fa9c
commit 47fefa235f
6 changed files with 70 additions and 54 deletions

View File

@ -1065,12 +1065,12 @@ class TensorBase : public TensorBase<Derived, ReadOnlyAccessors> {
#ifdef EIGEN_USE_THREADS #ifdef EIGEN_USE_THREADS
// Select the async device on which to evaluate the expression. // Select the async device on which to evaluate the expression.
template <typename DeviceType> template <typename DeviceType, typename DoneCallback>
typename internal::enable_if< typename internal::enable_if<
internal::is_same<DeviceType, ThreadPoolDevice>::value, internal::is_same<DeviceType, ThreadPoolDevice>::value,
TensorAsyncDevice<Derived, DeviceType>>::type TensorAsyncDevice<Derived, DeviceType, DoneCallback>>::type
device(const DeviceType& dev, std::function<void()> done) { device(const DeviceType& dev, DoneCallback done) {
return TensorAsyncDevice<Derived, DeviceType>(dev, derived(), std::move(done)); return TensorAsyncDevice<Derived, DeviceType, DoneCallback>(dev, derived(), std::move(done));
} }
#endif // EIGEN_USE_THREADS #endif // EIGEN_USE_THREADS

View File

@ -73,21 +73,21 @@ template <typename ExpressionType, typename DeviceType> class TensorDevice {
* ThreadPoolDevice). * ThreadPoolDevice).
* *
* Example: * Example:
* std::function<void()> done = []() {}; * auto done = []() { ... expression evaluation done ... };
* C.device(EIGEN_THREAD_POOL, std::move(done)) = A + B; * C.device(EIGEN_THREAD_POOL, std::move(done)) = A + B;
*/ */
template <typename ExpressionType, typename DeviceType> template <typename ExpressionType, typename DeviceType, typename DoneCallback>
class TensorAsyncDevice { class TensorAsyncDevice {
public: public:
TensorAsyncDevice(const DeviceType& device, ExpressionType& expression, TensorAsyncDevice(const DeviceType& device, ExpressionType& expression,
std::function<void()> done) DoneCallback done)
: m_device(device), m_expression(expression), m_done(std::move(done)) {} : m_device(device), m_expression(expression), m_done(std::move(done)) {}
template <typename OtherDerived> template <typename OtherDerived>
EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) { EIGEN_STRONG_INLINE TensorAsyncDevice& operator=(const OtherDerived& other) {
typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign; typedef TensorAssignOp<ExpressionType, const OtherDerived> Assign;
typedef internal::TensorAsyncExecutor<const Assign, DeviceType> Executor; typedef internal::TensorAsyncExecutor<const Assign, DeviceType, DoneCallback> Executor;
// WARNING: After assignment 'm_done' callback will be in undefined state. // WARNING: After assignment 'm_done' callback will be in undefined state.
Assign assign(m_expression, other); Assign assign(m_expression, other);
@ -99,7 +99,7 @@ class TensorAsyncDevice {
protected: protected:
const DeviceType& m_device; const DeviceType& m_device;
ExpressionType& m_expression; ExpressionType& m_expression;
std::function<void()> m_done; DoneCallback m_done;
}; };
#endif // EIGEN_USE_THREADS #endif // EIGEN_USE_THREADS

View File

@ -101,8 +101,8 @@ class TensorExecutor {
* Default async execution strategy is not implemented. Currently it's only * Default async execution strategy is not implemented. Currently it's only
* available for ThreadPoolDevice (see definition below). * available for ThreadPoolDevice (see definition below).
*/ */
template <typename Expression, typename Device, bool Vectorizable, template <typename Expression, typename Device, typename DoneCallback,
bool Tileable> bool Vectorizable, bool Tileable>
class TensorAsyncExecutor {}; class TensorAsyncExecutor {};
/** /**
@ -419,15 +419,17 @@ class TensorExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable*/ tr
} }
}; };
template <typename Expression, bool Vectorizable, bool Tileable> template <typename Expression, typename DoneCallback, bool Vectorizable,
class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable> { bool Tileable>
class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
Vectorizable, Tileable> {
public: public:
typedef typename Expression::Index StorageIndex; typedef typename Expression::Index StorageIndex;
typedef TensorEvaluator<Expression, ThreadPoolDevice> Evaluator; typedef TensorEvaluator<Expression, ThreadPoolDevice> Evaluator;
static EIGEN_STRONG_INLINE void runAsync(const Expression& expr, static EIGEN_STRONG_INLINE void runAsync(const Expression& expr,
const ThreadPoolDevice& device, const ThreadPoolDevice& device,
std::function<void()> done) { DoneCallback done) {
TensorAsyncExecutorContext* const ctx = TensorAsyncExecutorContext* const ctx =
new TensorAsyncExecutorContext(expr, device, std::move(done)); new TensorAsyncExecutorContext(expr, device, std::move(done));
@ -455,7 +457,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable>
struct TensorAsyncExecutorContext { struct TensorAsyncExecutorContext {
TensorAsyncExecutorContext(const Expression& expr, TensorAsyncExecutorContext(const Expression& expr,
const ThreadPoolDevice& thread_pool, const ThreadPoolDevice& thread_pool,
std::function<void()> done) DoneCallback done)
: evaluator(expr, thread_pool), on_done(std::move(done)) {} : evaluator(expr, thread_pool), on_done(std::move(done)) {}
~TensorAsyncExecutorContext() { ~TensorAsyncExecutorContext() {
@ -466,12 +468,13 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, Tileable>
Evaluator evaluator; Evaluator evaluator;
private: private:
std::function<void()> on_done; DoneCallback on_done;
}; };
}; };
template <typename Expression, bool Vectorizable> template <typename Expression, typename DoneCallback, bool Vectorizable>
class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable*/ true> { class TensorAsyncExecutor<Expression, ThreadPoolDevice, DoneCallback,
Vectorizable, /*Tileable*/ true> {
public: public:
typedef typename traits<Expression>::Index StorageIndex; typedef typename traits<Expression>::Index StorageIndex;
typedef typename traits<Expression>::Scalar Scalar; typedef typename traits<Expression>::Scalar Scalar;
@ -485,7 +488,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
static EIGEN_STRONG_INLINE void runAsync(const Expression& expr, static EIGEN_STRONG_INLINE void runAsync(const Expression& expr,
const ThreadPoolDevice& device, const ThreadPoolDevice& device,
std::function<void()> done) { DoneCallback done) {
TensorAsyncExecutorContext* const ctx = TensorAsyncExecutorContext* const ctx =
new TensorAsyncExecutorContext(expr, device, std::move(done)); new TensorAsyncExecutorContext(expr, device, std::move(done));
@ -494,9 +497,10 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
if (total_size < cache_size && if (total_size < cache_size &&
!ExpressionHasTensorBroadcastingOp<Expression>::value) { !ExpressionHasTensorBroadcastingOp<Expression>::value) {
internal::TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, auto delete_ctx = [ctx]() { delete ctx; };
/*Tileable*/ false>::runAsync( internal::TensorAsyncExecutor<
expr, device, [ctx]() { delete ctx; }); Expression, ThreadPoolDevice, decltype(delete_ctx), Vectorizable,
/*Tileable*/ false>::runAsync(expr, device, std::move(delete_ctx));
return; return;
} }
@ -532,7 +536,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
struct TensorAsyncExecutorContext { struct TensorAsyncExecutorContext {
TensorAsyncExecutorContext(const Expression& expr, TensorAsyncExecutorContext(const Expression& expr,
const ThreadPoolDevice& thread_pool, const ThreadPoolDevice& thread_pool,
std::function<void()> done) DoneCallback done)
: device(thread_pool), : device(thread_pool),
evaluator(expr, thread_pool), evaluator(expr, thread_pool),
on_done(std::move(done)) {} on_done(std::move(done)) {}
@ -548,7 +552,7 @@ class TensorAsyncExecutor<Expression, ThreadPoolDevice, Vectorizable, /*Tileable
TilingContext tiling; TilingContext tiling;
private: private:
std::function<void()> on_done; DoneCallback on_done;
}; };
}; };

View File

@ -94,7 +94,7 @@ template<typename XprType, template <class> class MakePointer_ = MakePointer> cl
template<typename XprType> class TensorForcedEvalOp; template<typename XprType> class TensorForcedEvalOp;
template<typename ExpressionType, typename DeviceType> class TensorDevice; template<typename ExpressionType, typename DeviceType> class TensorDevice;
template<typename ExpressionType, typename DeviceType> class TensorAsyncDevice; template<typename ExpressionType, typename DeviceType, typename DoneCallback> class TensorAsyncDevice;
template<typename Derived, typename Device> struct TensorEvaluator; template<typename Derived, typename Device> struct TensorEvaluator;
struct NoOpOutputKernel; struct NoOpOutputKernel;
@ -168,7 +168,7 @@ template <typename Expression, typename Device,
bool Tileable = IsTileable<Device, Expression>::value> bool Tileable = IsTileable<Device, Expression>::value>
class TensorExecutor; class TensorExecutor;
template <typename Expression, typename Device, template <typename Expression, typename Device, typename DoneCallback,
bool Vectorizable = IsVectorizable<Device, Expression>::value, bool Vectorizable = IsVectorizable<Device, Expression>::value,
bool Tileable = IsTileable<Device, Expression>::value> bool Tileable = IsTileable<Device, Expression>::value>
class TensorAsyncExecutor; class TensorAsyncExecutor;

View File

@ -578,11 +578,15 @@ static void test_async_execute_unary_expr(Device d)
src.setRandom(); src.setRandom();
const auto expr = src.square(); const auto expr = src.square();
using Assign = TensorAssignOp<decltype(dst), const decltype(expr)>;
using Executor = internal::TensorAsyncExecutor<const Assign, Device,
Vectorizable, Tileable>;
Eigen::Barrier done(1); Eigen::Barrier done(1);
Executor::runAsync(Assign(dst, expr), d, [&done]() { done.Notify(); }); auto on_done = [&done]() { done.Notify(); };
using Assign = TensorAssignOp<decltype(dst), const decltype(expr)>;
using DoneCallback = decltype(on_done);
using Executor = internal::TensorAsyncExecutor<const Assign, Device, DoneCallback,
Vectorizable, Tileable>;
Executor::runAsync(Assign(dst, expr), d, on_done);
done.Wait(); done.Wait();
for (Index i = 0; i < dst.dimensions().TotalSize(); ++i) { for (Index i = 0; i < dst.dimensions().TotalSize(); ++i) {
@ -610,12 +614,15 @@ static void test_async_execute_binary_expr(Device d)
const auto expr = lhs + rhs; const auto expr = lhs + rhs;
Eigen::Barrier done(1);
auto on_done = [&done]() { done.Notify(); };
using Assign = TensorAssignOp<decltype(dst), const decltype(expr)>; using Assign = TensorAssignOp<decltype(dst), const decltype(expr)>;
using Executor = internal::TensorAsyncExecutor<const Assign, Device, using DoneCallback = decltype(on_done);
using Executor = internal::TensorAsyncExecutor<const Assign, Device, DoneCallback,
Vectorizable, Tileable>; Vectorizable, Tileable>;
Eigen::Barrier done(1); Executor::runAsync(Assign(dst, expr), d, on_done);
Executor::runAsync(Assign(dst, expr), d, [&done]() { done.Notify(); });
done.Wait(); done.Wait();
for (Index i = 0; i < dst.dimensions().TotalSize(); ++i) { for (Index i = 0; i < dst.dimensions().TotalSize(); ++i) {

View File

@ -683,34 +683,39 @@ EIGEN_DECLARE_TEST(cxx11_tensor_thread_pool)
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>()); CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>()); CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>()); CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>());
CALL_SUBTEST_3(test_async_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST_3(test_async_multithread_contraction_agrees_with_singlethread<RowMajor>()); CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<RowMajor>());
// Test EvalShardedByInnerDimContext parallelization strategy. // Test EvalShardedByInnerDimContext parallelization strategy.
CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction<ColMajor>()); CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<ColMajor>());
CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction<RowMajor>()); CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<RowMajor>());
CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>()); CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_4(test_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>()); CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());
CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction<ColMajor>());
CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction<RowMajor>()); CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<ColMajor>());
CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>()); CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<RowMajor>());
CALL_SUBTEST_4(test_async_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>()); CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());
// Exercise various cases that have been problematic in the past. // Exercise various cases that have been problematic in the past.
CALL_SUBTEST_5(test_contraction_corner_cases<ColMajor>()); CALL_SUBTEST_7(test_contraction_corner_cases<ColMajor>());
CALL_SUBTEST_5(test_contraction_corner_cases<RowMajor>()); CALL_SUBTEST_7(test_contraction_corner_cases<RowMajor>());
CALL_SUBTEST_6(test_full_contraction<ColMajor>()); CALL_SUBTEST_8(test_full_contraction<ColMajor>());
CALL_SUBTEST_6(test_full_contraction<RowMajor>()); CALL_SUBTEST_8(test_full_contraction<RowMajor>());
CALL_SUBTEST_7(test_multithreaded_reductions<ColMajor>()); CALL_SUBTEST_9(test_multithreaded_reductions<ColMajor>());
CALL_SUBTEST_7(test_multithreaded_reductions<RowMajor>()); CALL_SUBTEST_9(test_multithreaded_reductions<RowMajor>());
CALL_SUBTEST_7(test_memcpy()); CALL_SUBTEST_10(test_memcpy());
CALL_SUBTEST_7(test_multithread_random()); CALL_SUBTEST_10(test_multithread_random());
TestAllocator test_allocator; TestAllocator test_allocator;
CALL_SUBTEST_7(test_multithread_shuffle<ColMajor>(NULL)); CALL_SUBTEST_11(test_multithread_shuffle<ColMajor>(NULL));
CALL_SUBTEST_7(test_multithread_shuffle<RowMajor>(&test_allocator)); CALL_SUBTEST_11(test_multithread_shuffle<RowMajor>(&test_allocator));
CALL_SUBTEST_7(test_threadpool_allocate(&test_allocator)); CALL_SUBTEST_11(test_threadpool_allocate(&test_allocator));
// Force CMake to split this test.
// EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11
} }