From 955884b87ee4aaf71cde489f16e49febf04117d1 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 2 Sep 2024 20:20:32 +0800 Subject: [PATCH] chore(workflow): max thread submit count --- api/core/app/apps/workflow/app_generator.py | 21 +++++-- api/core/app/apps/workflow/app_runner.py | 8 ++- api/core/tools/tool/workflow_tool.py | 4 +- api/core/tools/tool_engine.py | 2 + .../workflow/graph_engine/graph_engine.py | 55 ++++++++++++++++--- api/core/workflow/nodes/base_node.py | 4 +- api/core/workflow/nodes/tool/tool_node.py | 1 + api/core/workflow/workflow_entry.py | 8 ++- 8 files changed, 86 insertions(+), 17 deletions(-) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 5b7635e0e8..4347e5277b 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -4,7 +4,7 @@ import os import threading import uuid from collections.abc import Generator -from typing import Any, Literal, Union, overload +from typing import Any, Literal, Optional, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator): args: dict, invoke_from: InvokeFrom, stream: Literal[True] = True, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ) -> Generator[str, None, None]: ... @overload @@ -50,6 +52,8 @@ class WorkflowAppGenerator(BaseAppGenerator): args: dict, invoke_from: InvokeFrom, stream: Literal[False] = False, + call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ) -> dict: ... def generate( @@ -61,6 +65,7 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from: InvokeFrom, stream: bool = True, call_depth: int = 0, + workflow_thread_pool_id: Optional[str] = None ): """ Generate App response. @@ -72,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param invoke_from: invoke from source :param stream: is stream :param call_depth: call depth + :param workflow_thread_pool_id: workflow thread pool id """ inputs = args['inputs'] @@ -119,6 +125,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, invoke_from=invoke_from, stream=stream, + workflow_thread_pool_id=workflow_thread_pool_id ) def _generate( @@ -129,6 +136,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, stream: bool = True, + workflow_thread_pool_id: Optional[str] = None ) -> dict[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -139,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param application_generate_entity: application generate entity :param invoke_from: invoke from source :param stream: is stream + :param workflow_thread_pool_id: workflow thread pool id """ # init queue manager queue_manager = WorkflowAppQueueManager( @@ -153,7 +162,8 @@ class WorkflowAppGenerator(BaseAppGenerator): 'flask_app': current_app._get_current_object(), # type: ignore 'application_generate_entity': application_generate_entity, 'queue_manager': queue_manager, - 'context': contextvars.copy_context() + 'context': contextvars.copy_context(), + 'workflow_thread_pool_id': workflow_thread_pool_id }) worker_thread.start() @@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator): def _generate_worker(self, flask_app: Flask, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, - context: contextvars.Context) -> None: + context: contextvars.Context, + workflow_thread_pool_id: Optional[str] = None) -> None: """ Generate worker in a new thread. :param flask_app: Flask app :param application_generate_entity: application generate entity :param queue_manager: queue manager + :param workflow_thread_pool_id: workflow thread pool id :return: """ for var, val in context.items(): @@ -246,7 +258,8 @@ class WorkflowAppGenerator(BaseAppGenerator): # workflow app runner = WorkflowAppRunner( application_generate_entity=application_generate_entity, - queue_manager=queue_manager + queue_manager=queue_manager, + workflow_thread_pool_id=workflow_thread_pool_id ) runner.run() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index f8c8c7ddc3..836ce16c86 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -1,6 +1,6 @@ import logging import os -from typing import cast +from typing import Optional, cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig @@ -29,14 +29,17 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): def __init__( self, application_generate_entity: WorkflowAppGenerateEntity, - queue_manager: AppQueueManager + queue_manager: AppQueueManager, + workflow_thread_pool_id: Optional[str] = None ) -> None: """ :param application_generate_entity: application generate entity :param queue_manager: application queue manager + :param workflow_thread_pool_id: workflow thread pool id """ self.application_generate_entity = application_generate_entity self.queue_manager = queue_manager + self.workflow_thread_pool_id = workflow_thread_pool_id def run(self) -> None: """ @@ -116,6 +119,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): invoke_from=self.application_generate_entity.invoke_from, call_depth=self.application_generate_entity.call_depth, variable_pool=variable_pool, + thread_pool_id=self.workflow_thread_pool_id ) generator = workflow_entry.run( diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 12e498e76d..15e915628e 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -1,7 +1,7 @@ import json import logging from copy import deepcopy -from typing import Any, Union +from typing import Any, Optional, Union from core.file.file_obj import FileTransferMethod, FileVar from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType @@ -18,6 +18,7 @@ class WorkflowTool(Tool): version: str workflow_entities: dict[str, Any] workflow_call_depth: int + thread_pool_id: Optional[str] = None label: str @@ -57,6 +58,7 @@ class WorkflowTool(Tool): invoke_from=self.runtime.invoke_from, stream=False, call_depth=self.workflow_call_depth + 1, + workflow_thread_pool_id=self.thread_pool_id ) data = result.get('data', {}) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 0e15151aa4..6c0e906628 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -128,6 +128,7 @@ class ToolEngine: user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_call_depth: int, + thread_pool_id: Optional[str] = None ) -> list[ToolInvokeMessage]: """ Workflow invokes the tool with the given arguments. @@ -141,6 +142,7 @@ class ToolEngine: if isinstance(tool, WorkflowTool): tool.workflow_call_depth = workflow_call_depth + 1 + tool.thread_pool_id = thread_pool_id if tool.runtime and tool.runtime.runtime_parameters: tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index cf3ccf387d..65d9ab8446 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -1,12 +1,12 @@ import logging import queue import time +import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait from typing import Any, Optional from flask import Flask, current_app -from uritemplate.variable import VariableValue from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom @@ -15,7 +15,7 @@ from core.workflow.entities.node_entities import ( NodeType, UserFrom, ) -from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.variable_pool import VariablePool, VariableValue from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( BaseIterationEvent, @@ -47,7 +47,28 @@ from models.workflow import WorkflowNodeExecutionStatus, WorkflowType logger = logging.getLogger(__name__) +class GraphEngineThreadPool(ThreadPoolExecutor): + def __init__(self, max_workers=None, thread_name_prefix='', + initializer=None, initargs=(), max_submit_count=100) -> None: + super().__init__(max_workers, thread_name_prefix, initializer, initargs) + self.max_submit_count = max_submit_count + self.submit_count = 0 + + def submit(self, fn, *args, **kwargs): + self.submit_count += 1 + self.check_is_full() + + return super().submit(fn, *args, **kwargs) + + def check_is_full(self) -> None: + print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") + if self.submit_count > self.max_submit_count: + raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") + + class GraphEngine: + workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {} + def __init__( self, tenant_id: str, @@ -62,10 +83,26 @@ class GraphEngine: graph_config: Mapping[str, Any], variable_pool: VariablePool, max_execution_steps: int, - max_execution_time: int + max_execution_time: int, + thread_pool_id: Optional[str] = None ) -> None: + thread_pool_max_submit_count = 100 + thread_pool_max_workers = 10 + ## init thread pool - self.thread_pool = ThreadPoolExecutor(max_workers=10) + if thread_pool_id: + if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.") + + self.thread_pool_id = thread_pool_id + self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id] + self.is_main_thread_pool = False + else: + self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count) + self.thread_pool_id = str(uuid.uuid4()) + self.is_main_thread_pool = True + GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool + self.graph = graph self.init_params = GraphInitParams( tenant_id=tenant_id, @@ -144,6 +181,9 @@ class GraphEngine: logger.exception("Unknown Error when graph running") yield GraphRunFailedEvent(error=str(e)) raise e + finally: + if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping: + del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] def _run( self, @@ -196,7 +236,8 @@ class GraphEngine: graph_init_params=self.init_params, graph=self.graph, graph_runtime_state=self.graph_runtime_state, - previous_node_id=previous_node_id + previous_node_id=previous_node_id, + thread_pool_id=self.thread_pool_id ) try: @@ -357,10 +398,10 @@ class GraphEngine: node_id = edge_mappings[0].target_node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} related parallel not found.') + raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.') node_title = node_config.get('data', {}).get('title') - raise GraphRunFailedError(f'Node {node_title} related parallel not found.') + raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.') parallel = self.graph.parallel_mapping.get(parallel_id) if not parallel: diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 3807bbb2d5..b9912314f1 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -21,7 +21,8 @@ class BaseNode(ABC): graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState, - previous_node_id: Optional[str] = None) -> None: + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None) -> None: self.id = id self.tenant_id = graph_init_params.tenant_id self.app_id = graph_init_params.app_id @@ -35,6 +36,7 @@ class BaseNode(ABC): self.graph = graph self.graph_runtime_state = graph_runtime_state self.previous_node_id = previous_node_id + self.thread_pool_id = thread_pool_id node_id = config.get("id") if not node_id: diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 7f889e654b..feedeb6dad 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -66,6 +66,7 @@ class ToolNode(BaseNode): user_id=self.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, + thread_pool_id=self.thread_pool_id, ) except Exception as e: return NodeRunResult( diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index b84e46f280..a359bd606e 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -44,7 +44,8 @@ class WorkflowEntry: user_from: UserFrom, invoke_from: InvokeFrom, call_depth: int, - variable_pool: VariablePool + variable_pool: VariablePool, + thread_pool_id: Optional[str] = None ) -> None: """ Init workflow entry @@ -59,7 +60,9 @@ class WorkflowEntry: :param invoke_from: invoke from :param call_depth: call depth :param variable_pool: variable pool + :param thread_pool_id: thread pool id """ + # check call depth workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH if call_depth > workflow_call_max_depth: raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) @@ -78,7 +81,8 @@ class WorkflowEntry: graph_config=graph_config, variable_pool=variable_pool, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, - max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME + max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME, + thread_pool_id=thread_pool_id ) def run(