mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 03:35:57 +08:00
fix bugs
This commit is contained in:
parent
a603e01f5e
commit
2c695ded79
@ -2,7 +2,7 @@ from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
|
||||
@ -21,7 +21,7 @@ class VariablePool(BaseModel):
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
_variable_dictionary: dict[str, dict[int, Variable]] = Field(
|
||||
variable_dictionary: dict[str, dict[int, Variable]] = Field(
|
||||
description='Variables mapping',
|
||||
default=defaultdict(dict)
|
||||
)
|
||||
@ -36,10 +36,12 @@ class VariablePool(BaseModel):
|
||||
)
|
||||
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables."
|
||||
description="Environment variables.",
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@model_validator(mode="after")
|
||||
def val_model_after(self):
|
||||
"""
|
||||
Append system variables
|
||||
:return:
|
||||
@ -52,6 +54,8 @@ class VariablePool(BaseModel):
|
||||
for var in self.environment_variables or []:
|
||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||
|
||||
return self
|
||||
|
||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||
"""
|
||||
Adds a variable to the variable pool.
|
||||
@ -78,7 +82,7 @@ class VariablePool(BaseModel):
|
||||
v = value
|
||||
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]][hash_key] = v
|
||||
self.variable_dictionary[selector[0]][hash_key] = v
|
||||
|
||||
def get(self, selector: Sequence[str], /) -> Variable | None:
|
||||
"""
|
||||
@ -96,7 +100,7 @@ class VariablePool(BaseModel):
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
return value
|
||||
|
||||
@ -117,7 +121,7 @@ class VariablePool(BaseModel):
|
||||
if len(selector) < 2:
|
||||
raise ValueError('Invalid selector')
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
||||
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||
|
||||
if value is None:
|
||||
return value
|
||||
@ -140,7 +144,7 @@ class VariablePool(BaseModel):
|
||||
if not selector:
|
||||
return
|
||||
if len(selector) == 1:
|
||||
self._variable_dictionary[selector[0]] = {}
|
||||
self.variable_dictionary[selector[0]] = {}
|
||||
return
|
||||
hash_key = hash(tuple(selector[1:]))
|
||||
self._variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||
|
@ -285,16 +285,18 @@ class Graph(BaseModel):
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
end_to_node_id: Optional[str] = None
|
||||
for branch_node_id, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
if not end_to_node_id and edge_mapping.get(node_id):
|
||||
node_edges = edge_mapping[node_id]
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
|
||||
end_to_node_id = target_node_id
|
||||
end_to_node_id: Optional[str] = None
|
||||
for node_id in node_parallel_mapping:
|
||||
if not end_to_node_id and edge_mapping.get(node_id):
|
||||
node_edges = edge_mapping[node_id]
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
|
||||
end_to_node_id = target_node_id
|
||||
break
|
||||
|
||||
if end_to_node_id:
|
||||
parallel.end_to_node_id = end_to_node_id
|
||||
|
@ -244,8 +244,8 @@ class GraphEngine:
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||
break
|
||||
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||
# break
|
||||
|
||||
def _run_parallel_node(self,
|
||||
flask_app: Flask,
|
||||
@ -402,10 +402,9 @@ class GraphEngine:
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
self.graph_runtime_state.variable_pool.append_variable(
|
||||
node_id=node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=variable_value # type: ignore[arg-type]
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node_id] + variable_key_list,
|
||||
variable_value
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
|
@ -35,7 +35,7 @@ class AnswerNode(BaseNode):
|
||||
part = cast(VarGenerateRouteChunk, part)
|
||||
value_selector = part.value_selector
|
||||
value = self.graph_runtime_state.variable_pool.get(
|
||||
variable_selector=value_selector
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value:
|
||||
|
@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
@ -98,7 +97,7 @@ class AnswerStreamProcessor:
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping[node_id]:
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
@ -108,7 +107,7 @@ class AnswerStreamProcessor:
|
||||
remove target node ids until merge
|
||||
"""
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping[node_id]:
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
@ -124,8 +123,10 @@ class AnswerStreamProcessor:
|
||||
"""
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
# all depends on answer node id not in rest node ids
|
||||
if not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
if (event.route_node_state.node_id != answer_node_id
|
||||
and (answer_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
@ -145,53 +146,14 @@ class AnswerStreamProcessor:
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get_variable_value(
|
||||
variable_selector=value_selector
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = ''
|
||||
if isinstance(value, str | int | float):
|
||||
text = str(value)
|
||||
elif isinstance(value, FileVar):
|
||||
# convert file to markdown
|
||||
text = value.to_markdown()
|
||||
elif isinstance(value, dict):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
if file_vars:
|
||||
file_var = file_vars[0]
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown()
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
|
||||
if not text:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
elif isinstance(value, list):
|
||||
# handle files
|
||||
file_vars = self._fetch_files_from_variable_value(value)
|
||||
for file_var in file_vars:
|
||||
try:
|
||||
file_var_obj = FileVar(**file_var)
|
||||
except Exception as e:
|
||||
logger.error(f'Error creating file var: {e}')
|
||||
continue
|
||||
|
||||
# convert file to markdown
|
||||
text = file_var_obj.to_markdown() + ' '
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text and value:
|
||||
# other types
|
||||
text = json.dumps(value, ensure_ascii=False)
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
@ -7,15 +8,15 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
|
||||
input_conditions = []
|
||||
group_result = []
|
||||
|
||||
index = 0
|
||||
for condition in conditions:
|
||||
index += 1
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
actual_value = variable_pool.get_any(
|
||||
condition.variable_selector
|
||||
)
|
||||
|
||||
expected_value = None
|
||||
@ -24,8 +25,8 @@ class ConditionProcessor:
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if variable_selectors:
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_variable_value(
|
||||
variable_selector=variable_selector.value_selector
|
||||
value = variable_pool.get_any(
|
||||
variable_selector.value_selector
|
||||
)
|
||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||
|
||||
@ -63,37 +64,37 @@ class ConditionProcessor:
|
||||
:return: bool
|
||||
"""
|
||||
if comparison_operator == "contains":
|
||||
return self._assert_contains(actual_value, expected_value) # type: ignore
|
||||
return self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
return self._assert_not_contains(actual_value, expected_value) # type: ignore
|
||||
return self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
return self._assert_start_with(actual_value, expected_value) # type: ignore
|
||||
return self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
return self._assert_end_with(actual_value, expected_value) # type: ignore
|
||||
return self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
return self._assert_is(actual_value, expected_value) # type: ignore
|
||||
return self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
return self._assert_is_not(actual_value, expected_value) # type: ignore
|
||||
return self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
return self._assert_empty(actual_value) # type: ignore
|
||||
return self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
return self._assert_not_empty(actual_value) # type: ignore
|
||||
return self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
return self._assert_equal(actual_value, expected_value) # type: ignore
|
||||
return self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
return self._assert_not_equal(actual_value, expected_value) # type: ignore
|
||||
return self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
return self._assert_greater_than(actual_value, expected_value) # type: ignore
|
||||
return self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
return self._assert_less_than(actual_value, expected_value) # type: ignore
|
||||
return self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
return self._assert_greater_than_or_equal(actual_value, expected_value) # type: ignore
|
||||
return self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
return self._assert_less_than_or_equal(actual_value, expected_value) # type: ignore
|
||||
return self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
return self._assert_null(actual_value) # type: ignore
|
||||
return self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
return self._assert_not_null(actual_value) # type: ignore
|
||||
return self._assert_not_null(actual_value)
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||
|
||||
|
@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.event import (
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
@ -139,10 +140,12 @@ def test_run_parallel(mock_close, mock_remove):
|
||||
assert not isinstance(item, NodeRunFailedEvent)
|
||||
assert not isinstance(item, GraphRunFailedEvent)
|
||||
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']:
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
|
||||
'answer2', 'answer3', 'answer4', 'answer5'
|
||||
]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 12
|
||||
assert len(items) == 19
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == 'start'
|
||||
@ -291,12 +294,16 @@ def test_run_branch(mock_close, mock_remove):
|
||||
print(type(item), item)
|
||||
items.append(item)
|
||||
|
||||
assert len(items) == 8
|
||||
assert len(items) == 10
|
||||
assert items[3].route_node_state.node_id == 'if-else-1'
|
||||
assert items[4].route_node_state.node_id == 'if-else-1'
|
||||
assert items[5].route_node_state.node_id == 'answer-1'
|
||||
assert items[6].route_node_state.node_id == 'answer-1'
|
||||
assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato'
|
||||
assert isinstance(items[7], GraphRunSucceededEvent)
|
||||
assert isinstance(items[5], NodeRunStreamChunkEvent)
|
||||
assert items[5].chunk_content == '1 '
|
||||
assert isinstance(items[6], NodeRunStreamChunkEvent)
|
||||
assert items[6].chunk_content == 'takato'
|
||||
assert items[7].route_node_state.node_id == 'answer-1'
|
||||
assert items[8].route_node_state.node_id == 'answer-1'
|
||||
assert items[8].route_node_state.node_run_result.outputs['answer'] == '1 takato'
|
||||
assert isinstance(items[9], GraphRunSucceededEvent)
|
||||
|
||||
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
||||
|
@ -187,9 +187,8 @@ def test_process():
|
||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if 'llm' in event.route_node_state.node_id:
|
||||
variable_pool.append_variable(
|
||||
event.route_node_state.node_id,
|
||||
["text"],
|
||||
variable_pool.add(
|
||||
[event.route_node_state.node_id, "text"],
|
||||
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1])))
|
||||
)
|
||||
yield event
|
||||
|
@ -45,23 +45,23 @@ def test_execute_if_else_result_true():
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={})
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
|
||||
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
|
||||
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
|
||||
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
|
||||
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
|
||||
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
|
||||
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
|
||||
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
|
||||
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||
pool.add(['start', 'contains'], 'cabcde')
|
||||
pool.add(['start', 'not_contains'], 'zacde')
|
||||
pool.add(['start', 'start_with'], 'abc')
|
||||
pool.add(['start', 'end_with'], 'zzab')
|
||||
pool.add(['start', 'is'], 'ab')
|
||||
pool.add(['start', 'is_not'], 'aab')
|
||||
pool.add(['start', 'empty'], '')
|
||||
pool.add(['start', 'not_empty'], 'aaa')
|
||||
pool.add(['start', 'equals'], 22)
|
||||
pool.add(['start', 'not_equals'], 23)
|
||||
pool.add(['start', 'greater_than'], 23)
|
||||
pool.add(['start', 'less_than'], 21)
|
||||
pool.add(['start', 'greater_than_or_equal'], 22)
|
||||
pool.add(['start', 'less_than_or_equal'], 21)
|
||||
pool.add(['start', 'not_null'], '1212')
|
||||
|
||||
node = IfElseNode(
|
||||
graph_init_params=init_params,
|
||||
@ -204,13 +204,60 @@ def test_execute_if_else_result_true():
|
||||
|
||||
|
||||
def test_execute_if_else_result_false():
|
||||
node = IfElseNode(
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm-target",
|
||||
"source": "start",
|
||||
"target": "llm",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start"
|
||||
},
|
||||
"id": "start"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm"
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id='1',
|
||||
app_id='1',
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id='1',
|
||||
user_id='1',
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
|
||||
node = IfElseNode(
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=pool,
|
||||
start_at=time.perf_counter()
|
||||
),
|
||||
config={
|
||||
'id': 'if-else',
|
||||
'data': {
|
||||
@ -233,19 +280,11 @@ def test_execute_if_else_result_false():
|
||||
}
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={
|
||||
SystemVariable.FILES: [],
|
||||
SystemVariable.USER_ID: 'aaa'
|
||||
}, user_inputs={}, environment_variables=[])
|
||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# execute node
|
||||
result = node._run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs['result'] is False
|
||||
|
Loading…
x
Reference in New Issue
Block a user