mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 23:45:51 +08:00
fix bugs
This commit is contained in:
parent
16e2d00157
commit
cc96acdae3
@ -43,6 +43,26 @@ class RouteNodeState(BaseModel):
|
|||||||
paused_by: Optional[str] = None
|
paused_by: Optional[str] = None
|
||||||
"""paused by"""
|
"""paused by"""
|
||||||
|
|
||||||
|
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||||
|
"""
|
||||||
|
Node finished
|
||||||
|
|
||||||
|
:param run_result: run result
|
||||||
|
"""
|
||||||
|
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
||||||
|
raise Exception(f"Route state {self.id} already finished")
|
||||||
|
|
||||||
|
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
self.status = RouteNodeState.Status.SUCCESS
|
||||||
|
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
|
self.status = RouteNodeState.Status.FAILED
|
||||||
|
self.failed_reason = run_result.error
|
||||||
|
else:
|
||||||
|
raise Exception(f"Invalid route status {run_result.status}")
|
||||||
|
|
||||||
|
self.node_run_result = run_result
|
||||||
|
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
|
||||||
|
|
||||||
class RuntimeRouteState(BaseModel):
|
class RuntimeRouteState(BaseModel):
|
||||||
routes: dict[str, list[str]] = Field(
|
routes: dict[str, list[str]] = Field(
|
||||||
@ -87,29 +107,3 @@ class RuntimeRouteState(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return [self.node_state_mapping[target_state_id]
|
return [self.node_state_mapping[target_state_id]
|
||||||
for target_state_id in self.routes.get(source_node_state_id, [])]
|
for target_state_id in self.routes.get(source_node_state_id, [])]
|
||||||
|
|
||||||
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
|
|
||||||
"""
|
|
||||||
Node finished
|
|
||||||
|
|
||||||
:param node_state_id: route node state id
|
|
||||||
:param run_result: run result
|
|
||||||
"""
|
|
||||||
if node_state_id not in self.node_state_mapping:
|
|
||||||
raise Exception(f"Route state {node_state_id} not found")
|
|
||||||
|
|
||||||
route = self.node_state_mapping[node_state_id]
|
|
||||||
|
|
||||||
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
|
||||||
raise Exception(f"Route state {node_state_id} already finished")
|
|
||||||
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
|
||||||
route.status = RouteNodeState.Status.SUCCESS
|
|
||||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
|
||||||
route.status = RouteNodeState.Status.FAILED
|
|
||||||
route.failed_reason = run_result.error
|
|
||||||
else:
|
|
||||||
raise Exception(f"Invalid route status {run_result.status}")
|
|
||||||
|
|
||||||
route.node_run_result = run_result
|
|
||||||
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
||||||
|
@ -3,14 +3,14 @@ import queue
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
from uritemplate.variable import VariableValue
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import NodeType, UserFrom
|
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, UserFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
@ -89,7 +89,7 @@ class GraphEngine:
|
|||||||
|
|
||||||
# trigger graph run success event
|
# trigger graph run success event
|
||||||
yield GraphRunSucceededEvent()
|
yield GraphRunSucceededEvent()
|
||||||
except (GraphRunFailedError, NodeRunFailedError) as e:
|
except GraphRunFailedError as e:
|
||||||
yield GraphRunFailedEvent(reason=e.error)
|
yield GraphRunFailedEvent(reason=e.error)
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -112,7 +112,7 @@ class GraphEngine:
|
|||||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||||
|
|
||||||
# init route node state
|
# init route node state
|
||||||
route_node_state = self.graph_runtime_state.create_node_state(
|
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
|
||||||
node_id=next_node_id
|
node_id=next_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -128,13 +128,13 @@ class GraphEngine:
|
|||||||
|
|
||||||
# append route
|
# append route
|
||||||
if previous_route_node_state:
|
if previous_route_node_state:
|
||||||
if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes:
|
self.graph_runtime_state.node_run_state.add_route(
|
||||||
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = []
|
source_node_state_id=previous_route_node_state.id,
|
||||||
|
target_node_state_id=route_node_state.id
|
||||||
self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append(
|
|
||||||
route_node_state.id
|
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
route_node_state.status = RouteNodeState.Status.FAILED
|
||||||
|
route_node_state.failed_reason = str(e)
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=in_parallel_id
|
parallel_id=in_parallel_id
|
||||||
@ -181,9 +181,9 @@ class GraphEngine:
|
|||||||
next_node_id = final_node_id
|
next_node_id = final_node_id
|
||||||
else:
|
else:
|
||||||
# if nodes has no run conditions, parallel run all nodes
|
# if nodes has no run conditions, parallel run all nodes
|
||||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].source_node_id)
|
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||||
if not parallel_id:
|
if not parallel_id:
|
||||||
raise GraphRunFailedError(f'Node {edge_mappings[0].source_node_id} related parallel not found.')
|
raise GraphRunFailedError(f'Node {edge_mappings[0].target_node_id} related parallel not found.')
|
||||||
|
|
||||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||||
if not parallel:
|
if not parallel:
|
||||||
@ -199,18 +199,27 @@ class GraphEngine:
|
|||||||
self._run_parallel_node,
|
self._run_parallel_node,
|
||||||
flask_app=current_app._get_current_object(), # type: ignore
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
parallel_start_node_id=edge.source_node_id, # source_node_id is start nodes in parallel
|
parallel_start_node_id=edge.target_node_id,
|
||||||
q=q
|
q=q
|
||||||
))
|
))
|
||||||
|
|
||||||
|
succeeded_count = 0
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
event = q.get(timeout=1)
|
event = q.get(timeout=1)
|
||||||
if event is None:
|
if event is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO tag event with parallel id
|
if isinstance(event, GraphRunSucceededEvent):
|
||||||
yield event
|
succeeded_count += 1
|
||||||
|
if succeeded_count == len(edge_mappings):
|
||||||
|
break
|
||||||
|
|
||||||
|
continue
|
||||||
|
elif isinstance(event, GraphRunFailedEvent):
|
||||||
|
raise GraphRunFailedError(event.reason)
|
||||||
|
else:
|
||||||
|
yield event
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -246,19 +255,15 @@ class GraphEngine:
|
|||||||
|
|
||||||
for item in generator:
|
for item in generator:
|
||||||
q.put(item)
|
q.put(item)
|
||||||
if isinstance(item, NodeRunFailedEvent):
|
|
||||||
q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.'))
|
|
||||||
return
|
|
||||||
|
|
||||||
# trigger graph run success event
|
# trigger graph run success event
|
||||||
q.put(GraphRunSucceededEvent())
|
q.put(GraphRunSucceededEvent())
|
||||||
except (GraphRunFailedError, NodeRunFailedError) as e:
|
except GraphRunFailedError as e:
|
||||||
q.put(GraphRunFailedEvent(reason=e.error))
|
q.put(GraphRunFailedEvent(reason=e.error))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Unknown Error when generating in parallel")
|
logger.exception("Unknown Error when generating in parallel")
|
||||||
q.put(GraphRunFailedEvent(reason=str(e)))
|
q.put(GraphRunFailedEvent(reason=str(e)))
|
||||||
finally:
|
finally:
|
||||||
q.put(None)
|
|
||||||
db.session.remove()
|
db.session.remove()
|
||||||
|
|
||||||
def _run_node(self,
|
def _run_node(self,
|
||||||
@ -268,17 +273,35 @@ class GraphEngine:
|
|||||||
"""
|
"""
|
||||||
Run node
|
Run node
|
||||||
"""
|
"""
|
||||||
|
# trigger node run start event
|
||||||
|
yield NodeRunStartedEvent(
|
||||||
|
route_node_state=route_node_state,
|
||||||
|
parallel_id=parallel_id
|
||||||
|
)
|
||||||
|
|
||||||
# get node config
|
# get node config
|
||||||
node_id = route_node_state.node_id
|
node_id = route_node_state.node_id
|
||||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||||
if not node_config:
|
if not node_config:
|
||||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
route_node_state.status = RouteNodeState.Status.FAILED
|
||||||
|
route_node_state.failed_reason = f'Node {node_id} config not found.'
|
||||||
|
yield NodeRunFailedEvent(
|
||||||
|
route_node_state=route_node_state,
|
||||||
|
parallel_id=parallel_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# convert to specific node
|
# convert to specific node
|
||||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||||
node_cls = node_classes.get(node_type)
|
node_cls = node_classes.get(node_type)
|
||||||
if not node_cls:
|
if not node_cls:
|
||||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
route_node_state.status = RouteNodeState.Status.FAILED
|
||||||
|
route_node_state.failed_reason = f'Node {node_id} type {node_type} not found.'
|
||||||
|
yield NodeRunFailedEvent(
|
||||||
|
route_node_state=route_node_state,
|
||||||
|
parallel_id=parallel_id
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
node_instance = node_cls( # type: ignore
|
node_instance = node_cls( # type: ignore
|
||||||
@ -289,12 +312,6 @@ class GraphEngine:
|
|||||||
previous_node_id=previous_node_id
|
previous_node_id=previous_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# trigger node run start event
|
|
||||||
yield NodeRunStartedEvent(
|
|
||||||
route_node_state=route_node_state,
|
|
||||||
parallel_id=parallel_id
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
|
# TODO reference from core.workflow.workflow_entry.WorkflowEntry._run_workflow_node
|
||||||
@ -307,13 +324,7 @@ class GraphEngine:
|
|||||||
for item in generator:
|
for item in generator:
|
||||||
if isinstance(item, RunCompletedEvent):
|
if isinstance(item, RunCompletedEvent):
|
||||||
run_result = item.run_result
|
run_result = item.run_result
|
||||||
route_node_state.status = RouteNodeState.Status.SUCCESS \
|
route_node_state.set_finished(run_result=run_result)
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \
|
|
||||||
else RouteNodeState.Status.FAILED
|
|
||||||
route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
|
||||||
route_node_state.node_run_result = run_result
|
|
||||||
route_node_state.failed_reason = run_result.error \
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
|
|
||||||
|
|
||||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
@ -321,10 +332,27 @@ class GraphEngine:
|
|||||||
route_node_state=route_node_state
|
route_node_state=route_node_state
|
||||||
)
|
)
|
||||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||||
|
# plus state total_tokens
|
||||||
|
self.graph_runtime_state.total_tokens += int(
|
||||||
|
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)
|
||||||
|
)
|
||||||
|
|
||||||
|
# append node output variables to variable pool
|
||||||
|
if run_result.outputs:
|
||||||
|
for variable_key, variable_value in run_result.outputs.items():
|
||||||
|
# append variables to variable pool recursively
|
||||||
|
self._append_variables_recursively(
|
||||||
|
node_id=node_id,
|
||||||
|
variable_key_list=[variable_key],
|
||||||
|
variable_value=variable_value
|
||||||
|
)
|
||||||
|
|
||||||
yield NodeRunSucceededEvent(
|
yield NodeRunSucceededEvent(
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
route_node_state=route_node_state
|
route_node_state=route_node_state
|
||||||
)
|
)
|
||||||
|
|
||||||
break
|
break
|
||||||
elif isinstance(item, RunStreamChunkEvent):
|
elif isinstance(item, RunStreamChunkEvent):
|
||||||
yield NodeRunStreamChunkEvent(
|
yield NodeRunStreamChunkEvent(
|
||||||
@ -340,8 +368,10 @@ class GraphEngine:
|
|||||||
retriever_resources=item.retriever_resources,
|
retriever_resources=item.retriever_resources,
|
||||||
context=item.context
|
context=item.context
|
||||||
)
|
)
|
||||||
except GenerateTaskStoppedException as e:
|
except GenerateTaskStoppedException:
|
||||||
# trigger node run failed event
|
# trigger node run failed event
|
||||||
|
route_node_state.status = RouteNodeState.Status.FAILED
|
||||||
|
route_node_state.failed_reason = "Workflow stopped."
|
||||||
yield NodeRunFailedEvent(
|
yield NodeRunFailedEvent(
|
||||||
route_node_state=route_node_state,
|
route_node_state=route_node_state,
|
||||||
parallel_id=parallel_id
|
parallel_id=parallel_id
|
||||||
@ -353,6 +383,34 @@ class GraphEngine:
|
|||||||
finally:
|
finally:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
def _append_variables_recursively(self,
|
||||||
|
node_id: str,
|
||||||
|
variable_key_list: list[str],
|
||||||
|
variable_value: VariableValue):
|
||||||
|
"""
|
||||||
|
Append variables recursively
|
||||||
|
:param node_id: node id
|
||||||
|
:param variable_key_list: variable key list
|
||||||
|
:param variable_value: variable value
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
self.graph_runtime_state.variable_pool.append_variable(
|
||||||
|
node_id=node_id,
|
||||||
|
variable_key_list=variable_key_list,
|
||||||
|
value=variable_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# if variable_value is a dict, then recursively append variables
|
||||||
|
if isinstance(variable_value, dict):
|
||||||
|
for key, value in variable_value.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = variable_key_list + [key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
node_id=node_id,
|
||||||
|
variable_key_list=new_key_list,
|
||||||
|
variable_value=value
|
||||||
|
)
|
||||||
|
|
||||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Check timeout
|
Check timeout
|
||||||
@ -366,8 +424,3 @@ class GraphEngine:
|
|||||||
class GraphRunFailedError(Exception):
|
class GraphRunFailedError(Exception):
|
||||||
def __init__(self, error: str):
|
def __init__(self, error: str):
|
||||||
self.error = error
|
self.error = error
|
||||||
|
|
||||||
|
|
||||||
class NodeRunFailedError(Exception):
|
|
||||||
def __init__(self, error: str):
|
|
||||||
self.error = error
|
|
||||||
|
@ -26,7 +26,7 @@ class AnswerNode(BaseNode):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
node_data = self.node_data
|
node_data = self.node_data
|
||||||
node_data = cast(self._node_data_cls, node_data)
|
node_data = cast(AnswerNodeData, node_data)
|
||||||
|
|
||||||
# generate routes
|
# generate routes
|
||||||
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
generate_routes = self.extract_generate_route_from_node_data(node_data)
|
||||||
|
@ -1,9 +1,20 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from flask import Flask
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.graph_engine.entities.event import (
|
||||||
|
BaseNodeEvent,
|
||||||
|
GraphRunFailedEvent,
|
||||||
|
GraphRunStartedEvent,
|
||||||
|
NodeRunFailedEvent,
|
||||||
|
NodeRunStartedEvent,
|
||||||
|
NodeRunSucceededEvent,
|
||||||
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -14,31 +25,29 @@ def test_run(mock_close, mock_remove):
|
|||||||
graph_config = {
|
graph_config = {
|
||||||
"edges": [
|
"edges": [
|
||||||
{
|
{
|
||||||
"id": "llm-source-answer-target",
|
"id": "1",
|
||||||
"source": "llm",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-qc-target",
|
|
||||||
"source": "start",
|
"source": "start",
|
||||||
"target": "qc",
|
"target": "answer1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": "qc-1-llm-target",
|
"id": "2",
|
||||||
"source": "qc",
|
"source": "answer1",
|
||||||
"sourceHandle": "1",
|
|
||||||
"target": "llm",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "qc-2-http-target",
|
|
||||||
"source": "qc",
|
|
||||||
"sourceHandle": "2",
|
|
||||||
"target": "http",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "http-source-answer2-target",
|
|
||||||
"source": "http",
|
|
||||||
"target": "answer2",
|
"target": "answer2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "3",
|
||||||
|
"source": "answer1",
|
||||||
|
"target": "answer3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "4",
|
||||||
|
"source": "answer2",
|
||||||
|
"target": "answer4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "5",
|
||||||
|
"source": "answer3",
|
||||||
|
"target": "answer5",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"nodes": [
|
"nodes": [
|
||||||
@ -51,38 +60,43 @@ def test_run(mock_close, mock_remove):
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "llm",
|
"type": "answer",
|
||||||
"title": "llm"
|
"title": "answer1",
|
||||||
|
"answer": "1"
|
||||||
},
|
},
|
||||||
"id": "llm"
|
"id": "answer1"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
"title": "answer"
|
"title": "answer2",
|
||||||
},
|
"answer": "2"
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "question-classifier",
|
|
||||||
"title": "qc"
|
|
||||||
},
|
|
||||||
"id": "qc",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "http-request",
|
|
||||||
"title": "http"
|
|
||||||
},
|
|
||||||
"id": "http",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
"title": "answer2"
|
|
||||||
},
|
},
|
||||||
"id": "answer2",
|
"id": "answer2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
"title": "answer3",
|
||||||
|
"answer": "3"
|
||||||
|
},
|
||||||
|
"id": "answer3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
"title": "answer4",
|
||||||
|
"answer": "4"
|
||||||
|
},
|
||||||
|
"id": "answer4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
"title": "answer5",
|
||||||
|
"answer": "5"
|
||||||
|
},
|
||||||
|
"id": "answer5",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -115,6 +129,28 @@ def test_run(mock_close, mock_remove):
|
|||||||
|
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
generator = graph_engine.run()
|
app = Flask('test')
|
||||||
for item in generator:
|
|
||||||
print(type(item), item)
|
items = []
|
||||||
|
with app.app_context():
|
||||||
|
generator = graph_engine.run()
|
||||||
|
for item in generator:
|
||||||
|
print(type(item), item)
|
||||||
|
items.append(item)
|
||||||
|
if isinstance(item, NodeRunSucceededEvent):
|
||||||
|
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||||
|
|
||||||
|
assert not isinstance(item, NodeRunFailedEvent)
|
||||||
|
assert not isinstance(item, GraphRunFailedEvent)
|
||||||
|
|
||||||
|
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']:
|
||||||
|
assert item.parallel_id is not None
|
||||||
|
|
||||||
|
assert len(items) == 12
|
||||||
|
assert isinstance(items[0], GraphRunStartedEvent)
|
||||||
|
assert isinstance(items[1], NodeRunStartedEvent)
|
||||||
|
assert items[1].route_node_state.node_id == 'start'
|
||||||
|
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||||
|
assert items[2].route_node_state.node_id == 'start'
|
||||||
|
|
||||||
|
print(graph_engine.graph_runtime_state)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user