mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Add support for asynchronous evaluation of tensor casting expressions.
This commit is contained in:
parent
28b6786498
commit
1d5af0693c
@ -178,6 +178,27 @@ template <typename Eval, typename EvalPointerType> struct ConversionSubExprEval<
|
||||
}
|
||||
};
|
||||
|
||||
#ifdef EIGEN_USE_THREADS
|
||||
template <bool SameType, typename Eval, typename EvalPointerType,
|
||||
typename EvalSubExprsCallback>
|
||||
struct ConversionSubExprEvalAsync {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(
|
||||
Eval& impl, EvalPointerType, EvalSubExprsCallback done) {
|
||||
impl.evalSubExprsIfNeededAsync(nullptr, std::move(done));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Eval, typename EvalPointerType,
|
||||
typename EvalSubExprsCallback>
|
||||
struct ConversionSubExprEvalAsync<true, Eval, EvalPointerType,
|
||||
EvalSubExprsCallback> {
|
||||
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(
|
||||
Eval& impl, EvalPointerType data, EvalSubExprsCallback done) {
|
||||
impl.evalSubExprsIfNeededAsync(data, std::move(done));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <typename SrcType, typename TargetType, bool IsSameT>
|
||||
@ -299,6 +320,16 @@ struct TensorEvaluator<const TensorConversionOp<TargetType, ArgType>, Device>
|
||||
return ConversionSubExprEval<IsSameType, TensorEvaluator<ArgType, Device>, EvaluatorPointerType>::run(m_impl, data);
|
||||
}
|
||||
|
||||
#ifdef EIGEN_USE_THREADS
|
||||
template <typename EvalSubExprsCallback>
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync(
|
||||
EvaluatorPointerType data, EvalSubExprsCallback done) {
|
||||
ConversionSubExprEvalAsync<IsSameType, TensorEvaluator<ArgType, Device>,
|
||||
EvaluatorPointerType,
|
||||
EvalSubExprsCallback>::run(m_impl, data, std::move(done));
|
||||
}
|
||||
#endif
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup()
|
||||
{
|
||||
m_impl.cleanup();
|
||||
|
@ -40,19 +40,19 @@ void test_multithread_elementwise()
|
||||
{
|
||||
Tensor<float, 3> in1(200, 30, 70);
|
||||
Tensor<float, 3> in2(200, 30, 70);
|
||||
Tensor<float, 3> out(200, 30, 70);
|
||||
Tensor<double, 3> out(200, 30, 70);
|
||||
|
||||
in1.setRandom();
|
||||
in2.setRandom();
|
||||
|
||||
Eigen::ThreadPool tp(internal::random<int>(3, 11));
|
||||
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
|
||||
out.device(thread_pool_device) = in1 + in2 * 3.14f;
|
||||
out.device(thread_pool_device) = (in1 + in2 * 3.14f).cast<double>();
|
||||
|
||||
for (int i = 0; i < 200; ++i) {
|
||||
for (int j = 0; j < 30; ++j) {
|
||||
for (int k = 0; k < 70; ++k) {
|
||||
VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f);
|
||||
VERIFY_IS_APPROX(out(i, j, k), static_cast<double>(in1(i, j, k) + in2(i, j, k) * 3.14f));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -62,7 +62,7 @@ void test_async_multithread_elementwise()
|
||||
{
|
||||
Tensor<float, 3> in1(200, 30, 70);
|
||||
Tensor<float, 3> in2(200, 30, 70);
|
||||
Tensor<float, 3> out(200, 30, 70);
|
||||
Tensor<double, 3> out(200, 30, 70);
|
||||
|
||||
in1.setRandom();
|
||||
in2.setRandom();
|
||||
@ -71,13 +71,13 @@ void test_async_multithread_elementwise()
|
||||
Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random<int>(3, 11));
|
||||
|
||||
Eigen::Barrier b(1);
|
||||
out.device(thread_pool_device, [&b]() { b.Notify(); }) = in1 + in2 * 3.14f;
|
||||
out.device(thread_pool_device, [&b]() { b.Notify(); }) = (in1 + in2 * 3.14f).cast<double>();
|
||||
b.Wait();
|
||||
|
||||
for (int i = 0; i < 200; ++i) {
|
||||
for (int j = 0; j < 30; ++j) {
|
||||
for (int k = 0; k < 70; ++k) {
|
||||
VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f);
|
||||
VERIFY_IS_APPROX(out(i, j, k), static_cast<double>(in1(i, j, k) + in2(i, j, k) * 3.14f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user