add run logics

This commit is contained in:
takatost 2024-07-12 19:33:47 +08:00
parent d77b689a99
commit 821e09b259
4 changed files with 256 additions and 152 deletions

View File

@ -22,6 +22,12 @@ class GraphParallel(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random uuid parallel id"""
start_from_node_id: str
"""start from node id"""
end_to_node_id: Optional[str] = None
"""end to node id"""
parent_parallel_id: Optional[str] = None
"""parent parallel id if exists"""
@ -33,6 +39,9 @@ class Graph(BaseModel):
node_ids: list[str] = Field(default_factory=list)
"""graph node ids"""
node_id_config_mapping: dict[str, dict] = Field(default_factory=list)
"""node configs mapping (node id: node config)"""
edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict)
"""graph edge mapping (source node id: edges)"""
@ -102,6 +111,7 @@ class Graph(BaseModel):
# fetch nodes that have no predecessor node
root_node_configs = []
all_node_id_config_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
@ -110,6 +120,8 @@ class Graph(BaseModel):
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
all_node_id_config_mapping[node_id] = node_config
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
# fetch root node
@ -129,6 +141,8 @@ class Graph(BaseModel):
node_id=root_node_id
)
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
# init parallel mapping
parallel_mapping: dict[str, GraphParallel] = {}
node_parallel_mapping: dict[str, str] = {}
@ -143,6 +157,7 @@ class Graph(BaseModel):
graph = cls(
root_node_id=root_node_id,
node_ids=node_ids,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping,
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
@ -243,7 +258,10 @@ class Graph(BaseModel):
if all(node_id in node_parallel_mapping for node_id in parallel_node_ids):
parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]]
parallel = GraphParallel(parent_parallel_id=parent_parallel_id)
parallel = GraphParallel(
start_from_node_id=start_node_id,
parent_parallel_id=parent_parallel_id
)
parallel_mapping[parallel.id] = parallel
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
@ -252,10 +270,20 @@ class Graph(BaseModel):
)
# collect all branches node ids
end_to_node_id: Optional[str] = None
for branch_node_id, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
node_parallel_mapping[node_id] = parallel.id
if not end_to_node_id and edge_mapping.get(node_id):
node_edges = edge_mapping[node_id]
target_node_id = node_edges[0].target_node_id
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
end_to_node_id = target_node_id
if end_to_node_id:
parallel.end_to_node_id = end_to_node_id
for graph_edge in target_node_edges:
cls._recursively_add_parallels(
edge_mapping=edge_mapping,

View File

@ -1,4 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
@ -19,7 +18,7 @@ class GraphRuntimeState(BaseModel):
variable_pool: VariablePool
start_at: Optional[float] = None
start_at: float
total_tokens: int = 0
node_run_steps: int = 0

View File

@ -1,15 +1,25 @@
import logging
import queue
import time
from collections.abc import Generator
from typing import cast
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional, cast
from flask import current_app
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom
from extensions.ext_database import db
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
logger = logging.getLogger(__name__)
class GraphEngine:
@ -30,7 +40,8 @@ class GraphEngine:
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
variable_pool=variable_pool
variable_pool=variable_pool,
start_at=time.perf_counter()
)
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
@ -40,52 +51,217 @@ class GraphEngine:
self.callbacks = callbacks
def run(self) -> Generator:
self.graph_runtime_state.start_at = time.perf_counter()
def run_in_block_mode(self):
# TODO convert generator to result
pass
# def next_node_ids(self, node_state_id: str) -> list[NextGraphNode]:
# """
# Get next node ids
#
# :param node_state_id: source node state id
# """
# # get current node ids in state
# node_run_state = self.graph_runtime_state.node_run_state
# graph = self.graph
# if not node_run_state.routes:
# return [NextGraphNode(node_id=graph.root_node_id)]
#
# route_final_graph_edges: list[GraphEdge] = []
# for route in route_state.routes[graph.root_node_id]:
# graph_edges = graph.edge_mapping.get(route.node_id)
# if not graph_edges:
# continue
#
# for edge in graph_edges:
# if edge.target_node_id not in route_state.routes:
# route_final_graph_edges.append(edge)
#
# next_graph_nodes = []
# for route_final_graph_edge in route_final_graph_edges:
# node_id = route_final_graph_edge.target_node_id
# # check condition
# if route_final_graph_edge.run_condition:
# result = ConditionManager.get_condition_handler(
# run_condition=route_final_graph_edge.run_condition
# ).check(
# source_node_id=route_final_graph_edge.source_node_id,
# target_node_id=route_final_graph_edge.target_node_id,
# graph=self
# )
#
# if not result:
# continue
#
# parallel = None
# if route_final_graph_edge.target_node_id in graph.node_parallel_mapping:
# parallel = graph.parallel_mapping[graph.node_parallel_mapping[node_id]]
#
# next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
#
# return next_graph_nodes
def run(self) -> Generator:
# TODO trigger graph run start event
try:
# TODO run graph
rst = self._run(start_node_id=self.graph.root_node_id)
except GraphRunFailedError as e:
# TODO self._graph_run_failed(
# error=e.error,
# callbacks=callbacks
# )
pass
except Exception as e:
# TODO self._workflow_run_failed(
# error=str(e),
# callbacks=callbacks
# )
pass
# TODO trigger graph run success event
yield rst
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None):
next_node_id = start_node_id
while True:
# max steps reached
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
# or max execution time reached
if self._is_timed_out(
start_at=self.graph_runtime_state.start_at,
max_execution_time=self.max_execution_time
):
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# run node TODO generator
yield from self._run_node(node_id=next_node_id)
# todo if failed, break
# get next node ids
edge_mappings = self.graph.edge_mapping.get(next_node_id)
if not edge_mappings:
break
if len(edge_mappings) == 1:
next_node_id = edge_mappings[0].target_node_id
# It may not be necessary, but it is necessary. :)
if (self.graph.node_id_config_mapping[next_node_id]
.get("data", {}).get("type", "").lower() == NodeType.END.value):
break
else:
if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
final_node_id = None
for edge in edge_mappings:
if edge.run_condition:
result = ConditionManager.get_condition_handler(
run_condition=edge.run_condition
).check(
source_node_id=edge.source_node_id,
target_node_id=edge.target_node_id,
graph=self.graph
)
if result:
final_node_id = edge.target_node_id
break
if not final_node_id:
break
next_node_id = final_node_id
else:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
if not parallel_id:
raise GraphRunFailedError('Node related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError('Parallel not found.')
# run parallel nodes, run in new thread and use queue to get results
q: queue.Queue = queue.Queue()
# new thread
futures = []
for edge in edge_mappings:
futures.append(thread_pool.submit(
self._run_parallel_node,
flask_app=current_app._get_current_object(),
parallel_start_node_id=edge.source_node_id,
q=q
))
while True:
try:
event = q.get(timeout=1)
if event is None:
break
# TODO tag event with parallel id
yield event
except queue.Empty:
continue
for future in as_completed(futures):
future.result()
# get final node id
final_node_id = parallel.end_to_node_id
if not final_node_id:
break
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
break
def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None:
"""
Run parallel nodes
"""
with flask_app.app_context():
try:
in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id)
if not in_parallel_id:
q.put(None)
return
# run node TODO generator
rst = self._run(
start_node_id=parallel_start_node_id,
in_parallel_id=in_parallel_id
)
if not rst:
q.put(None)
return
for item in rst:
q.put(item)
q.put(None)
except Exception:
logger.exception("Unknown Error when generating in parallel")
finally:
db.session.remove()
def _run_node(self, node_id: str) -> Generator:
"""
Run node
"""
# get node config
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError('Node not found.')
# todo convert to specific node
# todo trigger node run start event
db.session.close()
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
self.graph_runtime_state.node_run_steps += 1
try:
# run node
rst = node.run(
graph_runtime_state=self.graph_runtime_state,
graph=self.graph,
callbacks=self.callbacks
)
yield from rst
# todo record state
except GenerateTaskStoppedException as e:
# TODO yield failed
# todo trigger node run failed event
pass
except Exception as e:
# logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# TODO yield failed
# todo trigger node run failed event
pass
# todo trigger node run success event
db.session.close()
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
:param start_at: start time
:param max_execution_time: max execution time
:return:
"""
return time.perf_counter() - start_at > max_execution_time
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -16,7 +16,6 @@ from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, Work
from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
@ -24,7 +23,6 @@ from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iterable_node import IterableNodeMixin
from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
@ -152,103 +150,6 @@ class WorkflowEntry:
return rst
def _recursively_add_edges(self, graph: Graph,
source_node_config: dict,
edges_mapping: dict,
nodes_mapping: dict,
root_node_configs: list[dict]) -> None:
"""
Add edges
:param source_node_config: source node config
:param edges_mapping: edges mapping
:param nodes_mapping: nodes mapping
:param root_node_configs: root node configs
"""
source_node_id = source_node_config.get('id')
if not source_node_id:
return
for edge_config in edges_mapping.get(source_node_id, []):
target_node_id = edge_config.get('target')
if not target_node_id:
continue
target_node_config = nodes_mapping.get(target_node_id)
if not target_node_config:
continue
sub_graph: Optional[Graph] = None
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
target_node_cls = None
if target_node_type:
target_node_cls = node_classes.get(target_node_type)
if not target_node_cls:
raise Exception(f'Node class not found for node type: {target_node_type}')
if target_node_cls and issubclass(target_node_cls, IterableNodeMixin):
# find iteration/loop sub nodes that have no predecessor node
sub_graph_root_node_config = None
for root_node_config in root_node_configs:
if root_node_config.get('parentId') == target_node_id:
sub_graph_root_node_config = root_node_config
break
if sub_graph_root_node_config:
# create sub graph run condition
iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls)
sub_graph_run_condition = RunCondition(
type='condition',
conditions=iterable_node_cls.get_conditions(
node_config=target_node_config
)
)
# create sub graph
sub_graph = Graph.init(
root_node_config=sub_graph_root_node_config,
run_condition=sub_graph_run_condition
)
self._recursively_add_edges(
graph=sub_graph,
source_node_config=sub_graph_root_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
# add edge from end node to first node of sub graph
sub_graph_root_node_id = sub_graph.root_node.id
for leaf_node in sub_graph.get_leaf_nodes():
leaf_node.add_child(sub_graph_root_node_id)
# parse run condition
run_condition = None
if edge_config.get('sourceHandle'):
run_condition = RunCondition(
type='branch_identify',
branch_identify=edge_config.get('sourceHandle')
)
# add edge
graph.add_edge(
edge_config=edge_config,
source_node_config=source_node_config,
target_node_config=target_node_config,
target_node_sub_graph=sub_graph,
run_condition=run_condition
)
# recursively add edges
self._recursively_add_edges(
graph=graph,
source_node_config=target_node_config,
edges_mapping=edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
def _run_workflow(self, graph_config: dict,
workflow_runtime_state: WorkflowRuntimeState,
callbacks: list[BaseWorkflowCallback],