This commit is contained in:
takatost 2024-07-17 11:26:33 +08:00
parent 16e2d00157
commit cc96acdae3
4 changed files with 199 additions and 116 deletions

View File

@ -43,6 +43,26 @@ class RouteNodeState(BaseModel):
paused_by: Optional[str] = None
"""paused by"""
def set_finished(self, run_result: NodeRunResult) -> None:
"""
Node finished
:param run_result: run result
"""
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {self.id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
self.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
self.status = RouteNodeState.Status.FAILED
self.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
self.node_run_result = run_result
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(
@ -87,29 +107,3 @@ class RuntimeRouteState(BaseModel):
"""
return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])]
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
"""
Node finished
:param node_state_id: route node state id
:param run_result: run result
"""
if node_state_id not in self.node_state_mapping:
raise Exception(f"Route state {node_state_id} not found")
route = self.node_state_mapping[node_state_id]
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {node_state_id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
route.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
route.status = RouteNodeState.Status.FAILED
route.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
route.node_run_result = run_result
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)

View File

@ -3,14 +3,14 @@ import queue
import time
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from typing import Optional
from flask import Flask, current_app
from uritemplate.variable import VariableValue
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import NodeType, UserFrom
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.graph_engine.entities.event import (
@ -89,7 +89,7 @@ class GraphEngine:
# trigger graph run success event
yield GraphRunSucceededEvent()
except (GraphRunFailedError, NodeRunFailedError) as e:
except GraphRunFailedError as e:
yield GraphRunFailedEvent(reason=e.error)
return
except Exception as e:
@ -112,7 +112,7 @@ class GraphEngine:
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
# init route node state
route_node_state = self.graph_runtime_state.create_node_state(
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
node_id=next_node_id
)
@ -128,13 +128,13 @@ class GraphEngine:
# append route
if previous_route_node_state:
if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes:
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = []
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append(
route_node_state.id
self.graph_runtime_state.node_run_state.add_route(
source_node_state_id=previous_route_node_state.id,
target_node_state_id=route_node_state.id
)
except Exception as e:
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=in_parallel_id
@ -181,9 +181,9 @@ class GraphEngine:
next_node_id = final_node_id
else:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
@ -199,18 +199,27 @@ class GraphEngine:
self._run_parallel_node,
flask_app=current_app._get_current_object(), # type: ignore
parallel_id=parallel_id,
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
parallel_start_node_id=edge.target_node_id,
q=q
))
succeeded_count = 0
while True:
try:
event = q.get(timeout=1)
if event is None:
break
# TODO tag event with parallel id
yield event
if isinstance(event, GraphRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(edge_mappings):
break
continue
elif isinstance(event, GraphRunFailedEvent):
raise GraphRunFailedError(event.reason)
else:
yield event
except queue.Empty:
continue
@ -246,19 +255,15 @@ class GraphEngine:
for item in generator:
q.put(item)
if isinstance(item, NodeRunFailedEvent):
q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.'))
return
# trigger graph run success event
q.put(GraphRunSucceededEvent())
except (GraphRunFailedError, NodeRunFailedError) as e:
except GraphRunFailedError as e:
q.put(GraphRunFailedEvent(reason=e.error))
except Exception as e:
logger.exception("Unknown Error when generating in parallel")
q.put(GraphRunFailedEvent(reason=str(e)))
finally:
q.put(None)
db.session.remove()
def _run_node(self,
@ -268,17 +273,35 @@ class GraphEngine:
"""
Run node
"""
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
# get node config
node_id = route_node_state.node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(f'Node {node_id} config not found.')
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} config not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
return
# convert to specific node
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
if not node_cls:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
return
# init workflow run state
node_instance = node_cls( # type: ignore
@ -289,12 +312,6 @@ class GraphEngine:
previous_node_id=previous_node_id
)
# trigger node run start event
yield NodeRunStartedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
)
db.session.close()
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
@ -307,13 +324,7 @@ class GraphEngine:
for item in generator:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
route_node_state.status = RouteNodeState.Status.SUCCESS \
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \
else RouteNodeState.Status.FAILED
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
route_node_state.node_run_result = run_result
route_node_state.failed_reason = run_result.error \
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
@ -321,10 +332,27 @@ class GraphEngine:
route_node_state=route_node_state
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
)
# append node output variables to variable pool
if run_result.outputs:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
node_id=node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
yield NodeRunSucceededEvent(
parallel_id=parallel_id,
route_node_state=route_node_state
)
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
@ -340,8 +368,10 @@ class GraphEngine:
retriever_resources=item.retriever_resources,
context=item.context
)
except GenerateTaskStoppedException as e:
except GenerateTaskStoppedException:
# trigger node run failed event
route_node_state.status = RouteNodeState.Status.FAILED
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
route_node_state=route_node_state,
parallel_id=parallel_id
@ -353,6 +383,34 @@ class GraphEngine:
finally:
db.session.close()
def _append_variables_recursively(self,
node_id: str,
variable_key_list: list[str],
variable_value: VariableValue):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value
)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id,
variable_key_list=new_key_list,
variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
Check timeout
@ -366,8 +424,3 @@ class GraphEngine:
class GraphRunFailedError(Exception):
def __init__(self, error: str):
self.error = error
class NodeRunFailedError(Exception):
def __init__(self, error: str):
self.error = error

View File

@ -26,7 +26,7 @@ class AnswerNode(BaseNode):
:return:
"""
node_data = self.node_data
node_data = cast(self._node_data_cls, node_data)
node_data = cast(AnswerNodeData, node_data)
# generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data)

View File

@ -1,9 +1,20 @@
from unittest.mock import patch
from flask import Flask
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
BaseNodeEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.graph_engine.graph_engine import GraphEngine
from models.workflow import WorkflowType
@ -14,31 +25,29 @@ def test_run(mock_close, mock_remove):
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "start-source-qc-target",
"id": "1",
"source": "start",
"target": "qc",
"target": "answer1",
},
{
"id": "qc-1-llm-target",
"source": "qc",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "qc-2-http-target",
"source": "qc",
"sourceHandle": "2",
"target": "http",
},
{
"id": "http-source-answer2-target",
"source": "http",
"id": "2",
"source": "answer1",
"target": "answer2",
},
{
"id": "3",
"source": "answer1",
"target": "answer3",
},
{
"id": "4",
"source": "answer2",
"target": "answer4",
},
{
"id": "5",
"source": "answer3",
"target": "answer5",
}
],
"nodes": [
@ -51,38 +60,43 @@ def test_run(mock_close, mock_remove):
},
{
"data": {
"type": "llm",
"title": "llm"
"type": "answer",
"title": "answer1",
"answer": "1"
},
"id": "llm"
"id": "answer1"
},
{
"data": {
"type": "answer",
"title": "answer"
},
"id": "answer",
},
{
"data": {
"type": "question-classifier",
"title": "qc"
},
"id": "qc",
},
{
"data": {
"type": "http-request",
"title": "http"
},
"id": "http",
},
{
"data": {
"type": "answer",
"title": "answer2"
"title": "answer2",
"answer": "2"
},
"id": "answer2",
},
{
"data": {
"type": "answer",
"title": "answer3",
"answer": "3"
},
"id": "answer3",
},
{
"data": {
"type": "answer",
"title": "answer4",
"answer": "4"
},
"id": "answer4",
},
{
"data": {
"type": "answer",
"title": "answer5",
"answer": "5"
},
"id": "answer5",
}
],
}
@ -115,6 +129,28 @@ def test_run(mock_close, mock_remove):
print("")
generator = graph_engine.run()
for item in generator:
print(type(item), item)
app = Flask('test')
items = []
with app.app_context():
generator = graph_engine.run()
for item in generator:
print(type(item), item)
items.append(item)
if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']:
assert item.parallel_id is not None
assert len(items) == 12
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start'
assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == 'start'
print(graph_engine.graph_runtime_state)