This commit is contained in:
takatost 2024-07-17 21:17:04 +08:00
parent 90e518b05b
commit f67a88f44d
12 changed files with 275 additions and 123 deletions

View File

@ -7,7 +7,6 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -19,15 +18,6 @@ class WorkflowStreamGenerateNodes(BaseModel):
stream_node_ids: list[str] stream_node_ids: list[str]
class ChatflowStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
current_route_position: int = 0
class NodeExecutionInfo(BaseModel): class NodeExecutionInfo(BaseModel):
""" """
NodeExecutionInfo entity NodeExecutionInfo entity

View File

@ -5,6 +5,8 @@ from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
class GraphEdge(BaseModel): class GraphEdge(BaseModel):
@ -39,6 +41,10 @@ class Graph(BaseModel):
default_factory=dict, default_factory=dict,
description="graph node parallel mapping (node id: parallel id)" description="graph node parallel mapping (node id: parallel id)"
) )
answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field(
default_factory=dict,
description="answer stream generate routes"
)
@classmethod @classmethod
def init(cls, def init(cls,
@ -142,6 +148,12 @@ class Graph(BaseModel):
node_parallel_mapping=node_parallel_mapping node_parallel_mapping=node_parallel_mapping
) )
# init answer stream generate routes
answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes(
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping
)
# init graph # init graph
graph = cls( graph = cls(
root_node_id=root_node_id, root_node_id=root_node_id,
@ -149,7 +161,8 @@ class Graph(BaseModel):
node_id_config_mapping=node_id_config_mapping, node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping, edge_mapping=edge_mapping,
parallel_mapping=parallel_mapping, parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping node_parallel_mapping=node_parallel_mapping,
answer_stream_generate_routes=answer_stream_generate_routes
) )
return graph return graph

View File

@ -28,8 +28,8 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import node_classes
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import node_classes
from extensions.ext_database import db from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType

View File

@ -1,33 +0,0 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}

View File

@ -2,9 +2,9 @@ import json
from typing import cast from typing import cast
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager
from core.workflow.nodes.answer.entities import ( from core.workflow.nodes.answer.entities import (
AnswerNodeData, AnswerNodeData,
GenerateRouteChunk, GenerateRouteChunk,
@ -29,11 +29,11 @@ class AnswerNode(BaseNode):
node_data = cast(AnswerNodeData, node_data) node_data = cast(AnswerNodeData, node_data)
# generate routes # generate routes
generate_routes = self.extract_generate_route_from_node_data(node_data) generate_routes = AnswerStreamOutputManager.extract_generate_route_from_node_data(node_data)
answer = '' answer = ''
for part in generate_routes: for part in generate_routes:
if part.type == "var": if part.type == GenerateRouteChunk.ChunkType.VAR:
part = cast(VarGenerateRouteChunk, part) part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get_variable_value( value = self.graph_runtime_state.variable_pool.get_variable_value(
@ -72,67 +72,6 @@ class AnswerNode(BaseNode):
} }
) )
@classmethod
def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
""" """
@ -141,7 +80,7 @@ class AnswerNode(BaseNode):
:return: :return:
""" """
node_data = node_data node_data = node_data
node_data = cast(cls._node_data_cls, node_data) node_data = cast(AnswerNodeData, node_data)
variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors() variable_selectors = variable_template_parser.extract_variable_selectors()

View File

