mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Merged in paultucker/eigen (pull request PR-431)
Optional ThreadPoolDevice allocator Approved-by: Benoit Steiner <benoit.steiner.goog@gmail.com>
This commit is contained in:
commit
93b9e36e10
@ -91,18 +91,31 @@ static EIGEN_STRONG_INLINE void wait_until_ready(SyncType* n) {
|
||||
}
|
||||
}
|
||||
|
||||
// An abstract interface to a device specific memory allocator.
|
||||
class Allocator {
|
||||
public:
|
||||
virtual ~Allocator() {}
|
||||
EIGEN_DEVICE_FUNC virtual void* allocate(size_t num_bytes) const = 0;
|
||||
EIGEN_DEVICE_FUNC virtual void deallocate(void* buffer) const = 0;
|
||||
};
|
||||
|
||||
// Build a thread pool device on top the an existing pool of threads.
|
||||
struct ThreadPoolDevice {
|
||||
// The ownership of the thread pool remains with the caller.
|
||||
ThreadPoolDevice(ThreadPoolInterface* pool, int num_cores) : pool_(pool), num_threads_(num_cores) { }
|
||||
ThreadPoolDevice(ThreadPoolInterface* pool, int num_cores, Allocator* allocator = nullptr)
|
||||
: pool_(pool), num_threads_(num_cores), allocator_(allocator) { }
|
||||
|
||||
EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
|
||||
return internal::aligned_malloc(num_bytes);
|
||||
return allocator_ ? allocator_->allocate(num_bytes)
|
||||
: internal::aligned_malloc(num_bytes);
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void deallocate(void* buffer) const {
|
||||
internal::aligned_free(buffer);
|
||||
if (allocator_) {
|
||||
allocator_->deallocate(buffer);
|
||||
} else {
|
||||
internal::aligned_free(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
EIGEN_STRONG_INLINE void* allocate_temp(size_t num_bytes) const {
|
||||
@ -275,9 +288,13 @@ struct ThreadPoolDevice {
|
||||
// Thread pool accessor.
|
||||
ThreadPoolInterface* getPool() const { return pool_; }
|
||||
|
||||
// Allocator accessor.
|
||||
Allocator* allocator() const { return allocator_; }
|
||||
|
||||
private:
|
||||
ThreadPoolInterface* pool_;
|
||||
int num_threads_;
|
||||
Allocator* allocator_;
|
||||
};
|
||||
|
||||
|
||||
|
@ -16,6 +16,25 @@
|
||||
|
||||
using Eigen::Tensor;
|
||||
|
||||
class TestAllocator : public Allocator {
|
||||
public:
|
||||
~TestAllocator() override {}
|
||||
EIGEN_DEVICE_FUNC void* allocate(size_t num_bytes) const override {
|
||||
const_cast<TestAllocator*>(this)->alloc_count_++;
|
||||
return internal::aligned_malloc(num_bytes);
|
||||
}
|
||||
EIGEN_DEVICE_FUNC void deallocate(void* buffer) const override {
|
||||
const_cast<TestAllocator*>(this)->dealloc_count_++;
|
||||
internal::aligned_free(buffer);
|
||||
}
|
||||
|
||||
int alloc_count() const { return alloc_count_; }
|
||||
int dealloc_count() const { return dealloc_count_; }
|
||||
|
||||
private:
|
||||
int alloc_count_ = 0;
|
||||
int dealloc_count_ = 0;
|
||||
};
|
||||
|
||||
void test_multithread_elementwise()
|
||||
{
|
||||
@ -374,14 +393,14 @@ void test_multithread_random()
|
||||
}
|
||||
|
||||
template<int DataLayout>
|
||||
void test_multithread_shuffle()
|
||||
void test_multithread_shuffle(Allocator* allocator)
|
||||
{
|
||||
Tensor<float, 4, DataLayout> tensor(17,5,7,11);
|
||||
tensor.setRandom();
|
||||
|
||||
const int num_threads = internal::random<int>(2, 11);
|
||||
ThreadPool threads(num_threads);
|
||||
Eigen::ThreadPoolDevice device(&threads, num_threads);
|
||||
Eigen::ThreadPoolDevice device(&threads, num_threads, allocator);
|
||||
|
||||
Tensor<float, 4, DataLayout> shuffle(7,5,11,17);
|
||||
array<ptrdiff_t, 4> shuffles = {{2,1,3,0}};
|
||||
@ -398,6 +417,21 @@ void test_multithread_shuffle()
|
||||
}
|
||||
}
|
||||
|
||||
void test_threadpool_allocate(TestAllocator* allocator)
|
||||
{
|
||||
const int num_threads = internal::random<int>(2, 11);
|
||||
const int num_allocs = internal::random<int>(2, 11);
|
||||
ThreadPool threads(num_threads);
|
||||
Eigen::ThreadPoolDevice device(&threads, num_threads, allocator);
|
||||
|
||||
for (int a = 0; a < num_allocs; ++a) {
|
||||
void* ptr = device.allocate(512);
|
||||
device.deallocate(ptr);
|
||||
}
|
||||
VERIFY(allocator != nullptr);
|
||||
VERIFY_IS_EQUAL(allocator->alloc_count(), num_allocs);
|
||||
VERIFY_IS_EQUAL(allocator->dealloc_count(), num_allocs);
|
||||
}
|
||||
|
||||
EIGEN_DECLARE_TEST(cxx11_tensor_thread_pool)
|
||||
{
|
||||
@ -424,6 +458,9 @@ EIGEN_DECLARE_TEST(cxx11_tensor_thread_pool)
|
||||
|
||||
CALL_SUBTEST_6(test_memcpy());
|
||||
CALL_SUBTEST_6(test_multithread_random());
|
||||
CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>());
|
||||
CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>());
|
||||
|
||||
TestAllocator test_allocator;
|
||||
CALL_SUBTEST_6(test_multithread_shuffle<ColMajor>(nullptr));
|
||||
CALL_SUBTEST_6(test_multithread_shuffle<RowMajor>(&test_allocator));
|
||||
CALL_SUBTEST_6(test_threadpool_allocate(&test_allocator));
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user