chore(workflow): max thread submit count

This commit is contained in:
takatost 2024-09-02 20:20:32 +08:00
parent 5ca9df65de
commit 955884b87e
8 changed files with 86 additions and 17 deletions

View File

@ -4,7 +4,7 @@ import os
import threading import threading
import uuid import uuid
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Literal, Union, overload from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[True] = True, stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> Generator[str, None, None]: ... ) -> Generator[str, None, None]: ...
@overload @overload
@ -50,6 +52,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict, args: dict,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: Literal[False] = False, stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> dict: ... ) -> dict: ...
def generate( def generate(
@ -61,6 +65,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
call_depth: int = 0, call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
): ):
""" """
Generate App response. Generate App response.
@ -72,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
:param call_depth: call depth :param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
""" """
inputs = args['inputs'] inputs = args['inputs']
@ -119,6 +125,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
invoke_from=invoke_from, invoke_from=invoke_from,
stream=stream, stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id
) )
def _generate( def _generate(
@ -129,6 +136,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
stream: bool = True, stream: bool = True,
workflow_thread_pool_id: Optional[str] = None
) -> dict[str, Any] | Generator[str, None, None]: ) -> dict[str, Any] | Generator[str, None, None]:
""" """
Generate App response. Generate App response.
@ -139,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param stream: is stream :param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
""" """
# init queue manager # init queue manager
queue_manager = WorkflowAppQueueManager( queue_manager = WorkflowAppQueueManager(
@ -153,7 +162,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
'flask_app': current_app._get_current_object(), # type: ignore 'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity, 'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager, 'queue_manager': queue_manager,
'context': contextvars.copy_context() 'context': contextvars.copy_context(),
'workflow_thread_pool_id': workflow_thread_pool_id
}) })
worker_thread.start() worker_thread.start()
@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
def _generate_worker(self, flask_app: Flask, def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
context: contextvars.Context) -> None: context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
:param flask_app: Flask app :param flask_app: Flask app
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param queue_manager: queue manager :param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return: :return:
""" """
for var, val in context.items(): for var, val in context.items():
@ -246,7 +258,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
# workflow app # workflow app
runner = WorkflowAppRunner( runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id
) )
runner.run() runner.run()

View File

@ -1,6 +1,6 @@
import logging import logging
import os import os
from typing import cast from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
@ -29,14 +29,17 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
def __init__( def __init__(
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None
) -> None: ) -> None:
""" """
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
""" """
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def run(self) -> None: def run(self) -> None:
""" """
@ -116,6 +119,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from, invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth, call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool, variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id
) )
generator = workflow_entry.run( generator = workflow_entry.run(

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
from copy import deepcopy from copy import deepcopy
from typing import Any, Union from typing import Any, Optional, Union
from core.file.file_obj import FileTransferMethod, FileVar from core.file.file_obj import FileTransferMethod, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
@ -18,6 +18,7 @@ class WorkflowTool(Tool):
version: str version: str
workflow_entities: dict[str, Any] workflow_entities: dict[str, Any]
workflow_call_depth: int workflow_call_depth: int
thread_pool_id: Optional[str] = None
label: str label: str
@ -57,6 +58,7 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from, invoke_from=self.runtime.invoke_from,
stream=False, stream=False,
call_depth=self.workflow_call_depth + 1, call_depth=self.workflow_call_depth + 1,
workflow_thread_pool_id=self.thread_pool_id
) )
data = result.get('data', {}) data = result.get('data', {})

View File

@ -128,6 +128,7 @@ class ToolEngine:
user_id: str, user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler, workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int, workflow_call_depth: int,
thread_pool_id: Optional[str] = None
) -> list[ToolInvokeMessage]: ) -> list[ToolInvokeMessage]:
""" """
Workflow invokes the tool with the given arguments. Workflow invokes the tool with the given arguments.
@ -141,6 +142,7 @@ class ToolEngine:
if isinstance(tool, WorkflowTool): if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1 tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if tool.runtime and tool.runtime.runtime_parameters: if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}

View File

