From 821e09b259b97c5bf84947bdf7131e383dde4423 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 12 Jul 2024 19:33:47 +0800 Subject: [PATCH] add run logics --- .../workflow/graph_engine/entities/graph.py | 30 +- .../entities/graph_runtime_state.py | 3 +- .../workflow/graph_engine/graph_engine.py | 276 ++++++++++++++---- api/core/workflow/workflow_entry.py | 99 ------- 4 files changed, 256 insertions(+), 152 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 806c1398c7..deddd255c0 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -22,6 +22,12 @@ class GraphParallel(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) """random uuid parallel id""" + start_from_node_id: str + """start from node id""" + + end_to_node_id: Optional[str] = None + """end to node id""" + parent_parallel_id: Optional[str] = None """parent parallel id if exists""" @@ -33,6 +39,9 @@ class Graph(BaseModel): node_ids: list[str] = Field(default_factory=list) """graph node ids""" + node_id_config_mapping: dict[str, dict] = Field(default_factory=list) + """node configs mapping (node id: node config)""" + edge_mapping: dict[str, list[GraphEdge]] = Field(default_factory=dict) """graph edge mapping (source node id: edges)""" @@ -102,6 +111,7 @@ class Graph(BaseModel): # fetch nodes that have no predecessor node root_node_configs = [] + all_node_id_config_mapping: dict[str, dict] = {} for node_config in node_configs: node_id = node_config.get('id') if not node_id: @@ -110,6 +120,8 @@ class Graph(BaseModel): if node_id not in target_edge_ids: root_node_configs.append(node_config) + all_node_id_config_mapping[node_id] = node_config + root_node_ids = [node_config.get('id') for node_config in root_node_configs] # fetch root node @@ -129,6 +141,8 @@ class Graph(BaseModel): node_id=root_node_id ) + node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids} + # init parallel mapping parallel_mapping: dict[str, GraphParallel] = {} node_parallel_mapping: dict[str, str] = {} @@ -143,6 +157,7 @@ class Graph(BaseModel): graph = cls( root_node_id=root_node_id, node_ids=node_ids, + node_id_config_mapping=node_id_config_mapping, edge_mapping=edge_mapping, parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping @@ -243,7 +258,10 @@ class Graph(BaseModel): if all(node_id in node_parallel_mapping for node_id in parallel_node_ids): parent_parallel_id = node_parallel_mapping[parallel_node_ids[0]] - parallel = GraphParallel(parent_parallel_id=parent_parallel_id) + parallel = GraphParallel( + start_from_node_id=start_node_id, + parent_parallel_id=parent_parallel_id + ) parallel_mapping[parallel.id] = parallel in_branch_node_ids = cls._fetch_all_node_ids_in_parallels( @@ -252,10 +270,20 @@ class Graph(BaseModel): ) # collect all branches node ids + end_to_node_id: Optional[str] = None for branch_node_id, node_ids in in_branch_node_ids.items(): for node_id in node_ids: node_parallel_mapping[node_id] = parallel.id + if not end_to_node_id and edge_mapping.get(node_id): + node_edges = edge_mapping[node_id] + target_node_id = node_edges[0].target_node_id + if node_parallel_mapping.get(target_node_id) == parent_parallel_id: + end_to_node_id = target_node_id + + if end_to_node_id: + parallel.end_to_node_id = end_to_node_id + for graph_edge in target_node_edges: cls._recursively_add_parallels( edge_mapping=edge_mapping, diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index 7075ff75ba..1eeff95d13 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, Field @@ -19,7 +18,7 @@ class GraphRuntimeState(BaseModel): variable_pool: VariablePool - start_at: Optional[float] = None + start_at: float total_tokens: int = 0 node_run_steps: int = 0 diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index c0f5acdb9a..76a7080a87 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -1,15 +1,25 @@ +import logging +import queue import time from collections.abc import Generator -from typing import cast +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Optional, cast -from flask import current_app +from flask import Flask, current_app +from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback +from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.base_node import UserFrom +from extensions.ext_database import db + +thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun") +logger = logging.getLogger(__name__) class GraphEngine: @@ -30,7 +40,8 @@ class GraphEngine: user_from=user_from, invoke_from=invoke_from, call_depth=call_depth, - variable_pool=variable_pool + variable_pool=variable_pool, + start_at=time.perf_counter() ) max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") @@ -40,52 +51,217 @@ class GraphEngine: self.callbacks = callbacks - def run(self) -> Generator: - self.graph_runtime_state.start_at = time.perf_counter() + def run_in_block_mode(self): + # TODO convert generator to result pass - # def next_node_ids(self, node_state_id: str) -> list[NextGraphNode]: - # """ - # Get next node ids - # - # :param node_state_id: source node state id - # """ - # # get current node ids in state - # node_run_state = self.graph_runtime_state.node_run_state - # graph = self.graph - # if not node_run_state.routes: - # return [NextGraphNode(node_id=graph.root_node_id)] - # - # route_final_graph_edges: list[GraphEdge] = [] - # for route in route_state.routes[graph.root_node_id]: - # graph_edges = graph.edge_mapping.get(route.node_id) - # if not graph_edges: - # continue - # - # for edge in graph_edges: - # if edge.target_node_id not in route_state.routes: - # route_final_graph_edges.append(edge) - # - # next_graph_nodes = [] - # for route_final_graph_edge in route_final_graph_edges: - # node_id = route_final_graph_edge.target_node_id - # # check condition - # if route_final_graph_edge.run_condition: - # result = ConditionManager.get_condition_handler( - # run_condition=route_final_graph_edge.run_condition - # ).check( - # source_node_id=route_final_graph_edge.source_node_id, - # target_node_id=route_final_graph_edge.target_node_id, - # graph=self - # ) - # - # if not result: - # continue - # - # parallel = None - # if route_final_graph_edge.target_node_id in graph.node_parallel_mapping: - # parallel = graph.parallel_mapping[graph.node_parallel_mapping[node_id]] - # - # next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel)) - # - # return next_graph_nodes + def run(self) -> Generator: + # TODO trigger graph run start event + + try: + # TODO run graph + rst = self._run(start_node_id=self.graph.root_node_id) + except GraphRunFailedError as e: + # TODO self._graph_run_failed( + # error=e.error, + # callbacks=callbacks + # ) + pass + except Exception as e: + # TODO self._workflow_run_failed( + # error=str(e), + # callbacks=callbacks + # ) + pass + + # TODO trigger graph run success event + + yield rst + + def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None): + next_node_id = start_node_id + while True: + # max steps reached + if self.graph_runtime_state.node_run_steps > self.max_execution_steps: + raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps)) + + # or max execution time reached + if self._is_timed_out( + start_at=self.graph_runtime_state.start_at, + max_execution_time=self.max_execution_time + ): + raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) + + # run node TODO generator + yield from self._run_node(node_id=next_node_id) + + # todo if failed, break + + # get next node ids + edge_mappings = self.graph.edge_mapping.get(next_node_id) + if not edge_mappings: + break + + if len(edge_mappings) == 1: + next_node_id = edge_mappings[0].target_node_id + + # It may not be necessary, but it is necessary. :) + if (self.graph.node_id_config_mapping[next_node_id] + .get("data", {}).get("type", "").lower() == NodeType.END.value): + break + else: + if any(edge.run_condition for edge in edge_mappings): + # if nodes has run conditions, get node id which branch to take based on the run condition results + final_node_id = None + for edge in edge_mappings: + if edge.run_condition: + result = ConditionManager.get_condition_handler( + run_condition=edge.run_condition + ).check( + source_node_id=edge.source_node_id, + target_node_id=edge.target_node_id, + graph=self.graph + ) + + if result: + final_node_id = edge.target_node_id + break + + if not final_node_id: + break + + next_node_id = final_node_id + else: + # if nodes has no run conditions, parallel run all nodes + parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id) + if not parallel_id: + raise GraphRunFailedError('Node related parallel not found.') + + parallel = self.graph.parallel_mapping.get(parallel_id) + if not parallel: + raise GraphRunFailedError('Parallel not found.') + + # run parallel nodes, run in new thread and use queue to get results + q: queue.Queue = queue.Queue() + + # new thread + futures = [] + for edge in edge_mappings: + futures.append(thread_pool.submit( + self._run_parallel_node, + flask_app=current_app._get_current_object(), + parallel_start_node_id=edge.source_node_id, + q=q + )) + + while True: + try: + event = q.get(timeout=1) + if event is None: + break + + # TODO tag event with parallel id + yield event + except queue.Empty: + continue + + for future in as_completed(futures): + future.result() + + # get final node id + final_node_id = parallel.end_to_node_id + if not final_node_id: + break + + next_node_id = final_node_id + + if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id: + break + + def _run_parallel_node(self, flask_app: Flask, parallel_start_node_id: str, q: queue.Queue) -> None: + """ + Run parallel nodes + """ + with flask_app.app_context(): + try: + in_parallel_id = self.graph.node_parallel_mapping.get(parallel_start_node_id) + if not in_parallel_id: + q.put(None) + return + + # run node TODO generator + rst = self._run( + start_node_id=parallel_start_node_id, + in_parallel_id=in_parallel_id + ) + + if not rst: + q.put(None) + return + + for item in rst: + q.put(item) + + q.put(None) + except Exception: + logger.exception("Unknown Error when generating in parallel") + finally: + db.session.remove() + + def _run_node(self, node_id: str) -> Generator: + """ + Run node + """ + # get node config + node_config = self.graph.node_id_config_mapping.get(node_id) + if not node_config: + raise GraphRunFailedError('Node not found.') + + # todo convert to specific node + + # todo trigger node run start event + + db.session.close() + + # TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node + + self.graph_runtime_state.node_run_steps += 1 + + try: + # run node + rst = node.run( + graph_runtime_state=self.graph_runtime_state, + graph=self.graph, + callbacks=self.callbacks + ) + + yield from rst + + # todo record state + except GenerateTaskStoppedException as e: + # TODO yield failed + # todo trigger node run failed event + pass + except Exception as e: + # logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") + # TODO yield failed + # todo trigger node run failed event + pass + + # todo trigger node run success event + + db.session.close() + + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: + """ + Check timeout + :param start_at: start time + :param max_execution_time: max execution time + :return: + """ + return time.perf_counter() - start_at > max_execution_time + + +class GraphRunFailedError(Exception): + def __init__(self, error: str): + self.error = error diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 0412ba572e..840b12878c 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -16,7 +16,6 @@ from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, Work from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.answer.answer_node import AnswerNode from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom @@ -24,7 +23,6 @@ from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.end.end_node import EndNode from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iterable_node import IterableNodeMixin from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode @@ -152,103 +150,6 @@ class WorkflowEntry: return rst - def _recursively_add_edges(self, graph: Graph, - source_node_config: dict, - edges_mapping: dict, - nodes_mapping: dict, - root_node_configs: list[dict]) -> None: - """ - Add edges - - :param source_node_config: source node config - :param edges_mapping: edges mapping - :param nodes_mapping: nodes mapping - :param root_node_configs: root node configs - """ - source_node_id = source_node_config.get('id') - if not source_node_id: - return - - for edge_config in edges_mapping.get(source_node_id, []): - target_node_id = edge_config.get('target') - if not target_node_id: - continue - - target_node_config = nodes_mapping.get(target_node_id) - if not target_node_config: - continue - - sub_graph: Optional[Graph] = None - target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type')) - target_node_cls = None - if target_node_type: - target_node_cls = node_classes.get(target_node_type) - if not target_node_cls: - raise Exception(f'Node class not found for node type: {target_node_type}') - - if target_node_cls and issubclass(target_node_cls, IterableNodeMixin): - # find iteration/loop sub nodes that have no predecessor node - sub_graph_root_node_config = None - for root_node_config in root_node_configs: - if root_node_config.get('parentId') == target_node_id: - sub_graph_root_node_config = root_node_config - break - - if sub_graph_root_node_config: - # create sub graph run condition - iterable_node_cls: IterableNodeMixin = cast(IterableNodeMixin, target_node_cls) - sub_graph_run_condition = RunCondition( - type='condition', - conditions=iterable_node_cls.get_conditions( - node_config=target_node_config - ) - ) - - # create sub graph - sub_graph = Graph.init( - root_node_config=sub_graph_root_node_config, - run_condition=sub_graph_run_condition - ) - - self._recursively_add_edges( - graph=sub_graph, - source_node_config=sub_graph_root_node_config, - edges_mapping=edges_mapping, - nodes_mapping=nodes_mapping, - root_node_configs=root_node_configs - ) - - # add edge from end node to first node of sub graph - sub_graph_root_node_id = sub_graph.root_node.id - for leaf_node in sub_graph.get_leaf_nodes(): - leaf_node.add_child(sub_graph_root_node_id) - - # parse run condition - run_condition = None - if edge_config.get('sourceHandle'): - run_condition = RunCondition( - type='branch_identify', - branch_identify=edge_config.get('sourceHandle') - ) - - # add edge - graph.add_edge( - edge_config=edge_config, - source_node_config=source_node_config, - target_node_config=target_node_config, - target_node_sub_graph=sub_graph, - run_condition=run_condition - ) - - # recursively add edges - self._recursively_add_edges( - graph=graph, - source_node_config=target_node_config, - edges_mapping=edges_mapping, - nodes_mapping=nodes_mapping, - root_node_configs=root_node_configs - ) - def _run_workflow(self, graph_config: dict, workflow_runtime_state: WorkflowRuntimeState, callbacks: list[BaseWorkflowCallback],