Add test coverage for ThreadPoolDevice optional allocator.

This commit is contained in:
Paul Tucker 2018-07-19 17:43:44 -07:00
parent 4e9848fa86
commit d4afccde5a
2 changed files with 48 additions and 4 deletions

View File

@ -91,6 +91,13 @@ static EIGEN_STRONG_INLINE void wait_until_ready(SyncType* n) {
}
}
// An abstract interface to a device specific memory allocator.
class Allocator {
public:
virtual ~Allocator() {}
EIGEN_DEVICE_FUNC virtual void* allocate(size_t num_bytes) const = 0;
EIGEN_DEVICE_FUNC virtual void deallocate(void* buffer) const = 0;
};
// Build a thread pool device on top the an existing pool of threads.
struct ThreadPoolDevice {

View File

@ -16,6 +16,25 @@
using Eigen::Tensor;
class TestAllocator : public Allocator {
public:
~TestAllocator() override {}
EIGEN_DEVICE_FUNC void* allocate(size_t num_bytes) const override {
const_cast<TestAllocator*>(this)->alloc_count_++;
return internal::aligned_malloc(num_bytes);
}
EIGEN_DEVICE_FUNC void deallocate(void* buffer) const override {
const_cast<TestAllocator*>(this)->dealloc_count_++;
internal::aligned_free(buffer);
}
int alloc_count() const { return alloc_count_; }
int dealloc_count() const { return dealloc_count_; }
private:
int alloc_count_ = 0;
int dealloc_count_ = 0;
};
void test_multithread_elementwise()
{
@ -320,14 +339,14 @@ void test_multithread_random()
}
template<int DataLayout>
void test_multithread_shuffle()
void test_multithread_shuffle(Allocator* allocator)
{
Tensor<float, 4, DataLayout> tensor(17,5,7,11);
tensor.setRandom();
const int num_threads = internal::random<int>(2, 11);
ThreadPool threads(num_threads);
Eigen::ThreadPoolDevice device(&threads, num_threads);
Eigen::ThreadPoolDevice device(&threads, num_threads, allocator);
Tensor<float, 4, DataLayout> shuffle(7,5,11,17);
array<ptrdiff_t, 4> shuffles = {{2,1,3,0}};
@ -344,6 +363,21 @@ void test_multithread_shuffle()
}
}
void test_threadpool_allocate(TestAllocator* allocator)
{
const int num_threads = internal::random<int>(2, 11);
const int num_allocs = internal::random<int>(2, 11);
ThreadPool threads(num_threads);
Eigen::ThreadPoolDevice device(&threads, num_threads, allocator);
for (int a = 0; a < num_allocs; ++a) {
void* ptr = device.allocate(512);
device.deallocate(ptr);
}
VERIFY(allocator != nullptr);
VERIFY_IS_EQUAL(allocator->alloc_count(), num_allocs);
VERIFY_IS_EQUAL(allocator->dealloc_count(), num_allocs);
}
void test_cxx11_tensor_thread_pool()
{
@ -368,6 +402,9 @@ void test_cxx11_tensor_thread_pool()
CALL_SUBTEST_6(test_memcpy());
CALL_SUBTEST_6(test_multithread_random());
CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>());
CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>());
TestAllocator test_allocator;
CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>(nullptr));
CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>(&test_allocator));
CALL_SUBTEST_6(test_threadpool_allocate(&test_allocator));
}