@ -1,12 +1,12 @@
import logging import logging
import queue import queue
import time import time
import uuid
from collections.abc import Generator, Mapping from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional from typing import Any, Optional
from flask import Flask, current_app from flask import Flask, current_app
from uritemplate.variable import VariableValue
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
@ -15,7 +15,7 @@ from core.workflow.entities.node_entities import (
NodeType, NodeType,
UserFrom, UserFrom,
) )
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
BaseIterationEvent, BaseIterationEvent,
@ -47,7 +47,28 @@ from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GraphEngineThreadPool(ThreadPoolExecutor):
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=(), max_submit_count=100) -> None:
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
self.submit_count = 0
def submit(self, fn, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
return super().submit(fn, *args, **kwargs)
def check_is_full(self) -> None:
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
if self.submit_count > self.max_submit_count:
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
class GraphEngine: class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__( def __init__(
self, self,
tenant_id: str, tenant_id: str,
@ -62,10 +83,26 @@ class GraphEngine:
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
variable_pool: VariablePool, variable_pool: VariablePool,
max_execution_steps: int, max_execution_steps: int,
max_execution_time: int max_execution_time: int,
thread_pool_id: Optional[str] = None
) -> None: ) -> None:
thread_pool_max_submit_count = 100
thread_pool_max_workers = 10
## init thread pool ## init thread pool
self.thread_pool = ThreadPoolExecutor(max_workers=10) if thread_pool_id:
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
self.thread_pool_id = thread_pool_id
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
self.is_main_thread_pool = False
else:
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
self.thread_pool_id = str(uuid.uuid4())
self.is_main_thread_pool = True
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph self.graph = graph
self.init_params = GraphInitParams( self.init_params = GraphInitParams(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -144,6 +181,9 @@ class GraphEngine:
logger.exception("Unknown Error when graph running") logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e)) yield GraphRunFailedEvent(error=str(e))
raise e raise e
finally:
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
def _run( def _run(
self, self,
@ -196,7 +236,8 @@ class GraphEngine:
graph_init_params=self.init_params, graph_init_params=self.init_params,
graph=self.graph, graph=self.graph,
graph_runtime_state=self.graph_runtime_state, graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id
) )
try: try:
@ -357,10 +398,10 @@ class GraphEngine:
node_id = edge_mappings[0].target_node_id node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id) node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config: if not node_config:
raise GraphRunFailedError(f'Node {node_id} related parallel not found.') raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
node_title = node_config.get('data', {}).get('title') node_title = node_config.get('data', {}).get('title')
raise GraphRunFailedError(f'Node {node_title} related parallel not found.') raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
parallel = self.graph.parallel_mapping.get(parallel_id) parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel: if not parallel:

View File

@ -21,7 +21,8 @@ class BaseNode(ABC):
graph_init_params: GraphInitParams, graph_init_params: GraphInitParams,
graph: Graph, graph: Graph,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None) -> None: previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None) -> None:
self.id = id self.id = id
self.tenant_id = graph_init_params.tenant_id self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id self.app_id = graph_init_params.app_id
@ -35,6 +36,7 @@ class BaseNode(ABC):
self.graph = graph self.graph = graph
self.graph_runtime_state = graph_runtime_state self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
node_id = config.get("id") node_id = config.get("id")
if not node_id: if not node_id:

View File

@ -66,6 +66,7 @@ class ToolNode(BaseNode):
user_id=self.user_id, user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth, workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
) )
except Exception as e: except Exception as e:
return NodeRunResult( return NodeRunResult(

View File

@ -44,7 +44,8 @@ class WorkflowEntry:
user_from: UserFrom, user_from: UserFrom,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
call_depth: int, call_depth: int,
variable_pool: VariablePool variable_pool: VariablePool,
thread_pool_id: Optional[str] = None
) -> None: ) -> None:
""" """
Init workflow entry Init workflow entry
@ -59,7 +60,9 @@ class WorkflowEntry:
:param invoke_from: invoke from :param invoke_from: invoke from
:param call_depth: call depth :param call_depth: call depth
:param variable_pool: variable pool :param variable_pool: variable pool
:param thread_pool_id: thread pool id
""" """
# check call depth
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth: if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth)) raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
@ -78,7 +81,8 @@ class WorkflowEntry:
graph_config=graph_config, graph_config=graph_config,
variable_pool=variable_pool, variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id
) )
def run( def run(