mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 09:26:02 +08:00
add run logics
This commit is contained in:
parent
d77b689a99
commit
821e09b259
@ -22,6 +22,12 @@ class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""random uuid parallel id"""
|
||||
|
||||
start_from_node_id: str
|
||||
"""start from node id"""
|
||||
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if exists"""
|
||||
|
||||
@ -33,6 +39,9 @@ class Graph(BaseModel):
|
||||
node_ids: list[str] = Field(default_factory=list)
|
||||
"""graph node ids"""
|
||||
|
||||
node_id_config_mapping: dict[str, dict] = Field(default_factory=list)
|
||||
"""node configs mapping (node id: node config)"""
|
||||
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
|
||||
"""graph edge mapping (source node id: edges)"""
|
||||
|
||||
@ -102,6 +111,7 @@ class Graph(BaseModel):
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
@ -110,6 +120,8 @@ class Graph(BaseModel):
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
@ -129,6 +141,8 @@ class Graph(BaseModel):
|
||||
node_id=root_node_id
|
||||
)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
@ -143,6 +157,7 @@ class Graph(BaseModel):
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
@ -243,7 +258,10 @@ class Graph(BaseModel):
|
||||
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
|
||||
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
|
||||
|
||||
parallel = GraphParallel(parent_parallel_id=parent_parallel_id)
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel_id
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
@ -252,10 +270,20 @@ class Graph(BaseModel):
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
end_to_node_id: Optional[str] = None
|
||||
for branch_node_id, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
if not end_to_node_id and edge_mapping.get(node_id):
|
||||
node_edges = edge_mapping[node_id]
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
|
||||
end_to_node_id = target_node_id
|
||||
|
||||
if end_to_node_id:
|
||||
parallel.end_to_node_id = end_to_node_id
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
|
@ -1,4 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -19,7 +18,7 @@ class GraphRuntimeState(BaseModel):
|
||||
|
||||
variable_pool: VariablePool
|
||||
|
||||
start_at: Optional[float] = None
|
||||
start_at: float
|
||||
total_tokens: int = 0
|
||||
node_run_steps: int = 0
|
||||
|
||||
|
@ -1,15 +1,25 @@
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Optional, cast
|
||||
|
||||
from flask import current_app
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from extensions.ext_database import db
|
||||
|
||||
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphEngine:
|
||||
@ -30,7 +40,8 @@ class GraphEngine:
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
variable_pool=variable_pool
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
@ -40,52 +51,217 @@ class GraphEngine:
|
||||
|
||||
self.callbacks = callbacks
|
||||
|
||||
def run(self) -> Generator:
|
||||
self.graph_runtime_state.start_at = time.perf_counter()
|
||||
def run_in_block_mode(self):
|
||||
# TODO convert generator to result
|
||||
pass
|
||||
|
||||
# def next_node_ids(self, node_state_id: str) -> list[NextGraphNode]:
|
||||
# """
|
||||
# Get next node ids
|
||||
#
|
||||
# :param node_state_id: source node state id
|
||||
# """
|
||||
# # get current node ids in state
|
||||
# node_run_state = self.graph_runtime_state.node_run_state
|
||||
# graph = self.graph
|
||||
# if not node_run_state.routes:
|
||||
# return [NextGraphNode(node_id=graph.root_node_id)]
|
||||
#
|
||||
# route_final_graph_edges: list[GraphEdge] = []
|
||||
# for route in route_state.routes[graph.root_node_id]:
|
||||
# graph_edges = graph.edge_mapping.get(route.node_id)
|
||||
# if not graph_edges:
|
||||
# continue
|
||||
#
|
||||
# for edge in graph_edges:
|
||||
# if edge.target_node_id not in route_state.routes:
|
||||
# route_final_graph_edges.append(edge)
|
||||
#
|
||||
# next_graph_nodes = []
|
||||
# for route_final_graph_edge in route_final_graph_edges:
|
||||
# node_id = route_final_graph_edge.target_node_id
|
||||
# # check condition
|
||||
# if route_final_graph_edge.run_condition:
|
||||
# result = ConditionManager.get_condition_handler(
|
||||
# run_condition=route_final_graph_edge.run_condition
|
||||
# ).check(
|
||||
# source_node_id=route_final_graph_edge.source_node_id,
|
||||
# target_node_id=route_final_graph_edge.target_node_id,
|
||||
# graph=self
|
||||
# )
|
||||
#
|
||||
# if not result:
|
||||
# continue
|
||||
#
|
||||
# parallel = None
|
||||
# if route_final_graph_edge.target_node_id in graph.node_parallel_mapping:
|
||||
# parallel = graph.parallel_mapping[graph.node_parallel_mapping[node_id]]
|
||||
#
|
||||
# next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
|
||||
#
|
||||
# return next_graph_nodes
|
||||
def run(self) -> Generator:
|
||||
# TODO trigger graph run start event
|
||||
|
||||
try:
|
||||
# TODO run graph
|
||||
rst = self._run(start_node_id=self.graph.root_node_id)
|
||||
except GraphRunFailedError as e:
|
||||
# TODO self._graph_run_failed(
|
||||
# error=e.error,
|
||||
# callbacks=callbacks
|
||||
# )
|
||||
pass
|
||||
except Exception as e:
|
||||
# TODO self._workflow_run_failed(
|
||||
# error=str(e),
|
||||
# callbacks=callbacks
|
||||
# )
|
||||
pass
|
||||
|
||||
# TODO trigger graph run success event
|
||||
|
||||
yield rst
|
||||
|
||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None):
|
||||
next_node_id = start_node_id
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
max_execution_time=self.max_execution_time
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# run node TODO generator
|
||||
yield from self._run_node(node_id=next_node_id)
|
||||
|
||||
# todo if failed, break
|
||||
|
||||
# get next node ids
|
||||
edge_mappings = self.graph.edge_mapping.get(next_node_id)
|
||||
if not edge_mappings:
|
||||
break
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
next_node_id = edge_mappings[0].target_node_id
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
else:
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
final_node_id = None
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
run_condition=edge.run_condition
|
||||
).check(
|
||||
source_node_id=edge.source_node_id,
|
||||
target_node_id=edge.target_node_id,
|
||||
graph=self.graph
|
||||
)
|
||||
|
||||
if result:
|
||||
final_node_id = edge.target_node_id
|
||||
break
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
else:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
|
||||
if not parallel_id:
|
||||
raise GraphRunFailedError('Node related parallel not found.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError('Parallel not found.')
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# new thread
|
||||
futures = []
|
||||
for edge in edge_mappings:
|
||||
futures.append(thread_pool.submit(
|
||||
self._run_parallel_node,
|
||||
flask_app=current_app._get_current_object(),
|
||||
parallel_start_node_id=edge.source_node_id,
|
||||
q=q
|
||||
))
|
||||
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
# TODO tag event with parallel id
|
||||
yield event
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id)
|
||||
if not in_parallel_id:
|
||||
q.put(None)
|
||||
return
|
||||
|
||||
# run node TODO generator
|
||||
rst = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
in_parallel_id=in_parallel_id
|
||||
)
|
||||
|
||||
if not rst:
|
||||
q.put(None)
|
||||
return
|
||||
|
||||
for item in rst:
|
||||
q.put(item)
|
||||
|
||||
q.put(None)
|
||||
except Exception:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(self, node_id: str) -> Generator:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# get node config
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError('Node not found.')
|
||||
|
||||
# todo convert to specific node
|
||||
|
||||
# todo trigger node run start event
|
||||
|
||||
db.session.close()
|
||||
|
||||
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
|
||||
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
|
||||
try:
|
||||
# run node
|
||||
rst = node.run(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
graph=self.graph,
|
||||
callbacks=self.callbacks
|
||||
)
|
||||
|
||||
yield from rst
|
||||
|
||||
# todo record state
|
||||
except GenerateTaskStoppedException as e:
|
||||
# TODO yield failed
|
||||
# todo trigger node run failed event
|
||||
pass
|
||||
except Exception as e:
|
||||
# logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
||||
# TODO yield failed
|
||||
# todo trigger node run failed event
|
||||
pass
|
||||
|
||||
# todo trigger node run success event
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
:param start_at: start time
|
||||
:param max_execution_time: max execution time
|
||||
:return:
|
||||
"""
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
@ -16,7 +16,6 @@ from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, Work
|
||||
from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||
@ -24,7 +23,6 @@ from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||
from core.workflow.nodes.iteration.entities import IterationState
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
@ -152,103 +150,6 @@ class WorkflowEntry:
|
||||
|
||||
return rst
|
||||
|
||||
def _recursively_add_edges(self, graph: Graph,
|
||||
source_node_config: dict,
|
||||
edges_mapping: dict,
|
||||
nodes_mapping: dict,
|
||||
root_node_configs: list[dict]) -> None:
|
||||
"""
|
||||
Add edges
|
||||
|
||||
:param source_node_config: source node config
|
||||
:param edges_mapping: edges mapping
|
||||
:param nodes_mapping: nodes mapping
|
||||
:param root_node_configs: root node configs
|
||||
"""
|
||||
source_node_id = source_node_config.get('id')
|
||||
if not source_node_id:
|
||||
return
|
||||
|
||||
for edge_config in edges_mapping.get(source_node_id, []):
|
||||
target_node_id = edge_config.get('target')
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
target_node_config = nodes_mapping.get(target_node_id)
|
||||
if not target_node_config:
|
||||
continue
|
||||
|
||||
sub_graph: Optional[Graph] = None
|
||||
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
|
||||
target_node_cls = None
|
||||
if target_node_type:
|
||||
target_node_cls = node_classes.get(target_node_type)
|
||||
if not target_node_cls:
|
||||
raise Exception(f'Node class not found for node type: {target_node_type}')
|
||||
|
||||
if target_node_cls and issubclass(target_node_cls, IterableNodeMixin):
|
||||
# find iteration/loop sub nodes that have no predecessor node
|
||||
sub_graph_root_node_config = None
|
||||
for root_node_config in root_node_configs:
|
||||
if root_node_config.get('parentId') == target_node_id:
|
||||
sub_graph_root_node_config = root_node_config
|
||||
break
|
||||
|
||||
if sub_graph_root_node_config:
|
||||
# create sub graph run condition
|
||||
iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls)
|
||||
sub_graph_run_condition = RunCondition(
|
||||
type='condition',
|
||||
conditions=iterable_node_cls.get_conditions(
|
||||
node_config=target_node_config
|
||||
)
|
||||
)
|
||||
|
||||
# create sub graph
|
||||
sub_graph = Graph.init(
|
||||
root_node_config=sub_graph_root_node_config,
|
||||
run_condition=sub_graph_run_condition
|
||||
)
|
||||
|
||||
self._recursively_add_edges(
|
||||
graph=sub_graph,
|
||||
source_node_config=sub_graph_root_node_config,
|
||||
edges_mapping=edges_mapping,
|
||||
nodes_mapping=nodes_mapping,
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
|
||||
# add edge from end node to first node of sub graph
|
||||
sub_graph_root_node_id = sub_graph.root_node.id
|
||||
for leaf_node in sub_graph.get_leaf_nodes():
|
||||
leaf_node.add_child(sub_graph_root_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle'):
|
||||
run_condition = RunCondition(
|
||||
type='branch_identify',
|
||||
branch_identify=edge_config.get('sourceHandle')
|
||||
)
|
||||
|
||||
# add edge
|
||||
graph.add_edge(
|
||||
edge_config=edge_config,
|
||||
source_node_config=source_node_config,
|
||||
target_node_config=target_node_config,
|
||||
target_node_sub_graph=sub_graph,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
# recursively add edges
|
||||
self._recursively_add_edges(
|
||||
graph=graph,
|
||||
source_node_config=target_node_config,
|
||||
edges_mapping=edges_mapping,
|
||||
nodes_mapping=nodes_mapping,
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
|
||||
def _run_workflow(self, graph_config: dict,
|
||||
workflow_runtime_state: WorkflowRuntimeState,
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
|
Loading…
x
Reference in New Issue
Block a user