mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 13:15:58 +08:00
chore(workflow): max thread submit count
This commit is contained in:
parent
5ca9df65de
commit
955884b87e
@ -4,7 +4,7 @@ import os
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Literal, Union, overload
|
||||
from typing import Any, Literal, Optional, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@ -40,6 +40,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[True] = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> Generator[str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@ -50,6 +52,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
args: dict,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: Literal[False] = False,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict: ...
|
||||
|
||||
def generate(
|
||||
@ -61,6 +65,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Generate App response.
|
||||
@ -72,6 +77,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param call_depth: call depth
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
inputs = args['inputs']
|
||||
|
||||
@ -119,6 +125,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=invoke_from,
|
||||
stream=stream,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
def _generate(
|
||||
@ -129,6 +136,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
invoke_from: InvokeFrom,
|
||||
stream: bool = True,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> dict[str, Any] | Generator[str, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@ -139,6 +147,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param application_generate_entity: application generate entity
|
||||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
@ -153,7 +162,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
'flask_app': current_app._get_current_object(), # type: ignore
|
||||
'application_generate_entity': application_generate_entity,
|
||||
'queue_manager': queue_manager,
|
||||
'context': contextvars.copy_context()
|
||||
'context': contextvars.copy_context(),
|
||||
'workflow_thread_pool_id': workflow_thread_pool_id
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
@ -231,12 +241,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
def _generate_worker(self, flask_app: Flask,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context) -> None:
|
||||
context: contextvars.Context,
|
||||
workflow_thread_pool_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
:param flask_app: Flask app
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
:return:
|
||||
"""
|
||||
for var, val in context.items():
|
||||
@ -246,7 +258,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
# workflow app
|
||||
runner = WorkflowAppRunner(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager
|
||||
queue_manager=queue_manager,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id
|
||||
)
|
||||
|
||||
runner.run()
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||
@ -29,14 +29,17 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
def __init__(
|
||||
self,
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
queue_manager: AppQueueManager
|
||||
queue_manager: AppQueueManager,
|
||||
workflow_thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param workflow_thread_pool_id: workflow thread pool id
|
||||
"""
|
||||
self.application_generate_entity = application_generate_entity
|
||||
self.queue_manager = queue_manager
|
||||
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||
|
||||
def run(self) -> None:
|
||||
"""
|
||||
@ -116,6 +119,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
variable_pool=variable_pool,
|
||||
thread_pool_id=self.workflow_thread_pool_id
|
||||
)
|
||||
|
||||
generator = workflow_entry.run(
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileTransferMethod, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||
@ -18,6 +18,7 @@ class WorkflowTool(Tool):
|
||||
version: str
|
||||
workflow_entities: dict[str, Any]
|
||||
workflow_call_depth: int
|
||||
thread_pool_id: Optional[str] = None
|
||||
|
||||
label: str
|
||||
|
||||
@ -57,6 +58,7 @@ class WorkflowTool(Tool):
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
stream=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
workflow_thread_pool_id=self.thread_pool_id
|
||||
)
|
||||
|
||||
data = result.get('data', {})
|
||||
|
@ -128,6 +128,7 @@ class ToolEngine:
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
@ -141,6 +142,7 @@ class ToolEngine:
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from uritemplate.variable import VariableValue
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -15,7 +15,7 @@ from core.workflow.entities.node_entities import (
|
||||
NodeType,
|
||||
UserFrom,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseIterationEvent,
|
||||
@ -47,7 +47,28 @@ from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
def __init__(self, max_workers=None, thread_name_prefix='',
|
||||
initializer=None, initargs=(), max_submit_count=100) -> None:
|
||||
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
self.submit_count += 1
|
||||
self.check_is_full()
|
||||
|
||||
return super().submit(fn, *args, **kwargs)
|
||||
|
||||
def check_is_full(self) -> None:
|
||||
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
|
||||
if self.submit_count > self.max_submit_count:
|
||||
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
|
||||
class GraphEngine:
|
||||
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -62,10 +83,26 @@ class GraphEngine:
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
thread_pool_max_submit_count = 100
|
||||
thread_pool_max_workers = 10
|
||||
|
||||
## init thread pool
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=10)
|
||||
if thread_pool_id:
|
||||
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
|
||||
self.is_main_thread_pool = False
|
||||
else:
|
||||
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
|
||||
self.thread_pool_id = str(uuid.uuid4())
|
||||
self.is_main_thread_pool = True
|
||||
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
|
||||
|
||||
self.graph = graph
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
@ -144,6 +181,9 @@ class GraphEngine:
|
||||
logger.exception("Unknown Error when graph running")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
raise e
|
||||
finally:
|
||||
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@ -196,7 +236,8 @@ class GraphEngine:
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id
|
||||
)
|
||||
|
||||
try:
|
||||
@ -357,10 +398,10 @@ class GraphEngine:
|
||||
node_id = edge_mappings[0].target_node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} related parallel not found.')
|
||||
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
node_title = node_config.get('data', {}).get('title')
|
||||
raise GraphRunFailedError(f'Node {node_title} related parallel not found.')
|
||||
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
|
@ -21,7 +21,8 @@ class BaseNode(ABC):
|
||||
graph_init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_node_id: Optional[str] = None) -> None:
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None) -> None:
|
||||
self.id = id
|
||||
self.tenant_id = graph_init_params.tenant_id
|
||||
self.app_id = graph_init_params.app_id
|
||||
@ -35,6 +36,7 @@ class BaseNode(ABC):
|
||||
self.graph = graph
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.previous_node_id = previous_node_id
|
||||
self.thread_pool_id = thread_pool_id
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
|
@ -66,6 +66,7 @@ class ToolNode(BaseNode):
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
thread_pool_id=self.thread_pool_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
|
@ -44,7 +44,8 @@ class WorkflowEntry:
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
variable_pool: VariablePool
|
||||
variable_pool: VariablePool,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Init workflow entry
|
||||
@ -59,7 +60,9 @@ class WorkflowEntry:
|
||||
:param invoke_from: invoke from
|
||||
:param call_depth: call depth
|
||||
:param variable_pool: variable pool
|
||||
:param thread_pool_id: thread pool id
|
||||
"""
|
||||
# check call depth
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
@ -78,7 +81,8 @@ class WorkflowEntry:
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=thread_pool_id
|
||||
)
|
||||
|
||||
def run(
|
||||
|
Loading…
x
Reference in New Issue
Block a user