From f67a88f44d61ce51c3981b5df1269da1b3e5994c Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 17 Jul 2024 21:17:04 +0800 Subject: [PATCH] fix test --- api/core/app/entities/task_entities.py | 10 -- .../workflow/graph_engine/entities/graph.py | 15 +- .../workflow/graph_engine/graph_engine.py | 2 +- api/core/workflow/nodes/__init__.py | 33 ---- api/core/workflow/nodes/answer/answer_node.py | 69 +------- .../answer/answer_stream_output_manager.py | 160 ++++++++++++++++++ api/core/workflow/nodes/answer/entities.py | 32 +++- api/core/workflow/nodes/node_mapping.py | 33 ++++ api/core/workflow/workflow_entry.py | 2 +- api/services/workflow_service.py | 2 +- .../core/workflow/graph_engine/test_graph.py | 20 +++ .../graph_engine/test_graph_engine.py | 20 ++- 12 files changed, 275 insertions(+), 123 deletions(-) create mode 100644 api/core/workflow/nodes/answer/answer_stream_output_manager.py create mode 100644 api/core/workflow/nodes/node_mapping.py diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 7bc5598984..ddf8200c77 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -7,7 +7,6 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.entities import GenerateRouteChunk from models.workflow import WorkflowNodeExecutionStatus @@ -19,15 +18,6 @@ class WorkflowStreamGenerateNodes(BaseModel): stream_node_ids: list[str] -class ChatflowStreamGenerateRoute(BaseModel): - """ - ChatflowStreamGenerateRoute entity - """ - answer_node_id: str - generate_route: list[GenerateRouteChunk] - current_route_position: int = 0 - - class NodeExecutionInfo(BaseModel): """ NodeExecutionInfo entity diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index dab4e30da6..2d52eab5f4 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -5,6 +5,8 @@ from pydantic import BaseModel, Field from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager +from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute class GraphEdge(BaseModel): @@ -39,6 +41,10 @@ class Graph(BaseModel): default_factory=dict, description="graph node parallel mapping (node id: parallel id)" ) + answer_stream_generate_routes: dict[str, AnswerStreamGenerateRoute] = Field( + default_factory=dict, + description="answer stream generate routes" + ) @classmethod def init(cls, @@ -142,6 +148,12 @@ class Graph(BaseModel): node_parallel_mapping=node_parallel_mapping ) + # init answer stream generate routes + answer_stream_generate_routes = AnswerStreamOutputManager.init_stream_generate_routes( + node_id_config_mapping=node_id_config_mapping, + edge_mapping=edge_mapping + ) + # init graph graph = cls( root_node_id=root_node_id, @@ -149,7 +161,8 @@ class Graph(BaseModel): node_id_config_mapping=node_id_config_mapping, edge_mapping=edge_mapping, parallel_mapping=parallel_mapping, - node_parallel_mapping=node_parallel_mapping + node_parallel_mapping=node_parallel_mapping, + answer_stream_generate_routes=answer_stream_generate_routes ) return graph diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 627eba19d4..6289677556 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -28,8 +28,8 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState -from core.workflow.nodes import node_classes from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.node_mapping import node_classes from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus, WorkflowType diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py index df1eb98989..e69de29bb2 100644 --- a/api/core/workflow/nodes/__init__.py +++ b/api/core/workflow/nodes/__init__.py @@ -1,33 +0,0 @@ -from core.workflow.entities.node_entities import NodeType -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode - -node_classes = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR - NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode -} diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index bf5c1617b5..a667c9ab73 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -2,9 +2,9 @@ import json from typing import cast from core.file.file_obj import FileVar -from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.answer.answer_stream_output_manager import AnswerStreamOutputManager from core.workflow.nodes.answer.entities import ( AnswerNodeData, GenerateRouteChunk, @@ -29,11 +29,11 @@ class AnswerNode(BaseNode): node_data = cast(AnswerNodeData, node_data) # generate routes - generate_routes = self.extract_generate_route_from_node_data(node_data) + generate_routes = AnswerStreamOutputManager.extract_generate_route_from_node_data(node_data) answer = '' for part in generate_routes: - if part.type == "var": + if part.type == GenerateRouteChunk.ChunkType.VAR: part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector value = self.graph_runtime_state.variable_pool.get_variable_value( @@ -72,67 +72,6 @@ class AnswerNode(BaseNode): } ) - @classmethod - def extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: - """ - Extract generate route selectors - :param config: node config - :return: - """ - node_data = cls._node_data_cls(**config.get("data", {})) - node_data = cast(cls._node_data_cls, node_data) - - return cls.extract_generate_route_from_node_data(node_data) - - @classmethod - def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: - """ - Extract generate route from node data - :param node_data: node data object - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - value_selector_mapping = { - variable_selector.variable: variable_selector.value_selector - for variable_selector in variable_selectors - } - - variable_keys = list(value_selector_mapping.keys()) - - # format answer template - template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) - template_variable_keys = template_parser.variable_keys - - # Take the intersection of variable_keys and template_variable_keys - variable_keys = list(set(variable_keys) & set(template_variable_keys)) - - template = node_data.answer - for var in variable_keys: - template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') - - generate_routes = [] - for part in template.split('Ω'): - if part: - if cls._is_variable(part, variable_keys): - var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') - value_selector = value_selector_mapping[var_key] - generate_routes.append(VarGenerateRouteChunk( - value_selector=value_selector - )) - else: - generate_routes.append(TextGenerateRouteChunk( - text=part - )) - - return generate_routes - - @classmethod - def _is_variable(cls, part, variable_keys): - cleaned_part = part.replace('{{', '').replace('}}', '') - return part.startswith('{{') and cleaned_part in variable_keys - @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """ @@ -141,7 +80,7 @@ class AnswerNode(BaseNode): :return: """ node_data = node_data - node_data = cast(cls._node_data_cls, node_data) + node_data = cast(AnswerNodeData, node_data) variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() diff --git a/api/core/workflow/nodes/answer/answer_stream_output_manager.py b/api/core/workflow/nodes/answer/answer_stream_output_manager.py new file mode 100644 index 0000000000..ff2d955cd2 --- /dev/null +++ b/api/core/workflow/nodes/answer/answer_stream_output_manager.py @@ -0,0 +1,160 @@ +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.entities import ( + AnswerNodeData, + AnswerStreamGenerateRoute, + GenerateRouteChunk, + TextGenerateRouteChunk, + VarGenerateRouteChunk, +) +from core.workflow.utils.variable_template_parser import VariableTemplateParser + + +class AnswerStreamOutputManager: + + @classmethod + def init_stream_generate_routes(cls, + node_id_config_mapping: dict[str, dict], + edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined] + ) -> dict[str, AnswerStreamGenerateRoute]: + """ + Get stream generate routes. + :return: + """ + # parse stream output node value selectors of answer nodes + stream_generate_routes = {} + for node_id, node_config in node_id_config_mapping.items(): + if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value: + continue + + # get generate route for stream output + generate_route = cls._extract_generate_route_selectors(node_config) + streaming_node_ids = cls._get_streaming_node_ids( + target_node_id=node_id, + node_id_config_mapping=node_id_config_mapping, + edge_mapping=edge_mapping + ) + + if not streaming_node_ids: + continue + + for streaming_node_id in streaming_node_ids: + stream_generate_routes[streaming_node_id] = AnswerStreamGenerateRoute( + answer_node_id=node_id, + generate_route=generate_route + ) + + return stream_generate_routes + + @classmethod + def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]: + """ + Extract generate route from node data + :param node_data: node data object + :return: + """ + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + + value_selector_mapping = { + variable_selector.variable: variable_selector.value_selector + for variable_selector in variable_selectors + } + + variable_keys = list(value_selector_mapping.keys()) + + # format answer template + template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True) + template_variable_keys = template_parser.variable_keys + + # Take the intersection of variable_keys and template_variable_keys + variable_keys = list(set(variable_keys) & set(template_variable_keys)) + + template = node_data.answer + for var in variable_keys: + template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω') + + generate_routes: list[GenerateRouteChunk] = [] + for part in template.split('Ω'): + if part: + if cls._is_variable(part, variable_keys): + var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '') + value_selector = value_selector_mapping[var_key] + generate_routes.append(VarGenerateRouteChunk( + value_selector=value_selector + )) + else: + generate_routes.append(TextGenerateRouteChunk( + text=part + )) + + return generate_routes + + @classmethod + def _get_streaming_node_ids(cls, + target_node_id: str, + node_id_config_mapping: dict[str, dict], + edge_mapping: dict[str, list["GraphEdge"]]) -> list[str]: # type: ignore[name-defined] + """ + Get answer stream node IDs. + :param target_node_id: target node ID + :return: + """ + # fetch all ingoing edges from source node + ingoing_graph_edges = [] + for graph_edges in edge_mapping.values(): + for graph_edge in graph_edges: + if graph_edge.target_node_id == target_node_id: + ingoing_graph_edges.append(graph_edge) + + if not ingoing_graph_edges: + return [] + + streaming_node_ids = [] + for ingoing_graph_edge in ingoing_graph_edges: + source_node_id = ingoing_graph_edge.source_node_id + source_node = node_id_config_mapping.get(source_node_id) + if not source_node: + continue + + node_type = source_node.get('data', {}).get('type') + + if node_type in [ + NodeType.ANSWER.value, + NodeType.IF_ELSE.value, + NodeType.QUESTION_CLASSIFIER.value, + NodeType.ITERATION.value, + NodeType.LOOP.value + ]: + # Current node (answer nodes / multi-branch nodes / iteration nodes) cannot be stream node. + streaming_node_ids.append(target_node_id) + elif node_type == NodeType.START.value: + # Current node is START node, can be stream node. + streaming_node_ids.append(source_node_id) + else: + # Find the stream node forward. + sub_streaming_node_ids = cls._get_streaming_node_ids( + target_node_id=source_node_id, + node_id_config_mapping=node_id_config_mapping, + edge_mapping=edge_mapping + ) + + if sub_streaming_node_ids: + streaming_node_ids.extend(sub_streaming_node_ids) + + return streaming_node_ids + + @classmethod + def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]: + """ + Extract generate route selectors + :param config: node config + :return: + """ + node_data = AnswerNodeData(**config.get("data", {})) + return cls.extract_generate_route_from_node_data(node_data) + + @classmethod + def _is_variable(cls, part, variable_keys): + cleaned_part = part.replace('{{', '').replace('}}', '') + return part.startswith('{{') and cleaned_part in variable_keys diff --git a/api/core/workflow/nodes/answer/entities.py b/api/core/workflow/nodes/answer/entities.py index 9effbbbe67..6f29af2027 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/core/workflow/nodes/answer/entities.py @@ -1,5 +1,6 @@ +from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.workflow.entities.base_node_data_entities import BaseNodeData @@ -8,27 +9,44 @@ class AnswerNodeData(BaseNodeData): """ Answer Node Data. """ - answer: str + answer: str = Field(..., description="answer template string") class GenerateRouteChunk(BaseModel): """ Generate Route Chunk. """ - type: str + + class ChunkType(Enum): + VAR = "var" + TEXT = "text" + + type: ChunkType = Field(..., description="generate route chunk type") class VarGenerateRouteChunk(GenerateRouteChunk): """ Var Generate Route Chunk. """ - type: str = "var" - value_selector: list[str] + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR + """generate route chunk type""" + value_selector: list[str] = Field(..., description="value selector") class TextGenerateRouteChunk(GenerateRouteChunk): """ Text Generate Route Chunk. """ - type: str = "text" - text: str + type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT + """generate route chunk type""" + text: str = Field(..., description="text") + + +class AnswerStreamGenerateRoute(BaseModel): + """ + ChatflowStreamGenerateRoute entity + """ + answer_node_id: str = Field(..., description="answer node ID") + generate_route: list[GenerateRouteChunk] = Field(..., description="answer stream generate route") + current_route_position: int = 0 + """current generate route position""" diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py new file mode 100644 index 0000000000..df1eb98989 --- /dev/null +++ b/api/core/workflow/nodes/node_mapping.py @@ -0,0 +1,33 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode + +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: IterationNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode +} diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 7ea9a76d5b..98baa024ae 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -17,10 +17,10 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes import node_classes from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.llm.entities import LLMNodeData +from core.workflow.nodes.node_mapping import node_classes from core.workflow.nodes.start.start_node import StartNode from extensions.ext_database import db from models.workflow import ( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 93c01315ea..ea76cfa2e8 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -8,7 +8,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.nodes import node_classes +from core.workflow.nodes.node_mapping import node_classes from core.workflow.workflow_engine_manager import WorkflowEngineManager from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py index 2987343693..a29ba16d9f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -50,6 +50,8 @@ def test_init(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, @@ -68,6 +70,8 @@ def test_init(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer2", } @@ -141,6 +145,8 @@ def test__init_iteration_graph(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, @@ -167,6 +173,8 @@ def test__init_iteration_graph(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer-in-iteration", "parentId": "iteration", @@ -270,6 +278,8 @@ def test_parallels_graph(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, @@ -350,6 +360,8 @@ def test_parallels_graph2(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, @@ -422,6 +434,8 @@ def test_parallels_graph3(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, @@ -538,6 +552,8 @@ def test_parallels_graph4(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, @@ -673,6 +689,8 @@ def test_parallels_graph5(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, @@ -816,6 +834,8 @@ def test_parallels_graph6(): { "data": { "type": "answer", + "title": "answer", + "answer": "1" }, "id": "answer", }, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index ef6c5bf289..b59f4af6a1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -9,6 +9,7 @@ from core.workflow.graph_engine.entities.event import ( BaseNodeEvent, GraphRunFailedEvent, GraphRunStartedEvent, + GraphRunSucceededEvent, NodeRunFailedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, @@ -187,12 +188,19 @@ def test_run_branch(mock_close, mock_remove): "data": { "title": "Start", "type": "start", - "variables": [] + "variables": [{ + "label": "uid", + "max_length": 48, + "options": [], + "required": True, + "type": "text-input", + "variable": "uid" + }] }, "id": "start" }, { "data": { - "answer": "1", + "answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": [] @@ -261,7 +269,9 @@ def test_run_branch(mock_close, mock_remove): SystemVariable.FILES: [], SystemVariable.CONVERSATION_ID: 'abababa', SystemVariable.USER_ID: 'aaa' - }, user_inputs={}) + }, user_inputs={ + "uid": "takato" + }) graph_engine = GraphEngine( tenant_id="111", @@ -286,7 +296,7 @@ def test_run_branch(mock_close, mock_remove): with app.app_context(): generator = graph_engine.run() for item in generator: - # print(type(item), item) + print(type(item), item) items.append(item) assert len(items) == 8 @@ -294,5 +304,7 @@ def test_run_branch(mock_close, mock_remove): assert items[4].route_node_state.node_id == 'if-else-1' assert items[5].route_node_state.node_id == 'answer-1' assert items[6].route_node_state.node_id == 'answer-1' + assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato' + assert isinstance(items[7], GraphRunSucceededEvent) # print(graph_engine.graph_runtime_state.model_dump_json(indent=2))