@ -0,0 +1,160 @@
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import (
AnswerNodeData,
AnswerStreamGenerateRoute,
GenerateRouteChunk,
TextGenerateRouteChunk,
VarGenerateRouteChunk,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerStreamOutputManager:
@classmethod
def init_stream_generate_routes(cls,
node_id_config_mapping: dict[str, dict],
edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
) -> dict[str, AnswerStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
"""
# parse stream output node value selectors of answer nodes
stream_generate_routes = {}
for node_id, node_config in node_id_config_mapping.items():
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
continue
# get generate route for stream output
generate_route = cls._extract_generate_route_selectors(node_config)
streaming_node_ids = cls._get_streaming_node_ids(
target_node_id=node_id,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping
)
if not streaming_node_ids:
continue
for streaming_node_id in streaming_node_ids:
stream_generate_routes[streaming_node_id] = AnswerStreamGenerateRoute(
answer_node_id=node_id,
generate_route=generate_route
)
return stream_generate_routes
@classmethod
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
"""
Extract generate route from node data
:param node_data: node data object
:return:
"""
variable_template_parser = VariableTemplateParser(template=node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
value_selector_mapping = {
variable_selector.variable: variable_selector.value_selector
for variable_selector in variable_selectors
}
variable_keys = list(value_selector_mapping.keys())
# format answer template
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
template_variable_keys = template_parser.variable_keys
# Take the intersection of variable_keys and template_variable_keys
variable_keys = list(set(variable_keys) & set(template_variable_keys))
template = node_data.answer
for var in variable_keys:
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
generate_routes: list[GenerateRouteChunk] = []
for part in template.split('Ω'):
if part:
if cls._is_variable(part, variable_keys):
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
value_selector = value_selector_mapping[var_key]
generate_routes.append(VarGenerateRouteChunk(
value_selector=value_selector
))
else:
generate_routes.append(TextGenerateRouteChunk(
text=part
))
return generate_routes
@classmethod
def _get_streaming_node_ids(cls,
target_node_id: str,
node_id_config_mapping: dict[str, dict],
edge_mapping: dict[str, list["GraphEdge"]]) -> list[str]: # type: ignore[name-defined]
"""
Get answer stream node IDs.
:param target_node_id: target node ID
:return:
"""
# fetch all ingoing edges from source node
ingoing_graph_edges = []
for graph_edges in edge_mapping.values():
for graph_edge in graph_edges:
if graph_edge.target_node_id == target_node_id:
ingoing_graph_edges.append(graph_edge)
if not ingoing_graph_edges:
return []
streaming_node_ids = []
for ingoing_graph_edge in ingoing_graph_edges:
source_node_id = ingoing_graph_edge.source_node_id
source_node = node_id_config_mapping.get(source_node_id)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.ANSWER.value,
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value,
NodeType.ITERATION.value,
NodeType.LOOP.value
]:
# Current node (answer nodes / multi-branch nodes / iteration nodes) cannot be stream node.
streaming_node_ids.append(target_node_id)
elif node_type == NodeType.START.value:
# Current node is START node, can be stream node.
streaming_node_ids.append(source_node_id)
else:
# Find the stream node forward.
sub_streaming_node_ids = cls._get_streaming_node_ids(
target_node_id=source_node_id,
node_id_config_mapping=node_id_config_mapping,
edge_mapping=edge_mapping
)
if sub_streaming_node_ids:
streaming_node_ids.extend(sub_streaming_node_ids)
return streaming_node_ids
@classmethod
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
"""
Extract generate route selectors
:param config: node config
:return:
"""
node_data = AnswerNodeData(**config.get("data", {}))
return cls.extract_generate_route_from_node_data(node_data)
@classmethod
def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace('{{', '').replace('}}', '')
return part.startswith('{{') and cleaned_part in variable_keys

View File

@ -1,5 +1,6 @@
from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.base_node_data_entities import BaseNodeData
@ -8,27 +9,44 @@ class AnswerNodeData(BaseNodeData):
""" """
Answer Node Data. Answer Node Data.
""" """
answer: str answer: str = Field(..., description="answer template string")
class GenerateRouteChunk(BaseModel): class GenerateRouteChunk(BaseModel):
""" """
Generate Route Chunk. Generate Route Chunk.
""" """
type: str
class ChunkType(Enum):
VAR = "var"
TEXT = "text"
type: ChunkType = Field(..., description="generate route chunk type")
class VarGenerateRouteChunk(GenerateRouteChunk): class VarGenerateRouteChunk(GenerateRouteChunk):
""" """
Var Generate Route Chunk. Var Generate Route Chunk.
""" """
type: str = "var" type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
value_selector: list[str] """generate route chunk type"""
value_selector: list[str] = Field(..., description="value selector")
class TextGenerateRouteChunk(GenerateRouteChunk): class TextGenerateRouteChunk(GenerateRouteChunk):
""" """
Text Generate Route Chunk. Text Generate Route Chunk.
""" """
type: str = "text" type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT
text: str """generate route chunk type"""
text: str = Field(..., description="text")
class AnswerStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str = Field(..., description="answer node ID")
generate_route: list[GenerateRouteChunk] = Field(..., description="answer stream generate route")
current_route_position: int = 0
"""current generate route position"""

View File

@ -0,0 +1,33 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}

View File

@ -17,10 +17,10 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes import node_classes
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.node_mapping import node_classes
from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.start.start_node import StartNode
from extensions.ext_database import db from extensions.ext_database import db
from models.workflow import ( from models.workflow import (

View File

@ -8,7 +8,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes import node_classes from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_engine_manager import WorkflowEngineManager from core.workflow.workflow_engine_manager import WorkflowEngineManager
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -50,6 +50,8 @@ def test_init():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },
@ -68,6 +70,8 @@ def test_init():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer2", "id": "answer2",
} }
@ -141,6 +145,8 @@ def test__init_iteration_graph():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },
@ -167,6 +173,8 @@ def test__init_iteration_graph():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer-in-iteration", "id": "answer-in-iteration",
"parentId": "iteration", "parentId": "iteration",
@ -270,6 +278,8 @@ def test_parallels_graph():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },
@ -350,6 +360,8 @@ def test_parallels_graph2():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },
@ -422,6 +434,8 @@ def test_parallels_graph3():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },
@ -538,6 +552,8 @@ def test_parallels_graph4():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },
@ -673,6 +689,8 @@ def test_parallels_graph5():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },
@ -816,6 +834,8 @@ def test_parallels_graph6():
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer",
"answer": "1"
}, },
"id": "answer", "id": "answer",
}, },

View File

@ -9,6 +9,7 @@ from core.workflow.graph_engine.entities.event import (
BaseNodeEvent, BaseNodeEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,
@ -187,12 +188,19 @@ def test_run_branch(mock_close, mock_remove):
"data": { "data": {
"title": "Start", "title": "Start",
"type": "start", "type": "start",
"variables": [] "variables": [{
"label": "uid",
"max_length": 48,
"options": [],
"required": True,
"type": "text-input",
"variable": "uid"
}]
}, },
"id": "start" "id": "start"
}, { }, {
"data": { "data": {
"answer": "1", "answer": "1 {{#start.uid#}}",
"title": "Answer", "title": "Answer",
"type": "answer", "type": "answer",
"variables": [] "variables": []
@ -261,7 +269,9 @@ def test_run_branch(mock_close, mock_remove):
SystemVariable.FILES: [], SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa' SystemVariable.USER_ID: 'aaa'
}, user_inputs={}) }, user_inputs={
"uid": "takato"
})
graph_engine = GraphEngine( graph_engine = GraphEngine(
tenant_id="111", tenant_id="111",
@ -286,7 +296,7 @@ def test_run_branch(mock_close, mock_remove):
with app.app_context(): with app.app_context():
generator = graph_engine.run() generator = graph_engine.run()
for item in generator: for item in generator:
# print(type(item), item) print(type(item), item)
items.append(item) items.append(item)
assert len(items) == 8 assert len(items) == 8
@ -294,5 +304,7 @@ def test_run_branch(mock_close, mock_remove):
assert items[4].route_node_state.node_id == 'if-else-1' assert items[4].route_node_state.node_id == 'if-else-1'
assert items[5].route_node_state.node_id == 'answer-1' assert items[5].route_node_state.node_id == 'answer-1'
assert items[6].route_node_state.node_id == 'answer-1' assert items[6].route_node_state.node_id == 'answer-1'
assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato'
assert isinstance(items[7], GraphRunSucceededEvent)
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) # print(graph_engine.graph_runtime_state.model_dump_json(indent=2))