chore(workflow): max thread submit count

This commit is contained in:
takatost 2024-09-02 20:20:32 +08:00
parent 5ca9df65de
commit 955884b87e
8 changed files with 86 additions and 17 deletions

View File

@ -4,7 +4,7 @@ import os
import threading
import uuid
from collections.abc import Generator
from typing import Any, Literal, Union, overload
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[True] = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> Generator[str, None, None]: ...
@overload
@ -50,6 +52,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
args: dict,
invoke_from: InvokeFrom,
stream: Literal[False] = False,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
) -> dict: ...
def generate(
@ -61,6 +65,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
stream: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None
):
"""
Generate App response.
@ -72,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
:param call_depth: call depth
:param workflow_thread_pool_id: workflow thread pool id
"""
inputs = args['inputs']
@ -119,6 +125,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
stream=stream,
workflow_thread_pool_id=workflow_thread_pool_id
)
def _generate(
@ -129,6 +136,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
stream: bool = True,
workflow_thread_pool_id: Optional[str] = None
) -> dict[str, Any] | Generator[str, None, None]:
"""
Generate App response.
@ -139,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
:param stream: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
# init queue manager
queue_manager = WorkflowAppQueueManager(
@ -153,7 +162,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
'flask_app': current_app._get_current_object(), # type: ignore
'application_generate_entity': application_generate_entity,
'queue_manager': queue_manager,
'context': contextvars.copy_context()
'context': contextvars.copy_context(),
'workflow_thread_pool_id': workflow_thread_pool_id
})
worker_thread.start()
@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
def _generate_worker(self, flask_app: Flask,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context) -> None:
context: contextvars.Context,
workflow_thread_pool_id: Optional[str] = None) -> None:
"""
Generate worker in a new thread.
:param flask_app: Flask app
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
@ -246,7 +258,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
# workflow app
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id
)
runner.run()

View File

@ -1,6 +1,6 @@
import logging
import os
from typing import cast
from typing import Optional, cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
@ -29,14 +29,17 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
def __init__(
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager
queue_manager: AppQueueManager,
workflow_thread_pool_id: Optional[str] = None
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
def run(self) -> None:
"""
@ -116,6 +119,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
invoke_from=self.application_generate_entity.invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
thread_pool_id=self.workflow_thread_pool_id
)
generator = workflow_entry.run(

View File

@ -1,7 +1,7 @@
import json
import logging
from copy import deepcopy
from typing import Any, Union
from typing import Any, Optional, Union
from core.file.file_obj import FileTransferMethod, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
@ -18,6 +18,7 @@ class WorkflowTool(Tool):
version: str
workflow_entities: dict[str, Any]
workflow_call_depth: int
thread_pool_id: Optional[str] = None
label: str
@ -57,6 +58,7 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from,
stream=False,
call_depth=self.workflow_call_depth + 1,
workflow_thread_pool_id=self.thread_pool_id
)
data = result.get('data', {})

View File

@ -128,6 +128,7 @@ class ToolEngine:
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
thread_pool_id: Optional[str] = None
) -> list[ToolInvokeMessage]:
"""
Workflow invokes the tool with the given arguments.
@ -141,6 +142,7 @@ class ToolEngine:
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}

View File

@ -1,12 +1,12 @@
import logging
import queue
import time
import uuid
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional
from flask import Flask, current_app
from uritemplate.variable import VariableValue
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
@ -15,7 +15,7 @@ from core.workflow.entities.node_entities import (
NodeType,
UserFrom,
)
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
BaseIterationEvent,
@ -47,7 +47,28 @@ from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
logger = logging.getLogger(__name__)
class GraphEngineThreadPool(ThreadPoolExecutor):
def __init__(self, max_workers=None, thread_name_prefix='',
initializer=None, initargs=(), max_submit_count=100) -> None:
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
self.max_submit_count = max_submit_count
self.submit_count = 0
def submit(self, fn, *args, **kwargs):
self.submit_count += 1
self.check_is_full()
return super().submit(fn, *args, **kwargs)
def check_is_full(self) -> None:
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
if self.submit_count > self.max_submit_count:
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
class GraphEngine:
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
def __init__(
self,
tenant_id: str,
@ -62,10 +83,26 @@ class GraphEngine:
graph_config: Mapping[str, Any],
variable_pool: VariablePool,
max_execution_steps: int,
max_execution_time: int
max_execution_time: int,
thread_pool_id: Optional[str] = None
) -> None:
thread_pool_max_submit_count = 100
thread_pool_max_workers = 10
## init thread pool
self.thread_pool = ThreadPoolExecutor(max_workers=10)
if thread_pool_id:
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
self.thread_pool_id = thread_pool_id
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
self.is_main_thread_pool = False
else:
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
self.thread_pool_id = str(uuid.uuid4())
self.is_main_thread_pool = True
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
@ -144,6 +181,9 @@ class GraphEngine:
logger.exception("Unknown Error when graph running")
yield GraphRunFailedEvent(error=str(e))
raise e
finally:
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
def _run(
self,
@ -196,7 +236,8 @@ class GraphEngine:
graph_init_params=self.init_params,
graph=self.graph,
graph_runtime_state=self.graph_runtime_state,
previous_node_id=previous_node_id
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id
)
try:
@ -357,10 +398,10 @@ class GraphEngine:
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} related parallel not found.')
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
node_title = node_config.get('data', {}).get('title')
raise GraphRunFailedError(f'Node {node_title} related parallel not found.')
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:

View File

@ -21,7 +21,8 @@ class BaseNode(ABC):
graph_init_params: GraphInitParams,
graph: Graph,
graph_runtime_state: GraphRuntimeState,
previous_node_id: Optional[str] = None) -> None:
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None) -> None:
self.id = id
self.tenant_id = graph_init_params.tenant_id
self.app_id = graph_init_params.app_id
@ -35,6 +36,7 @@ class BaseNode(ABC):
self.graph = graph
self.graph_runtime_state = graph_runtime_state
self.previous_node_id = previous_node_id
self.thread_pool_id = thread_pool_id
node_id = config.get("id")
if not node_id:

View File

@ -66,6 +66,7 @@ class ToolNode(BaseNode):
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
)
except Exception as e:
return NodeRunResult(

View File

@ -44,7 +44,8 @@ class WorkflowEntry:
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
variable_pool: VariablePool
variable_pool: VariablePool,
thread_pool_id: Optional[str] = None
) -> None:
"""
Init workflow entry
@ -59,7 +60,9 @@ class WorkflowEntry:
:param invoke_from: invoke from
:param call_depth: call depth
:param variable_pool: variable pool
:param thread_pool_id: thread pool id
"""
# check call depth
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
@ -78,7 +81,8 @@ class WorkflowEntry:
graph_config=graph_config,
variable_pool=variable_pool,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id
)
def run(