diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h index e96f31537..fa329bfe6 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h @@ -178,6 +178,27 @@ template struct ConversionSubExprEval< } }; +#ifdef EIGEN_USE_THREADS +template +struct ConversionSubExprEvalAsync { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run( + Eval& impl, EvalPointerType, EvalSubExprsCallback done) { + impl.evalSubExprsIfNeededAsync(nullptr, std::move(done)); + } +}; + +template +struct ConversionSubExprEvalAsync { + static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run( + Eval& impl, EvalPointerType data, EvalSubExprsCallback done) { + impl.evalSubExprsIfNeededAsync(data, std::move(done)); + } +}; +#endif + namespace internal { template @@ -299,6 +320,16 @@ struct TensorEvaluator, Device> return ConversionSubExprEval, EvaluatorPointerType>::run(m_impl, data); } +#ifdef EIGEN_USE_THREADS + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalSubExprsIfNeededAsync( + EvaluatorPointerType data, EvalSubExprsCallback done) { + ConversionSubExprEvalAsync, + EvaluatorPointerType, + EvalSubExprsCallback>::run(m_impl, data, std::move(done)); + } +#endif + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { m_impl.cleanup(); diff --git a/unsupported/test/cxx11_tensor_thread_pool.cpp b/unsupported/test/cxx11_tensor_thread_pool.cpp index dae7b0335..b772a1d60 100644 --- a/unsupported/test/cxx11_tensor_thread_pool.cpp +++ b/unsupported/test/cxx11_tensor_thread_pool.cpp @@ -40,19 +40,19 @@ void test_multithread_elementwise() { Tensor in1(200, 30, 70); Tensor in2(200, 30, 70); - Tensor out(200, 30, 70); + Tensor out(200, 30, 70); in1.setRandom(); in2.setRandom(); Eigen::ThreadPool tp(internal::random(3, 11)); Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random(3, 11)); - out.device(thread_pool_device) = in1 + in2 * 3.14f; + out.device(thread_pool_device) = (in1 + in2 * 3.14f).cast(); for (int i = 0; i < 200; ++i) { for (int j = 0; j < 30; ++j) { for (int k = 0; k < 70; ++k) { - VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f); + VERIFY_IS_APPROX(out(i, j, k), static_cast(in1(i, j, k) + in2(i, j, k) * 3.14f)); } } } @@ -62,7 +62,7 @@ void test_async_multithread_elementwise() { Tensor in1(200, 30, 70); Tensor in2(200, 30, 70); - Tensor out(200, 30, 70); + Tensor out(200, 30, 70); in1.setRandom(); in2.setRandom(); @@ -71,13 +71,13 @@ void test_async_multithread_elementwise() Eigen::ThreadPoolDevice thread_pool_device(&tp, internal::random(3, 11)); Eigen::Barrier b(1); - out.device(thread_pool_device, [&b]() { b.Notify(); }) = in1 + in2 * 3.14f; + out.device(thread_pool_device, [&b]() { b.Notify(); }) = (in1 + in2 * 3.14f).cast(); b.Wait(); for (int i = 0; i < 200; ++i) { for (int j = 0; j < 30; ++j) { for (int k = 0; k < 70; ++k) { - VERIFY_IS_APPROX(out(i, j, k), in1(i, j, k) + in2(i, j, k) * 3.14f); + VERIFY_IS_APPROX(out(i, j, k), static_cast(in1(i, j, k) + in2(i, j, k) * 3.14f)); } } }