Sharded the tensor thread pool test

This commit is contained in:
Benoit Steiner 2016-01-30 10:43:57 -08:00
parent ba27c8a7de
commit d0db95f730

View File

@ -17,7 +17,7 @@
using Eigen::Tensor; using Eigen::Tensor;
static void test_multithread_elementwise() void test_multithread_elementwise()
{ {
Tensor<float, 3> in1(2,3,7); Tensor<float, 3> in1(2,3,7);
Tensor<float, 3> in2(2,3,7); Tensor<float, 3> in2(2,3,7);
@ -40,7 +40,7 @@ static void test_multithread_elementwise()
} }
static void test_multithread_compound_assignment() void test_multithread_compound_assignment()
{ {
Tensor<float, 3> in1(2,3,7); Tensor<float, 3> in1(2,3,7);
Tensor<float, 3> in2(2,3,7); Tensor<float, 3> in2(2,3,7);
@ -64,7 +64,7 @@ static void test_multithread_compound_assignment()
} }
template<int DataLayout> template<int DataLayout>
static void test_multithread_contraction() void test_multithread_contraction()
{ {
Tensor<float, 4, DataLayout> t_left(30, 50, 37, 31); Tensor<float, 4, DataLayout> t_left(30, 50, 37, 31);
Tensor<float, 5, DataLayout> t_right(37, 31, 70, 2, 10); Tensor<float, 5, DataLayout> t_right(37, 31, 70, 2, 10);
@ -99,7 +99,7 @@ static void test_multithread_contraction()
} }
template<int DataLayout> template<int DataLayout>
static void test_contraction_corner_cases() void test_contraction_corner_cases()
{ {
Tensor<float, 2, DataLayout> t_left(32, 500); Tensor<float, 2, DataLayout> t_left(32, 500);
Tensor<float, 2, DataLayout> t_right(32, 28*28); Tensor<float, 2, DataLayout> t_right(32, 28*28);
@ -186,7 +186,7 @@ static void test_contraction_corner_cases()
} }
template<int DataLayout> template<int DataLayout>
static void test_multithread_contraction_agrees_with_singlethread() { void test_multithread_contraction_agrees_with_singlethread() {
int contract_size = internal::random<int>(1, 5000); int contract_size = internal::random<int>(1, 5000);
Tensor<float, 3, DataLayout> left(internal::random<int>(1, 80), Tensor<float, 3, DataLayout> left(internal::random<int>(1, 80),
@ -229,7 +229,7 @@ static void test_multithread_contraction_agrees_with_singlethread() {
template<int DataLayout> template<int DataLayout>
static void test_multithreaded_reductions() { void test_multithreaded_reductions() {
const int num_threads = internal::random<int>(3, 11); const int num_threads = internal::random<int>(3, 11);
ThreadPool thread_pool(num_threads); ThreadPool thread_pool(num_threads);
Eigen::ThreadPoolDevice thread_pool_device(&thread_pool, num_threads); Eigen::ThreadPoolDevice thread_pool_device(&thread_pool, num_threads);
@ -251,7 +251,7 @@ static void test_multithreaded_reductions() {
} }
static void test_memcpy() { void test_memcpy() {
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
const int num_threads = internal::random<int>(3, 11); const int num_threads = internal::random<int>(3, 11);
@ -270,7 +270,7 @@ static void test_memcpy() {
} }
static void test_multithread_random() void test_multithread_random()
{ {
Eigen::ThreadPool tp(2); Eigen::ThreadPool tp(2);
Eigen::ThreadPoolDevice device(&tp, 2); Eigen::ThreadPoolDevice device(&tp, 2);
@ -281,23 +281,22 @@ static void test_multithread_random()
void test_cxx11_tensor_thread_pool() void test_cxx11_tensor_thread_pool()
{ {
CALL_SUBTEST(test_multithread_elementwise()); CALL_SUBTEST_1(test_multithread_elementwise());
CALL_SUBTEST(test_multithread_compound_assignment()); CALL_SUBTEST_1(test_multithread_compound_assignment());
CALL_SUBTEST(test_multithread_contraction<ColMajor>()); CALL_SUBTEST_2(test_multithread_contraction<ColMajor>());
CALL_SUBTEST(test_multithread_contraction<RowMajor>()); CALL_SUBTEST_2(test_multithread_contraction<RowMajor>());
CALL_SUBTEST(test_multithread_contraction_agrees_with_singlethread<ColMajor>()); CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<ColMajor>());
CALL_SUBTEST(test_multithread_contraction_agrees_with_singlethread<RowMajor>()); CALL_SUBTEST_3(test_multithread_contraction_agrees_with_singlethread<RowMajor>());
// Exercise various cases that have been problematic in the past. // Exercise various cases that have been problematic in the past.
CALL_SUBTEST(test_contraction_corner_cases<ColMajor>()); CALL_SUBTEST_4(test_contraction_corner_cases<ColMajor>());
CALL_SUBTEST(test_contraction_corner_cases<RowMajor>()); CALL_SUBTEST_4(test_contraction_corner_cases<RowMajor>());
CALL_SUBTEST(test_multithreaded_reductions<ColMajor>()); CALL_SUBTEST_5(test_multithreaded_reductions<ColMajor>());
CALL_SUBTEST(test_multithreaded_reductions<RowMajor>()); CALL_SUBTEST_5(test_multithreaded_reductions<RowMajor>());
CALL_SUBTEST(test_memcpy()); CALL_SUBTEST_6(test_memcpy());
CALL_SUBTEST_6(test_multithread_random());
CALL_SUBTEST(test_multithread_random());
} }