mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Fix expression evaluation heuristic for TensorSliceOp
This commit is contained in:
parent
23b958818e
commit
3cd148f983
@ -479,9 +479,12 @@ class TensorSlicingOp : public TensorBase<TensorSlicingOp<StartIndices, Sizes, X
|
|||||||
|
|
||||||
// Fixme: figure out the exact threshold
|
// Fixme: figure out the exact threshold
|
||||||
namespace {
|
namespace {
|
||||||
template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
|
template <typename Index, typename Device, bool BlockAccess> struct MemcpyTriggerForSlicing {
|
||||||
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { }
|
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const Device& device) : threshold_(2 * device.numThreads()) { }
|
||||||
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > threshold_; }
|
EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const {
|
||||||
|
const bool prefer_block_evaluation = BlockAccess && total > 32*1024;
|
||||||
|
return !prefer_block_evaluation && contiguous > threshold_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Index threshold_;
|
Index threshold_;
|
||||||
@ -490,18 +493,18 @@ template <typename Index, typename Device> struct MemcpyTriggerForSlicing {
|
|||||||
// It is very expensive to start the memcpy kernel on GPU: we therefore only
|
// It is very expensive to start the memcpy kernel on GPU: we therefore only
|
||||||
// use it for large copies.
|
// use it for large copies.
|
||||||
#ifdef EIGEN_USE_GPU
|
#ifdef EIGEN_USE_GPU
|
||||||
template <typename Index> struct MemcpyTriggerForSlicing<Index, GpuDevice> {
|
template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, GpuDevice, BlockAccess> {
|
||||||
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { }
|
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const GpuDevice&) { }
|
||||||
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; }
|
EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// It is very expensive to start the memcpy kernel on GPU: we therefore only
|
// It is very expensive to start the memcpy kernel on GPU: we therefore only
|
||||||
// use it for large copies.
|
// use it for large copies.
|
||||||
#ifdef EIGEN_USE_SYCL
|
#ifdef EIGEN_USE_SYCL
|
||||||
template <typename Index> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice> {
|
template <typename Index, bool BlockAccess> struct MemcpyTriggerForSlicing<Index, Eigen::SyclDevice, BlockAccess> {
|
||||||
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const SyclDevice&) { }
|
EIGEN_DEVICE_FUNC MemcpyTriggerForSlicing(const SyclDevice&) { }
|
||||||
EIGEN_DEVICE_FUNC bool operator ()(Index val) const { return val > 4*1024*1024; }
|
EIGEN_DEVICE_FUNC bool operator ()(Index total, Index contiguous) const { return contiguous > 4*1024*1024; }
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -592,8 +595,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
|
|||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(EvaluatorPointerType data) {
|
||||||
m_impl.evalSubExprsIfNeeded(NULL);
|
m_impl.evalSubExprsIfNeeded(NULL);
|
||||||
if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization
|
if (!NumTraits<typename internal::remove_const<Scalar>::type>::RequireInitialization
|
||||||
&& data && m_impl.data()
|
&& data && m_impl.data()) {
|
||||||
&& !BlockAccess) {
|
|
||||||
Index contiguous_values = 1;
|
Index contiguous_values = 1;
|
||||||
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
for (int i = 0; i < NumDims; ++i) {
|
for (int i = 0; i < NumDims; ++i) {
|
||||||
@ -611,8 +613,8 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Use memcpy if it's going to be faster than using the regular evaluation.
|
// Use memcpy if it's going to be faster than using the regular evaluation.
|
||||||
const MemcpyTriggerForSlicing<Index, Device> trigger(m_device);
|
const MemcpyTriggerForSlicing<Index, Device, BlockAccess> trigger(m_device);
|
||||||
if (trigger(contiguous_values)) {
|
if (trigger(internal::array_prod(dimensions()), contiguous_values)) {
|
||||||
EvaluatorPointerType src = (EvaluatorPointerType)m_impl.data();
|
EvaluatorPointerType src = (EvaluatorPointerType)m_impl.data();
|
||||||
for (Index i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) {
|
for (Index i = 0; i < internal::array_prod(dimensions()); i += contiguous_values) {
|
||||||
Index offset = srcCoeff(i);
|
Index offset = srcCoeff(i);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user