This commit is contained in:
takatost 2024-07-17 11:26:33 +08:00
parent 16e2d00157
commit cc96acdae3
4 changed files with 199 additions and 116 deletions

View File

@ -43,6 +43,26 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None paused_by: Optional[str] = None
"""paused by""" """paused by"""
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished
:param run_result: run result
"""
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
class RuntimeRouteState(BaseModel): class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field( routes: dict[str, list[str]] = Field(
@ -87,29 +107,3 @@ class RuntimeRouteState(BaseModel):
""" """
return [self.node_state_mapping[target_state_id] return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])] for target_state_id in self.routes.get(source_node_state_id, [])]
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
"""
Node finished
:param node_state_id: route node state id
:param run_result: run result
"""
if node_state_id not in self.node_state_mapping:
raise Exception(f"Route state {node_state_id} not found")
route = self.node_state_mapping[node_state_id]
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {node_state_id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
route.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
route.status = RouteNodeState.Status.FAILED
route.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
route.node_run_result = run_result
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)

View File

@ -3,14 +3,14 @@ import queue
import time import time
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import Optional from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from uritemplate.variable import VariableValue
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeType, UserFrom from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
@ -89,7 +89,7 @@ class GraphEngine:
# trigger graph run success event # trigger graph run success event
yield GraphRunSucceededEvent() yield GraphRunSucceededEvent()
except (GraphRunFailedError, NodeRunFailedError) as e: except GraphRunFailedError as e:
yield GraphRunFailedEvent(reason=e.error) yield GraphRunFailedEvent(reason=e.error)
return return
except Exception as e: except Exception as e:
@ -112,7 +112,7 @@ class GraphEngine:
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# init route node state # init route node state
route_node_state = self.graph_runtime_state.create_node_state( route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
node_id=next_node_id node_id=next_node_id
) )
@ -128,13 +128,13 @@ class GraphEngine:
# append route # append route
if previous_route_node_state: if previous_route_node_state:
if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes: self.graph_runtime_state.node_run_state.add_route(
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = [] source_node_state_id=previous_route_node_state.id,
target_node_state_id=route_node_state.id
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append(
route_node_state.id
) )
except Exception as e: except Exception as e:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=in_parallel_id parallel_id=in_parallel_id
@ -181,9 +181,9 @@ class GraphEngine:
next_node_id = final_node_id next_node_id = final_node_id
else: else:
# if nodes has no run conditions, parallel run all nodes # if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id) parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id: if not parallel_id:
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.') raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id) parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel: if not parallel:
@ -199,18 +199,27 @@ class GraphEngine:
self._run_parallel_node, self._run_parallel_node,
flask_app=current_app._get_current_object(), # type: ignore flask_app=current_app._get_current_object(), # type: ignore
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel parallel_start_node_id=edge.target_node_id,
q=q q=q
)) ))
succeeded_count = 0
while True: while True:
try: try:
event = q.get(timeout=1) event = q.get(timeout=1)
if event is None: if event is None:
break break
# TODO tag event with parallel id if isinstance(event, GraphRunSucceededEvent):
yield event succeeded_count += 1
if succeeded_count == len(edge_mappings):
break
continue
elif isinstance(event, GraphRunFailedEvent):
raise GraphRunFailedError(event.reason)
else:
yield event
except queue.Empty: except queue.Empty:
continue continue
@ -246,19 +255,15 @@ class GraphEngine:
for item in generator: for item in generator:
q.put(item) q.put(item)
if isinstance(item, NodeRunFailedEvent):
q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.'))
return
# trigger graph run success event # trigger graph run success event
q.put(GraphRunSucceededEvent()) q.put(GraphRunSucceededEvent())
except (GraphRunFailedError, NodeRunFailedError) as e: except GraphRunFailedError as e:
q.put(GraphRunFailedEvent(reason=e.error)) q.put(GraphRunFailedEvent(reason=e.error))
except Exception as e: except Exception as e:
logger.exception("Unknown Error when generating in parallel") logger.exception("Unknown Error when generating in parallel")
q.put(GraphRunFailedEvent(reason=str(e))) q.put(GraphRunFailedEvent(reason=str(e)))
finally: finally:
q.put(None)
db.session.remove() db.session.remove()
def _run_node(self, def _run_node(self,
@ -268,17 +273,35 @@ class GraphEngine:
""" """
Run node Run node
""" """
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
# get node config # get node config
node_id = route_node_state.node_id node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id) node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config: if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.') route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} config not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
return
# convert to specific node # convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type')) node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type) node_cls = node_classes.get(node_type)
if not node_cls: if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
return
# init workflow run state # init workflow run state
node_instance = node_cls( # type: ignore node_instance = node_cls( # type: ignore
@ -289,12 +312,6 @@ class GraphEngine:
previous_node_id=previous_node_id previous_node_id=previous_node_id
) )
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
db.session.close() db.session.close()
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node # TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
@ -307,13 +324,7 @@ class GraphEngine:
for item in generator: for item in generator:
if isinstance(item, RunCompletedEvent): if isinstance(item, RunCompletedEvent):
run_result = item.run_result run_result = item.run_result
route_node_state.status = RouteNodeState.Status.SUCCESS \ route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \
else RouteNodeState.Status.FAILED
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
route_node_state.node_run_result = run_result
route_node_state.failed_reason = run_result.error \
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
if run_result.status == WorkflowNodeExecutionStatus.FAILED: if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
@ -321,10 +332,27 @@ class GraphEngine:
route_node_state=route_node_state route_node_state=route_node_state
) )
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
)
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
yield NodeRunSucceededEvent( yield NodeRunSucceededEvent(
parallel_id=parallel_id, parallel_id=parallel_id,
route_node_state=route_node_state route_node_state=route_node_state
) )
break break
elif isinstance(item, RunStreamChunkEvent): elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
@ -340,8 +368,10 @@ class GraphEngine:
retriever_resources=item.retriever_resources, retriever_resources=item.retriever_resources,
context=item.context context=item.context
) )
except GenerateTaskStoppedException as e: except GenerateTaskStoppedException:
# trigger node run failed event # trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent( yield NodeRunFailedEvent(
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id parallel_id=parallel_id
@ -353,6 +383,34 @@ class GraphEngine:
finally: finally:
db.session.close() db.session.close()
def _append_variables_recursively(self,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
""" """
Check timeout Check timeout
@ -366,8 +424,3 @@ class GraphEngine:
class GraphRunFailedError(Exception): class GraphRunFailedError(Exception):
def __init__(self, error: str): def __init__(self, error: str):
self.error = error self.error = error
class NodeRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -26,7 +26,7 @@ class AnswerNode(BaseNode):
:return: :return:
""" """
node_data = self.node_data node_data = self.node_data
node_data = cast(self._node_data_cls, node_data) node_data = cast(AnswerNodeData, node_data)
# generate routes # generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data) generate_routes = self.extract_generate_route_from_node_data(node_data)

View File

@ -1,9 +1,20 @@
from unittest.mock import patch from unittest.mock import patch
from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable, UserFrom from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -14,31 +25,29 @@ def test_run(mock_close, mock_remove):
graph_config = { graph_config = {
"edges": [ "edges": [
{ {
"id": "llm-source-answer-target", "id": "1",
"source": "llm",
"target": "answer",
},
{
"id": "start-source-qc-target",
"source": "start", "source": "start",
"target": "qc", "target": "answer1",
}, },
{ {
"id": "qc-1-llm-target", "id": "2",
"source": "qc", "source": "answer1",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "qc-2-http-target",
"source": "qc",
"sourceHandle": "2",
"target": "http",
},
{
"id": "http-source-answer2-target",
"source": "http",
"target": "answer2", "target": "answer2",
},
{
"id": "3",
"source": "answer1",
"target": "answer3",
},
{
"id": "4",
"source": "answer2",
"target": "answer4",
},
{
"id": "5",
"source": "answer3",
"target": "answer5",
} }
], ],
"nodes": [ "nodes": [
@ -51,38 +60,43 @@ def test_run(mock_close, mock_remove):
}, },
{ {
"data": { "data": {
"type": "llm", "type": "answer",
"title": "llm" "title": "answer1",
"answer": "1"
}, },
"id": "llm" "id": "answer1"
}, },
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer" "title": "answer2",
}, "answer": "2"
"id": "answer",
},
{
"data": {
"type": "question-classifier",
"title": "qc"
},
"id": "qc",
},
{
"data": {
"type": "http-request",
"title": "http"
},
"id": "http",
},
{
"data": {
"type": "answer",
"title": "answer2"
}, },
"id": "answer2", "id": "answer2",
},
{
"data": {
"type": "answer",
"title": "answer3",
"answer": "3"
},
"id": "answer3",
},
{
"data": {
"type": "answer",
"title": "answer4",
"answer": "4"
},
"id": "answer4",
},
{
"data": {
"type": "answer",
"title": "answer5",
"answer": "5"
},
"id": "answer5",
} }
], ],
} }
@ -115,6 +129,28 @@ def test_run(mock_close, mock_remove):
print("") print("")
generator = graph_engine.run() app = Flask('test')
for item in generator:
print(type(item), item) items = []
with app.app_context():
generator = graph_engine.run()
for item in generator:
print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']:
assert item.parallel_id is not None
assert len(items) == 12
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start'
assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == 'start'
print(graph_engine.graph_runtime_state)