From 2c695ded79d4ad8bfce97435d1588329a9f563dd Mon Sep 17 00:00:00 2001 From: takatost Date: Tue, 23 Jul 2024 00:10:23 +0800 Subject: [PATCH] fix bugs --- api/core/workflow/entities/variable_pool.py | 22 +++-- .../workflow/graph_engine/entities/graph.py | 14 +-- .../workflow/graph_engine/graph_engine.py | 11 +-- api/core/workflow/nodes/answer/answer_node.py | 2 +- .../nodes/answer/answer_stream_processor.py | 56 ++--------- .../workflow/utils/condition/processor.py | 43 ++++----- .../graph_engine/test_graph_engine.py | 21 +++-- .../answer/test_answer_stream_processor.py | 5 +- .../core/workflow/nodes/test_if_else.py | 93 +++++++++++++------ 9 files changed, 140 insertions(+), 127 deletions(-) diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 38d52e0f75..82480d59b1 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -2,7 +2,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from typing import Any, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from typing_extensions import deprecated from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory @@ -21,7 +21,7 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - _variable_dictionary: dict[str, dict[int, Variable]] = Field( + variable_dictionary: dict[str, dict[int, Variable]] = Field( description='Variables mapping', default=defaultdict(dict) ) @@ -36,10 +36,12 @@ class VariablePool(BaseModel): ) environment_variables: Sequence[Variable] = Field( - description="Environment variables." + description="Environment variables.", + default_factory=list ) - def __post_init__(self): + @model_validator(mode="after") + def val_model_after(self): """ Append system variables :return: @@ -52,6 +54,8 @@ class VariablePool(BaseModel): for var in self.environment_variables or []: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) + return self + def add(self, selector: Sequence[str], value: Any, /) -> None: """ Adds a variable to the variable pool. @@ -78,7 +82,7 @@ class VariablePool(BaseModel): v = value hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]][hash_key] = v + self.variable_dictionary[selector[0]][hash_key] = v def get(self, selector: Sequence[str], /) -> Variable | None: """ @@ -96,7 +100,7 @@ class VariablePool(BaseModel): if len(selector) < 2: raise ValueError('Invalid selector') hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) return value @@ -117,7 +121,7 @@ class VariablePool(BaseModel): if len(selector) < 2: raise ValueError('Invalid selector') hash_key = hash(tuple(selector[1:])) - value = self._variable_dictionary[selector[0]].get(hash_key) + value = self.variable_dictionary[selector[0]].get(hash_key) if value is None: return value @@ -140,7 +144,7 @@ class VariablePool(BaseModel): if not selector: return if len(selector) == 1: - self._variable_dictionary[selector[0]] = {} + self.variable_dictionary[selector[0]] = {} return hash_key = hash(tuple(selector[1:])) - self._variable_dictionary[selector[0]].pop(hash_key, None) + self.variable_dictionary[selector[0]].pop(hash_key, None) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index e12f7cec47..714503df86 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -285,16 +285,18 @@ class Graph(BaseModel): ) # collect all branches node ids - end_to_node_id: Optional[str] = None for branch_node_id, node_ids in in_branch_node_ids.items(): for node_id in node_ids: node_parallel_mapping[node_id] = parallel.id - if not end_to_node_id and edge_mapping.get(node_id): - node_edges = edge_mapping[node_id] - target_node_id = node_edges[0].target_node_id - if node_parallel_mapping.get(target_node_id) == parent_parallel_id: - end_to_node_id = target_node_id + end_to_node_id: Optional[str] = None + for node_id in node_parallel_mapping: + if not end_to_node_id and edge_mapping.get(node_id): + node_edges = edge_mapping[node_id] + target_node_id = node_edges[0].target_node_id + if node_parallel_mapping.get(target_node_id) == parent_parallel_id: + end_to_node_id = target_node_id + break if end_to_node_id: parallel.end_to_node_id = end_to_node_id diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index d11994dc9c..7fcdc46223 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -244,8 +244,8 @@ class GraphEngine: next_node_id = final_node_id - if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id: - break + # if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id: + # break def _run_parallel_node(self, flask_app: Flask, @@ -402,10 +402,9 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.append_variable( - node_id=node_id, - variable_key_list=variable_key_list, - value=variable_value # type: ignore[arg-type] + self.graph_runtime_state.variable_pool.add( + [node_id] + variable_key_list, + variable_value ) # if variable_value is a dict, then recursively append variables diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 79659cbab9..b54a0f9cdb 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -35,7 +35,7 @@ class AnswerNode(BaseNode): part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector value = self.graph_runtime_state.variable_pool.get( - variable_selector=value_selector + value_selector ) if value: diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py index a28a7786f7..851a66c9ba 100644 --- a/api/core/workflow/nodes/answer/answer_stream_processor.py +++ b/api/core/workflow/nodes/answer/answer_stream_processor.py @@ -1,4 +1,3 @@ -import json import logging from collections.abc import Generator from typing import Optional, cast @@ -98,7 +97,7 @@ class AnswerStreamProcessor: def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: node_ids = [] - for edge in self.graph.edge_mapping[node_id]: + for edge in self.graph.edge_mapping.get(node_id, []): node_ids.append(edge.target_node_id) node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) return node_ids @@ -108,7 +107,7 @@ class AnswerStreamProcessor: remove target node ids until merge """ self.rest_node_ids.remove(node_id) - for edge in self.graph.edge_mapping[node_id]: + for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id in reachable_node_ids: continue @@ -124,8 +123,10 @@ class AnswerStreamProcessor: """ for answer_node_id, position in self.route_position.items(): # all depends on answer node id not in rest node ids - if not all(dep_id not in self.rest_node_ids - for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): + if (event.route_node_state.node_id != answer_node_id + and (answer_node_id not in self.rest_node_ids + or not all(dep_id not in self.rest_node_ids + for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))): continue route_position = self.route_position[answer_node_id] @@ -145,53 +146,14 @@ class AnswerStreamProcessor: if not value_selector: break - value = self.variable_pool.get_variable_value( - variable_selector=value_selector + value = self.variable_pool.get( + value_selector ) if value is None: break - text = '' - if isinstance(value, str | int | float): - text = str(value) - elif isinstance(value, FileVar): - # convert file to markdown - text = value.to_markdown() - elif isinstance(value, dict): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - if file_vars: - file_var = file_vars[0] - try: - file_var_obj = FileVar(**file_var) - - # convert file to markdown - text = file_var_obj.to_markdown() - except Exception as e: - logger.error(f'Error creating file var: {e}') - - if not text: - # other types - text = json.dumps(value, ensure_ascii=False) - elif isinstance(value, list): - # handle files - file_vars = self._fetch_files_from_variable_value(value) - for file_var in file_vars: - try: - file_var_obj = FileVar(**file_var) - except Exception as e: - logger.error(f'Error creating file var: {e}') - continue - - # convert file to markdown - text = file_var_obj.to_markdown() + ' ' - - text = text.strip() - - if not text and value: - # other types - text = json.dumps(value, ensure_ascii=False) + text = value.markdown if text: yield NodeRunStreamChunkEvent( diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index d617d3abdc..5ff61aab3d 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import Any, Optional from core.file.file_obj import FileVar @@ -7,15 +8,15 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser class ConditionProcessor: - def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]): + def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]): input_conditions = [] group_result = [] index = 0 for condition in conditions: index += 1 - actual_value = variable_pool.get_variable_value( - variable_selector=condition.variable_selector + actual_value = variable_pool.get_any( + condition.variable_selector ) expected_value = None @@ -24,8 +25,8 @@ class ConditionProcessor: variable_selectors = variable_template_parser.extract_variable_selectors() if variable_selectors: for variable_selector in variable_selectors: - value = variable_pool.get_variable_value( - variable_selector=variable_selector.value_selector + value = variable_pool.get_any( + variable_selector.value_selector ) expected_value = variable_template_parser.format({variable_selector.variable: value}) @@ -63,37 +64,37 @@ class ConditionProcessor: :return: bool """ if comparison_operator == "contains": - return self._assert_contains(actual_value, expected_value) # type: ignore + return self._assert_contains(actual_value, expected_value) elif comparison_operator == "not contains": - return self._assert_not_contains(actual_value, expected_value) # type: ignore + return self._assert_not_contains(actual_value, expected_value) elif comparison_operator == "start with": - return self._assert_start_with(actual_value, expected_value) # type: ignore + return self._assert_start_with(actual_value, expected_value) elif comparison_operator == "end with": - return self._assert_end_with(actual_value, expected_value) # type: ignore + return self._assert_end_with(actual_value, expected_value) elif comparison_operator == "is": - return self._assert_is(actual_value, expected_value) # type: ignore + return self._assert_is(actual_value, expected_value) elif comparison_operator == "is not": - return self._assert_is_not(actual_value, expected_value) # type: ignore + return self._assert_is_not(actual_value, expected_value) elif comparison_operator == "empty": - return self._assert_empty(actual_value) # type: ignore + return self._assert_empty(actual_value) elif comparison_operator == "not empty": - return self._assert_not_empty(actual_value) # type: ignore + return self._assert_not_empty(actual_value) elif comparison_operator == "=": - return self._assert_equal(actual_value, expected_value) # type: ignore + return self._assert_equal(actual_value, expected_value) elif comparison_operator == "≠": - return self._assert_not_equal(actual_value, expected_value) # type: ignore + return self._assert_not_equal(actual_value, expected_value) elif comparison_operator == ">": - return self._assert_greater_than(actual_value, expected_value) # type: ignore + return self._assert_greater_than(actual_value, expected_value) elif comparison_operator == "<": - return self._assert_less_than(actual_value, expected_value) # type: ignore + return self._assert_less_than(actual_value, expected_value) elif comparison_operator == "≥": - return self._assert_greater_than_or_equal(actual_value, expected_value) # type: ignore + return self._assert_greater_than_or_equal(actual_value, expected_value) elif comparison_operator == "≤": - return self._assert_less_than_or_equal(actual_value, expected_value) # type: ignore + return self._assert_less_than_or_equal(actual_value, expected_value) elif comparison_operator == "null": - return self._assert_null(actual_value) # type: ignore + return self._assert_null(actual_value) elif comparison_operator == "not null": - return self._assert_not_null(actual_value) # type: ignore + return self._assert_not_null(actual_value) else: raise ValueError(f"Invalid comparison operator: {comparison_operator}") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 69362b7e1b..75e70ee65b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.event import ( GraphRunSucceededEvent, NodeRunFailedEvent, NodeRunStartedEvent, + NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph @@ -139,10 +140,12 @@ def test_run_parallel(mock_close, mock_remove): assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent) - if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']: + if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [ + 'answer2', 'answer3', 'answer4', 'answer5' + ]: assert item.parallel_id is not None - assert len(items) == 12 + assert len(items) == 19 assert isinstance(items[0], GraphRunStartedEvent) assert isinstance(items[1], NodeRunStartedEvent) assert items[1].route_node_state.node_id == 'start' @@ -291,12 +294,16 @@ def test_run_branch(mock_close, mock_remove): print(type(item), item) items.append(item) - assert len(items) == 8 + assert len(items) == 10 assert items[3].route_node_state.node_id == 'if-else-1' assert items[4].route_node_state.node_id == 'if-else-1' - assert items[5].route_node_state.node_id == 'answer-1' - assert items[6].route_node_state.node_id == 'answer-1' - assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato' - assert isinstance(items[7], GraphRunSucceededEvent) + assert isinstance(items[5], NodeRunStreamChunkEvent) + assert items[5].chunk_content == '1 ' + assert isinstance(items[6], NodeRunStreamChunkEvent) + assert items[6].chunk_content == 'takato' + assert items[7].route_node_state.node_id == 'answer-1' + assert items[8].route_node_state.node_id == 'answer-1' + assert items[8].route_node_state.node_run_result.outputs['answer'] == '1 takato' + assert isinstance(items[9], GraphRunSucceededEvent) # print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index c0c95d7bb8..fe1ebffa4d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -187,9 +187,8 @@ def test_process(): # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) if isinstance(event, NodeRunSucceededEvent): if 'llm' in event.route_node_state.node_id: - variable_pool.append_variable( - event.route_node_state.node_id, - ["text"], + variable_pool.add( + [event.route_node_state.node_id, "text"], "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))) ) yield event diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c1ebafa968..9f1e5c4517 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -45,23 +45,23 @@ def test_execute_if_else_result_true(): SystemVariable.FILES: [], SystemVariable.USER_ID: 'aaa' }, user_inputs={}) - pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) - pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) - pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde') - pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde') - pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc') - pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab') - pool.append_variable(node_id='start', variable_key_list=['is'], value='ab') - pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab') - pool.append_variable(node_id='start', variable_key_list=['empty'], value='') - pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa') - pool.append_variable(node_id='start', variable_key_list=['equals'], value=22) - pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23) - pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23) - pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21) - pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22) - pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21) - pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212') + pool.add(['start', 'array_contains'], ['ab', 'def']) + pool.add(['start', 'array_not_contains'], ['ac', 'def']) + pool.add(['start', 'contains'], 'cabcde') + pool.add(['start', 'not_contains'], 'zacde') + pool.add(['start', 'start_with'], 'abc') + pool.add(['start', 'end_with'], 'zzab') + pool.add(['start', 'is'], 'ab') + pool.add(['start', 'is_not'], 'aab') + pool.add(['start', 'empty'], '') + pool.add(['start', 'not_empty'], 'aaa') + pool.add(['start', 'equals'], 22) + pool.add(['start', 'not_equals'], 23) + pool.add(['start', 'greater_than'], 23) + pool.add(['start', 'less_than'], 21) + pool.add(['start', 'greater_than_or_equal'], 22) + pool.add(['start', 'less_than_or_equal'], 21) + pool.add(['start', 'not_null'], '1212') node = IfElseNode( graph_init_params=init_params, @@ -204,13 +204,60 @@ def test_execute_if_else_result_true(): def test_execute_if_else_result_false(): - node = IfElseNode( + graph_config = { + "edges": [ + { + "id": "start-source-llm-target", + "source": "start", + "target": "llm", + }, + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm" + }, + ] + } + + graph = Graph.init( + graph_config=graph_config + ) + + init_params = GraphInitParams( tenant_id='1', app_id='1', + workflow_type=WorkflowType.WORKFLOW, workflow_id='1', user_id='1', user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, + call_depth=0 + ) + + # construct variable pool + pool = VariablePool(system_variables={ + SystemVariable.FILES: [], + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}, environment_variables=[]) + pool.add(['start', 'array_contains'], ['1ab', 'def']) + pool.add(['start', 'array_not_contains'], ['ab', 'def']) + + node = IfElseNode( + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState( + variable_pool=pool, + start_at=time.perf_counter() + ), config={ 'id': 'if-else', 'data': { @@ -233,19 +280,11 @@ def test_execute_if_else_result_false(): } ) - # construct variable pool - pool = VariablePool(system_variables={ - SystemVariable.FILES: [], - SystemVariable.USER_ID: 'aaa' - }, user_inputs={}, environment_variables=[]) - pool.add(['start', 'array_contains'], ['1ab', 'def']) - pool.add(['start', 'array_not_contains'], ['ab', 'def']) - # Mock db.session.close() db.session.close = MagicMock() # execute node - result = node._run(pool) + result = node._run() assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs['result'] is False