Add async support for 'chip' and 'extract_volume_patches'

This commit is contained in:
adambanas 2024-06-27 09:34:19 +02:00
parent d791d48859
commit 33d0937c6b
3 changed files with 126 additions and 26 deletions

View File

@ -182,6 +182,13 @@ struct TensorEvaluator<const TensorChippingOp<DimId, ArgType>, Device> {
return true;
}
#ifdef EIGEN_USE_THREADS
template <typename EvalSubExprsCallback>
EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType /*data*/, EvalSubExprsCallback done) {
m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); });
}
#endif // EIGEN_USE_THREADS
EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {

View File

@ -365,6 +365,13 @@ struct TensorEvaluator<const TensorVolumePatchOp<Planes, Rows, Cols, ArgType>, D
return true;
}
#ifdef EIGEN_USE_THREADS
template <typename EvalSubExprsCallback>
EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(EvaluatorPointerType /*data*/, EvalSubExprsCallback done) {
m_impl.evalSubExprsIfNeededAsync(nullptr, [done](bool) { done(true); });
}
#endif // EIGEN_USE_THREADS
EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {

View File

@ -80,6 +80,86 @@ void test_async_multithread_elementwise() {
}
}
void test_multithread_chip() {
Tensor<float, 5> in(2, 3, 5, 7, 11);
Tensor<float, 4> out(3, 5, 7, 11);
in.setRandom();
Eigen::ThreadPool tp(internal::random<int>(3, 11));
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
out.device(thread_pool_device) = in.chip(1, 0);
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 5; ++j) {
for (int k = 0; k < 7; ++k) {
for (int l = 0; l < 11; ++l) {
VERIFY_IS_EQUAL(out(i, j, k, l), in(1, i, j, k, l));
}
}
}
}
}
void test_async_multithread_chip() {
Tensor<float, 5> in(2, 3, 5, 7, 11);
Tensor<float, 4> out(3, 5, 7, 11);
in.setRandom();
Eigen::ThreadPool tp(internal::random<int>(3, 11));
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
Eigen::Barrier b(1);
out.device(thread_pool_device, [&b]() { b.Notify(); }) = in.chip(1, 0);
b.Wait();
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 5; ++j) {
for (int k = 0; k < 7; ++k) {
for (int l = 0; l < 11; ++l) {
VERIFY_IS_EQUAL(out(i, j, k, l), in(1, i, j, k, l));
}
}
}
}
}
void test_multithread_volume_patch() {
Tensor<float, 5> in(4, 2, 3, 5, 7);
Tensor<float, 6> out(4, 1, 1, 1, 2 * 3 * 5, 7);
in.setRandom();
Eigen::ThreadPool tp(internal::random<int>(3, 11));
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
out.device(thread_pool_device) = in.extract_volume_patches(1, 1, 1);
for (int i = 0; i < in.size(); ++i) {
VERIFY_IS_EQUAL(in.data()[i], out.data()[i]);
}
}
void test_async_multithread_volume_patch() {
Tensor<float, 5> in(4, 2, 3, 5, 7);
Tensor<float, 6> out(4, 1, 1, 1, 2 * 3 * 5, 7);
in.setRandom();
Eigen::ThreadPool tp(internal::random<int>(3, 11));
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
Eigen::Barrier b(1);
out.device(thread_pool_device, [&b]() { b.Notify(); }) = in.extract_volume_patches(1, 1, 1);
b.Wait();
for (int i = 0; i < in.size(); ++i) {
VERIFY_IS_EQUAL(in.data()[i], out.data()[i]);
}
}
void test_multithread_compound_assignment() {
Tensor<float, 3> in1(2, 3, 7);
Tensor<float, 3> in2(2, 3, 7);
@ -648,43 +728,49 @@ EIGEN_DECLARE_TEST(cxx11_tensor_thread_pool) {
CALL_SUBTEST_2(test_multithread_contraction<ColMajor>());
CALL_SUBTEST_2(test_multithread_contraction<RowMajor>());
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_3(test_multithread_contraction_with_output_kernel<RowMajor>());
CALL_SUBTEST_3(test_multithread_chip());
CALL_SUBTEST_3(test_async_multithread_chip());
CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST_4(test_async_multithread_contraction_agrees_with_singlethread<RowMajor>());
CALL_SUBTEST_4(test_multithread_volume_patch());
CALL_SUBTEST_4(test_async_multithread_volume_patch());
CALL_SUBTEST_5(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST_5(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
CALL_SUBTEST_5(test_multithread_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_5(test_multithread_contraction_with_output_kernel<RowMajor>());
CALL_SUBTEST_6(test_async_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST_6(test_async_multithread_contraction_agrees_with_singlethread<RowMajor>());
// Test EvalShardedByInnerDimContext parallelization strategy.
CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<ColMajor>());
CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction<RowMajor>());
CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_5(test_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());
CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction<ColMajor>());
CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction<RowMajor>());
CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_7(test_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());
CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<ColMajor>());
CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction<RowMajor>());
CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_6(test_async_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());
CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction<ColMajor>());
CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction<RowMajor>());
CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction_with_output_kernel<ColMajor>());
CALL_SUBTEST_8(test_async_sharded_by_inner_dim_contraction_with_output_kernel<RowMajor>());
// Exercise various cases that have been problematic in the past.
CALL_SUBTEST_7(test_contraction_corner_cases<ColMajor>());
CALL_SUBTEST_7(test_contraction_corner_cases<RowMajor>());
CALL_SUBTEST_9(test_contraction_corner_cases<ColMajor>());
CALL_SUBTEST_9(test_contraction_corner_cases<RowMajor>());
CALL_SUBTEST_8(test_full_contraction<ColMajor>());
CALL_SUBTEST_8(test_full_contraction<RowMajor>());
CALL_SUBTEST_10(test_full_contraction<ColMajor>());
CALL_SUBTEST_10(test_full_contraction<RowMajor>());
CALL_SUBTEST_9(test_multithreaded_reductions<ColMajor>());
CALL_SUBTEST_9(test_multithreaded_reductions<RowMajor>());
CALL_SUBTEST_11(test_multithreaded_reductions<ColMajor>());
CALL_SUBTEST_11(test_multithreaded_reductions<RowMajor>());
CALL_SUBTEST_10(test_memcpy());
CALL_SUBTEST_10(test_multithread_random());
CALL_SUBTEST_12(test_memcpy());
CALL_SUBTEST_12(test_multithread_random());
TestAllocator test_allocator;
CALL_SUBTEST_11(test_multithread_shuffle<ColMajor>(NULL));
CALL_SUBTEST_11(test_multithread_shuffle<RowMajor>(&test_allocator));
CALL_SUBTEST_11(test_threadpool_allocate(&test_allocator));
CALL_SUBTEST_13(test_multithread_shuffle<ColMajor>(NULL));
CALL_SUBTEST_13(test_multithread_shuffle<RowMajor>(&test_allocator));
CALL_SUBTEST_13(test_threadpool_allocate(&test_allocator));
// Force CMake to split this test.
// EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11
// EIGEN_SUFFIXES;1;2;3;4;5;6;7;8;9;10;11;12;13
}