From cc96acdae31195fe2b0ea09359cc47bde00c2a9a Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 17 Jul 2024 11:26:33 +0800 Subject: [PATCH] fix bugs --- .../entities/runtime_route_state.py | 46 +++--- .../workflow/graph_engine/graph_engine.py | 133 +++++++++++------ api/core/workflow/nodes/answer/answer_node.py | 2 +- .../graph_engine/test_graph_engine.py | 134 +++++++++++------- 4 files changed, 199 insertions(+), 116 deletions(-) diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index a300d8b2f4..90c918e370 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -43,6 +43,26 @@ class RouteNodeState(BaseModel): paused_by: Optional[str] = None """paused by""" + def set_finished(self, run_result: NodeRunResult) -> None: + """ + Node finished + + :param run_result: run result + """ + if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + raise Exception(f"Route state {self.id} already finished") + + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + self.status = RouteNodeState.Status.SUCCESS + elif run_result.status == WorkflowNodeExecutionStatus.FAILED: + self.status = RouteNodeState.Status.FAILED + self.failed_reason = run_result.error + else: + raise Exception(f"Invalid route status {run_result.status}") + + self.node_run_result = run_result + self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + class RuntimeRouteState(BaseModel): routes: dict[str, list[str]] = Field( @@ -87,29 +107,3 @@ class RuntimeRouteState(BaseModel): """ return [self.node_state_mapping[target_state_id] for target_state_id in self.routes.get(source_node_state_id, [])] - - def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None: - """ - Node finished - - :param node_state_id: route node state id - :param run_result: run result - """ - if node_state_id not in self.node_state_mapping: - raise Exception(f"Route state {node_state_id} not found") - - route = self.node_state_mapping[node_state_id] - - if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: - raise Exception(f"Route state {node_state_id} already finished") - - if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - route.status = RouteNodeState.Status.SUCCESS - elif run_result.status == WorkflowNodeExecutionStatus.FAILED: - route.status = RouteNodeState.Status.FAILED - route.failed_reason = run_result.error - else: - raise Exception(f"Invalid route status {run_result.status}") - - route.node_run_result = run_result - route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 211d0df5fb..00ce48571b 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -3,14 +3,14 @@ import queue import time from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timezone from typing import Optional from flask import Flask, current_app +from uritemplate.variable import VariableValue from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeType, UserFrom +from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -89,7 +89,7 @@ class GraphEngine: # trigger graph run success event yield GraphRunSucceededEvent() - except (GraphRunFailedError, NodeRunFailedError) as e: + except GraphRunFailedError as e: yield GraphRunFailedEvent(reason=e.error) return except Exception as e: @@ -112,7 +112,7 @@ class GraphEngine: raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) # init route node state - route_node_state = self.graph_runtime_state.create_node_state( + route_node_state = self.graph_runtime_state.node_run_state.create_node_state( node_id=next_node_id ) @@ -128,13 +128,13 @@ class GraphEngine: # append route if previous_route_node_state: - if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes: - self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = [] - - self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append( - route_node_state.id + self.graph_runtime_state.node_run_state.add_route( + source_node_state_id=previous_route_node_state.id, + target_node_state_id=route_node_state.id ) except Exception as e: + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = str(e) yield NodeRunFailedEvent( route_node_state=route_node_state, parallel_id=in_parallel_id @@ -181,9 +181,9 @@ class GraphEngine: next_node_id = final_node_id else: # if nodes has no run conditions, parallel run all nodes - parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id) + parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id) if not parallel_id: - raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.') + raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.') parallel = self.graph.parallel_mapping.get(parallel_id) if not parallel: @@ -199,18 +199,27 @@ class GraphEngine: self._run_parallel_node, flask_app=current_app._get_current_object(), # type: ignore parallel_id=parallel_id, - parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel + parallel_start_node_id=edge.target_node_id, q=q )) + succeeded_count = 0 while True: try: event = q.get(timeout=1) if event is None: break - # TODO tag event with parallel id - yield event + if isinstance(event, GraphRunSucceededEvent): + succeeded_count += 1 + if succeeded_count == len(edge_mappings): + break + + continue + elif isinstance(event, GraphRunFailedEvent): + raise GraphRunFailedError(event.reason) + else: + yield event except queue.Empty: continue @@ -246,19 +255,15 @@ class GraphEngine: for item in generator: q.put(item) - if isinstance(item, NodeRunFailedEvent): - q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')) - return # trigger graph run success event q.put(GraphRunSucceededEvent()) - except (GraphRunFailedError, NodeRunFailedError) as e: + except GraphRunFailedError as e: q.put(GraphRunFailedEvent(reason=e.error)) except Exception as e: logger.exception("Unknown Error when generating in parallel") q.put(GraphRunFailedEvent(reason=str(e))) finally: - q.put(None) db.session.remove() def _run_node(self, @@ -268,17 +273,35 @@ class GraphEngine: """ Run node """ + # trigger node run start event + yield NodeRunStartedEvent( + route_node_state=route_node_state, + parallel_id=parallel_id + ) + # get node config node_id = route_node_state.node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: - raise GraphRunFailedError(f'Node {node_id} config not found.') + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = f'Node {node_id} config not found.' + yield NodeRunFailedEvent( + route_node_state=route_node_state, + parallel_id=parallel_id + ) + return # convert to specific node node_type = NodeType.value_of(node_config.get('data', {}).get('type')) node_cls = node_classes.get(node_type) if not node_cls: - raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.' + yield NodeRunFailedEvent( + route_node_state=route_node_state, + parallel_id=parallel_id + ) + return # init workflow run state node_instance = node_cls( # type: ignore @@ -289,12 +312,6 @@ class GraphEngine: previous_node_id=previous_node_id ) - # trigger node run start event - yield NodeRunStartedEvent( - route_node_state=route_node_state, - parallel_id=parallel_id - ) - db.session.close() # TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node @@ -307,13 +324,7 @@ class GraphEngine: for item in generator: if isinstance(item, RunCompletedEvent): run_result = item.run_result - route_node_state.status = RouteNodeState.Status.SUCCESS \ - if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \ - else RouteNodeState.Status.FAILED - route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - route_node_state.node_run_result = run_result - route_node_state.failed_reason = run_result.error \ - if run_result.status == WorkflowNodeExecutionStatus.FAILED else None + route_node_state.set_finished(run_result=run_result) if run_result.status == WorkflowNodeExecutionStatus.FAILED: yield NodeRunFailedEvent( @@ -321,10 +332,27 @@ class GraphEngine: route_node_state=route_node_state ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + # plus state total_tokens + self.graph_runtime_state.total_tokens += int( + run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) + ) + + # append node output variables to variable pool + if run_result.outputs: + for variable_key, variable_value in run_result.outputs.items(): + # append variables to variable pool recursively + self._append_variables_recursively( + node_id=node_id, + variable_key_list=[variable_key], + variable_value=variable_value + ) + yield NodeRunSucceededEvent( parallel_id=parallel_id, route_node_state=route_node_state ) + break elif isinstance(item, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( @@ -340,8 +368,10 @@ class GraphEngine: retriever_resources=item.retriever_resources, context=item.context ) - except GenerateTaskStoppedException as e: + except GenerateTaskStoppedException: # trigger node run failed event + route_node_state.status = RouteNodeState.Status.FAILED + route_node_state.failed_reason = "Workflow stopped." yield NodeRunFailedEvent( route_node_state=route_node_state, parallel_id=parallel_id @@ -353,6 +383,34 @@ class GraphEngine: finally: db.session.close() + def _append_variables_recursively(self, + node_id: str, + variable_key_list: list[str], + variable_value: VariableValue): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + self.graph_runtime_state.variable_pool.append_variable( + node_id=node_id, + variable_key_list=variable_key_list, + value=variable_value + ) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, dict): + for key, value in variable_value.items(): + # construct new key list + new_key_list = variable_key_list + [key] + self._append_variables_recursively( + node_id=node_id, + variable_key_list=new_key_list, + variable_value=value + ) + def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ Check timeout @@ -366,8 +424,3 @@ class GraphEngine: class GraphRunFailedError(Exception): def __init__(self, error: str): self.error = error - - -class NodeRunFailedError(Exception): - def __init__(self, error: str): - self.error = error diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index fb61a6890f..bf5c1617b5 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -26,7 +26,7 @@ class AnswerNode(BaseNode): :return: """ node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) + node_data = cast(AnswerNodeData, node_data) # generate routes generate_routes = self.extract_generate_route_from_node_data(node_data) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 5032bac330..4b25b2a3fb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -1,9 +1,20 @@ from unittest.mock import patch +from flask import Flask + from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.event import ( + BaseNodeEvent, + GraphRunFailedEvent, + GraphRunStartedEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, + NodeRunSucceededEvent, +) from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.graph_engine import GraphEngine from models.workflow import WorkflowType @@ -14,31 +25,29 @@ def test_run(mock_close, mock_remove): graph_config = { "edges": [ { - "id": "llm-source-answer-target", - "source": "llm", - "target": "answer", - }, - { - "id": "start-source-qc-target", + "id": "1", "source": "start", - "target": "qc", + "target": "answer1", }, { - "id": "qc-1-llm-target", - "source": "qc", - "sourceHandle": "1", - "target": "llm", - }, - { - "id": "qc-2-http-target", - "source": "qc", - "sourceHandle": "2", - "target": "http", - }, - { - "id": "http-source-answer2-target", - "source": "http", + "id": "2", + "source": "answer1", "target": "answer2", + }, + { + "id": "3", + "source": "answer1", + "target": "answer3", + }, + { + "id": "4", + "source": "answer2", + "target": "answer4", + }, + { + "id": "5", + "source": "answer3", + "target": "answer5", } ], "nodes": [ @@ -51,38 +60,43 @@ def test_run(mock_close, mock_remove): }, { "data": { - "type": "llm", - "title": "llm" + "type": "answer", + "title": "answer1", + "answer": "1" }, - "id": "llm" + "id": "answer1" }, { "data": { "type": "answer", - "title": "answer" - }, - "id": "answer", - }, - { - "data": { - "type": "question-classifier", - "title": "qc" - }, - "id": "qc", - }, - { - "data": { - "type": "http-request", - "title": "http" - }, - "id": "http", - }, - { - "data": { - "type": "answer", - "title": "answer2" + "title": "answer2", + "answer": "2" }, "id": "answer2", + }, + { + "data": { + "type": "answer", + "title": "answer3", + "answer": "3" + }, + "id": "answer3", + }, + { + "data": { + "type": "answer", + "title": "answer4", + "answer": "4" + }, + "id": "answer4", + }, + { + "data": { + "type": "answer", + "title": "answer5", + "answer": "5" + }, + "id": "answer5", } ], } @@ -115,6 +129,28 @@ def test_run(mock_close, mock_remove): print("") - generator = graph_engine.run() - for item in generator: - print(type(item), item) + app = Flask('test') + + items = [] + with app.app_context(): + generator = graph_engine.run() + for item in generator: + print(type(item), item) + items.append(item) + if isinstance(item, NodeRunSucceededEvent): + assert item.route_node_state.status == RouteNodeState.Status.SUCCESS + + assert not isinstance(item, NodeRunFailedEvent) + assert not isinstance(item, GraphRunFailedEvent) + + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']: + assert item.parallel_id is not None + + assert len(items) == 12 + assert isinstance(items[0], GraphRunStartedEvent) + assert isinstance(items[1], NodeRunStartedEvent) + assert items[1].route_node_state.node_id == 'start' + assert isinstance(items[2], NodeRunSucceededEvent) + assert items[2].route_node_state.node_id == 'start' + + print(graph_engine.graph_runtime_state)