Use signed integers more consistently to encode the number of threads to use to evaluate a tensor expression.

This commit is contained in:
Benoit Steiner 2016-06-09 08:25:22 -07:00
parent 8f92c26319
commit 14a112ee15
2 changed files with 11 additions and 11 deletions

View File

@ -202,7 +202,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
// across k dimension. // across k dimension.
const TensorOpCost cost = const TensorOpCost cost =
contractionCost(m, n, bm, bn, bk, shard_by_col, false); contractionCost(m, n, bm, bn, bk, shard_by_col, false);
Index num_threads = TensorCostModel<ThreadPoolDevice>::numThreads( int num_threads = TensorCostModel<ThreadPoolDevice>::numThreads(
static_cast<double>(n) * m, cost, this->m_device.numThreads()); static_cast<double>(n) * m, cost, this->m_device.numThreads());
// TODO(dvyukov): this is a stop-gap to prevent regressions while the cost // TODO(dvyukov): this is a stop-gap to prevent regressions while the cost
@ -301,7 +301,7 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
class Context { class Context {
public: public:
Context(const Device& device, int num_threads, LhsMapper& lhs, Context(const Device& device, int num_threads, LhsMapper& lhs,
RhsMapper& rhs, Scalar* buffer, Index m, Index n, Index k, Index bm, RhsMapper& rhs, Scalar* buffer, Index tm, Index tn, Index tk, Index bm,
Index bn, Index bk, Index nm, Index nn, Index nk, Index gm, Index bn, Index bk, Index nm, Index nn, Index nk, Index gm,
Index gn, Index nm0, Index nn0, bool shard_by_col, Index gn, Index nm0, Index nn0, bool shard_by_col,
bool parallel_pack) bool parallel_pack)
@ -309,13 +309,13 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
lhs_(lhs), lhs_(lhs),
rhs_(rhs), rhs_(rhs),
buffer_(buffer), buffer_(buffer),
output_(buffer, m), output_(buffer, tm),
num_threads_(num_threads), num_threads_(num_threads),
shard_by_col_(shard_by_col), shard_by_col_(shard_by_col),
parallel_pack_(parallel_pack), parallel_pack_(parallel_pack),
m_(m), m_(tm),
n_(n), n_(tn),
k_(k), k_(tk),
bm_(bm), bm_(bm),
bn_(bn), bn_(bn),
bk_(bk), bk_(bk),

View File

@ -106,7 +106,7 @@ static EIGEN_STRONG_INLINE void wait_until_ready(SyncType* n) {
// Build a thread pool device on top the an existing pool of threads. // Build a thread pool device on top the an existing pool of threads.
struct ThreadPoolDevice { struct ThreadPoolDevice {
// The ownership of the thread pool remains with the caller. // The ownership of the thread pool remains with the caller.
ThreadPoolDevice(ThreadPoolInterface* pool, size_t num_cores) : pool_(pool), num_threads_(num_cores) { } ThreadPoolDevice(ThreadPoolInterface* pool, int num_cores) : pool_(pool), num_threads_(num_cores) { }
EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const { EIGEN_STRONG_INLINE void* allocate(size_t num_bytes) const {
return internal::aligned_malloc(num_bytes); return internal::aligned_malloc(num_bytes);
@ -130,7 +130,7 @@ struct ThreadPoolDevice {
::memset(buffer, c, n); ::memset(buffer, c, n);
} }
EIGEN_STRONG_INLINE size_t numThreads() const { EIGEN_STRONG_INLINE int numThreads() const {
return num_threads_; return num_threads_;
} }
@ -182,7 +182,7 @@ struct ThreadPoolDevice {
std::function<void(Index, Index)> f) const { std::function<void(Index, Index)> f) const {
typedef TensorCostModel<ThreadPoolDevice> CostModel; typedef TensorCostModel<ThreadPoolDevice> CostModel;
if (n <= 1 || numThreads() == 1 || if (n <= 1 || numThreads() == 1 ||
CostModel::numThreads(n, cost, numThreads()) == 1) { CostModel::numThreads(n, cost, static_cast<int>(numThreads())) == 1) {
f(0, n); f(0, n);
return; return;
} }
@ -242,7 +242,7 @@ struct ThreadPoolDevice {
// Recursively divide size into halves until we reach block_size. // Recursively divide size into halves until we reach block_size.
// Division code rounds mid to block_size, so we are guaranteed to get // Division code rounds mid to block_size, so we are guaranteed to get
// block_count leaves that do actual computations. // block_count leaves that do actual computations.
Barrier barrier(block_count); Barrier barrier(static_cast<unsigned int>(block_count));
std::function<void(Index, Index)> handleRange; std::function<void(Index, Index)> handleRange;
handleRange = [=, &handleRange, &barrier, &f](Index first, Index last) { handleRange = [=, &handleRange, &barrier, &f](Index first, Index last) {
if (last - first <= block_size) { if (last - first <= block_size) {
@ -268,7 +268,7 @@ struct ThreadPoolDevice {
private: private:
ThreadPoolInterface* pool_; ThreadPoolInterface* pool_;
size_t num_threads_; int num_threads_;
}; };