Fix expression evaluation heuristic for TensorSliceOp

This commit is contained in:
Eugene Zhulenev 2019-07-09 12:10:26 -07:00
parent 23b958818e
commit 3cd148f983

View File

@ -479,9 +479,12 @@ class TensorSlicingOp : public TensorBase<TensorSlicingOp<StartIndices, Sizes, X
// Fixme: figure out the exact threshold // Fixme: figure out the exact threshold
namespace { namespace {
template <typename Index, typename Device> struct MemcpyTriggerForSlicing { template <typename Index, typename Device, bool BlockAccess> struct MemcpyTriggerForSlicing {
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { } EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { }
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > threshold_; } EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const {
const bool prefer_block_evaluation = BlockAccess && total > 32*1024;
return !prefer_block_evaluation && contiguous > threshold_;
}
private: private:
Index threshold_; Index threshold_;
@ -490,18 +493,18 @@ template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
// It is very expensive to start the memcpy kernel on GPU: we therefore only // It is very expensive to start the memcpy kernel on GPU: we therefore only
// use it for large copies. // use it for large copies.
#ifdef EIGEN_USE_GPU #ifdef EIGEN_USE_GPU
template <typename Index> struct MemcpyTriggerForSlicing<Index, GpuDevice> { template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, GpuDevice, BlockAccess> {
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { } EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { }
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; } EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
}; };
#endif #endif
// It is very expensive to start the memcpy kernel on GPU: we therefore only // It is very expensive to start the memcpy kernel on GPU: we therefore only
// use it for large copies. // use it for large copies.
#ifdef EIGEN_USE_SYCL #ifdef EIGEN_USE_SYCL
template <typename Index> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice> { template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice, BlockAccess> {
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const SyclDevice&) { } EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const SyclDevice&) { }
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; } EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
}; };
#endif #endif
@ -592,8 +595,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
m_impl.evalSubExprsIfNeeded(NULL); m_impl.evalSubExprsIfNeeded(NULL);
if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization
&& data && m_impl.data() && data && m_impl.data()) {
&& !BlockAccess) {
Index contiguous_values = 1; Index contiguous_values = 1;
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) { if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
for (int i = 0; i < NumDims; ++i) { for (int i = 0; i < NumDims; ++i) {
@ -611,8 +613,8 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
} }
} }
// Use memcpy if it's going to be faster than using the regular evaluation. // Use memcpy if it's going to be faster than using the regular evaluation.
const MemcpyTriggerForSlicing<Index, Device> trigger(m_device); const MemcpyTriggerForSlicing<Index, Device, BlockAccess> trigger(m_device);
if (trigger(contiguous_values)) { if (trigger(internal::array_prod(dimensions()), contiguous_values)) {
EvaluatorPointerType src = (EvaluatorPointerType)m_impl.data(); EvaluatorPointerType src = (EvaluatorPointerType)m_impl.data();
for (Index i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) { for (Index i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) {
Index offset = srcCoeff(i); Index offset = srcCoeff(i);