mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 00:55:53 +08:00
fix bugs
This commit is contained in:
parent
16e2d00157
commit
cc96acdae3
@ -43,6 +43,26 @@ class RouteNodeState(BaseModel):
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
self.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
@ -87,29 +107,3 @@ class RuntimeRouteState(BaseModel):
|
||||
"""
|
||||
return [self.node_state_mapping[target_state_id]
|
||||
for target_state_id in self.routes.get(source_node_state_id, [])]
|
||||
|
||||
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param node_state_id: route node state id
|
||||
:param run_result: run result
|
||||
"""
|
||||
if node_state_id not in self.node_state_mapping:
|
||||
raise Exception(f"Route state {node_state_id} not found")
|
||||
|
||||
route = self.node_state_mapping[node_state_id]
|
||||
|
||||
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
||||
raise Exception(f"Route state {node_state_id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
route.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
route.status = RouteNodeState.Status.FAILED
|
||||
route.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
route.node_run_result = run_result
|
||||
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
@ -3,14 +3,14 @@ import queue
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from uritemplate.variable import VariableValue
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeType, UserFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
@ -89,7 +89,7 @@ class GraphEngine:
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent()
|
||||
except (GraphRunFailedError, NodeRunFailedError) as e:
|
||||
except GraphRunFailedError as e:
|
||||
yield GraphRunFailedEvent(reason=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
@ -112,7 +112,7 @@ class GraphEngine:
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# init route node state
|
||||
route_node_state = self.graph_runtime_state.create_node_state(
|
||||
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
|
||||
node_id=next_node_id
|
||||
)
|
||||
|
||||
@ -128,13 +128,13 @@ class GraphEngine:
|
||||
|
||||
# append route
|
||||
if previous_route_node_state:
|
||||
if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes:
|
||||
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = []
|
||||
|
||||
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append(
|
||||
route_node_state.id
|
||||
self.graph_runtime_state.node_run_state.add_route(
|
||||
source_node_state_id=previous_route_node_state.id,
|
||||
target_node_state_id=route_node_state.id
|
||||
)
|
||||
except Exception as e:
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id
|
||||
@ -181,9 +181,9 @@ class GraphEngine:
|
||||
next_node_id = final_node_id
|
||||
else:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
if not parallel_id:
|
||||
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
|
||||
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
@ -199,18 +199,27 @@ class GraphEngine:
|
||||
self._run_parallel_node,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
|
||||
parallel_start_node_id=edge.target_node_id,
|
||||
q=q
|
||||
))
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
# TODO tag event with parallel id
|
||||
yield event
|
||||
if isinstance(event, GraphRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(edge_mappings):
|
||||
break
|
||||
|
||||
continue
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
raise GraphRunFailedError(event.reason)
|
||||
else:
|
||||
yield event
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
@ -246,19 +255,15 @@ class GraphEngine:
|
||||
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.'))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
q.put(GraphRunSucceededEvent())
|
||||
except (GraphRunFailedError, NodeRunFailedError) as e:
|
||||
except GraphRunFailedError as e:
|
||||
q.put(GraphRunFailedEvent(reason=e.error))
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(GraphRunFailedEvent(reason=str(e)))
|
||||
finally:
|
||||
q.put(None)
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(self,
|
||||
@ -268,17 +273,35 @@ class GraphEngine:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
)
|
||||
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = f'Node {node_id} config not found.'
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
)
|
||||
return
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
)
|
||||
return
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
@ -289,12 +312,6 @@ class GraphEngine:
|
||||
previous_node_id=previous_node_id
|
||||
)
|
||||
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
|
||||
@ -307,13 +324,7 @@ class GraphEngine:
|
||||
for item in generator:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.status = RouteNodeState.Status.SUCCESS \
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \
|
||||
else RouteNodeState.Status.FAILED
|
||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
route_node_state.node_run_result = run_result
|
||||
route_node_state.failed_reason = run_result.error \
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
@ -321,10 +332,27 @@ class GraphEngine:
|
||||
route_node_state=route_node_state
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
|
||||
)
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
parallel_id=parallel_id,
|
||||
route_node_state=route_node_state
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
@ -340,8 +368,10 @@ class GraphEngine:
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context
|
||||
)
|
||||
except GenerateTaskStoppedException as e:
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id
|
||||
@ -353,6 +383,34 @@ class GraphEngine:
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self,
|
||||
node_id: str,
|
||||
variable_key_list: list[str],
|
||||
variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
self.graph_runtime_state.variable_pool.append_variable(
|
||||
node_id=node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=variable_value
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
@ -366,8 +424,3 @@ class GraphEngine:
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
||||
|
||||
class NodeRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
@ -26,7 +26,7 @@ class AnswerNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(AnswerNodeData, node_data)
|
||||
|
||||
# generate routes
|
||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||
|
@ -1,9 +1,20 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseNodeEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
@ -14,31 +25,29 @@ def test_run(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-source-answer-target",
|
||||
"source": "llm",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "start-source-qc-target",
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "qc",
|
||||
"target": "answer1",
|
||||
},
|
||||
{
|
||||
"id": "qc-1-llm-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "1",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "qc-2-http-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "2",
|
||||
"target": "http",
|
||||
},
|
||||
{
|
||||
"id": "http-source-answer2-target",
|
||||
"source": "http",
|
||||
"id": "2",
|
||||
"source": "answer1",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "answer1",
|
||||
"target": "answer3",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "answer2",
|
||||
"target": "answer4",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "answer3",
|
||||
"target": "answer5",
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
@ -51,38 +60,43 @@ def test_run(mock_close, mock_remove):
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm"
|
||||
"type": "answer",
|
||||
"title": "answer1",
|
||||
"answer": "1"
|
||||
},
|
||||
"id": "llm"
|
||||
"id": "answer1"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer"
|
||||
},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "question-classifier",
|
||||
"title": "qc"
|
||||
},
|
||||
"id": "qc",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
"title": "http"
|
||||
},
|
||||
"id": "http",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer2"
|
||||
"title": "answer2",
|
||||
"answer": "2"
|
||||
},
|
||||
"id": "answer2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer3",
|
||||
"answer": "3"
|
||||
},
|
||||
"id": "answer3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer4",
|
||||
"answer": "4"
|
||||
},
|
||||
"id": "answer4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "answer5",
|
||||
"answer": "5"
|
||||
},
|
||||
"id": "answer5",
|
||||
}
|
||||
],
|
||||
}
|
||||
@ -115,6 +129,28 @@ def test_run(mock_close, mock_remove):
|
||||
|
||||
print("")
|
||||
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
print(type(item), item)
|
||||
app = Flask('test')
|
||||
|
||||
items = []
|
||||
with app.app_context():
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
print(type(item), item)
|
||||
items.append(item)
|
||||
if isinstance(item, NodeRunSucceededEvent):
|
||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||
|
||||
assert not isinstance(item, NodeRunFailedEvent)
|
||||
assert not isinstance(item, GraphRunFailedEvent)
|
||||
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 12
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == 'start'
|
||||
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||
assert items[2].route_node_state.node_id == 'start'
|
||||
|
||||
print(graph_engine.graph_runtime_state)
|
||||
|
Loading…
x
Reference in New Issue
Block a user