mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-17 23:35:52 +08:00
fix bugs
This commit is contained in:
parent
a603e01f5e
commit
2c695ded79
@ -2,7 +2,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
|
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
|
||||||
@ -21,7 +21,7 @@ class VariablePool(BaseModel):
|
|||||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||||
# elements of the selector except the first one.
|
# elements of the selector except the first one.
|
||||||
_variable_dictionary: dict[str, dict[int, Variable]] = Field(
|
variable_dictionary: dict[str, dict[int, Variable]] = Field(
|
||||||
description='Variables mapping',
|
description='Variables mapping',
|
||||||
default=defaultdict(dict)
|
default=defaultdict(dict)
|
||||||
)
|
)
|
||||||
@ -36,10 +36,12 @@ class VariablePool(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
environment_variables: Sequence[Variable] = Field(
|
environment_variables: Sequence[Variable] = Field(
|
||||||
description="Environment variables."
|
description="Environment variables.",
|
||||||
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
@model_validator(mode="after")
|
||||||
|
def val_model_after(self):
|
||||||
"""
|
"""
|
||||||
Append system variables
|
Append system variables
|
||||||
:return:
|
:return:
|
||||||
@ -52,6 +54,8 @@ class VariablePool(BaseModel):
|
|||||||
for var in self.environment_variables or []:
|
for var in self.environment_variables or []:
|
||||||
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
def add(self, selector: Sequence[str], value: Any, /) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a variable to the variable pool.
|
Adds a variable to the variable pool.
|
||||||
@ -78,7 +82,7 @@ class VariablePool(BaseModel):
|
|||||||
v = value
|
v = value
|
||||||
|
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
self._variable_dictionary[selector[0]][hash_key] = v
|
self.variable_dictionary[selector[0]][hash_key] = v
|
||||||
|
|
||||||
def get(self, selector: Sequence[str], /) -> Variable | None:
|
def get(self, selector: Sequence[str], /) -> Variable | None:
|
||||||
"""
|
"""
|
||||||
@ -96,7 +100,7 @@ class VariablePool(BaseModel):
|
|||||||
if len(selector) < 2:
|
if len(selector) < 2:
|
||||||
raise ValueError('Invalid selector')
|
raise ValueError('Invalid selector')
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -117,7 +121,7 @@ class VariablePool(BaseModel):
|
|||||||
if len(selector) < 2:
|
if len(selector) < 2:
|
||||||
raise ValueError('Invalid selector')
|
raise ValueError('Invalid selector')
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
value = self._variable_dictionary[selector[0]].get(hash_key)
|
value = self.variable_dictionary[selector[0]].get(hash_key)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
@ -140,7 +144,7 @@ class VariablePool(BaseModel):
|
|||||||
if not selector:
|
if not selector:
|
||||||
return
|
return
|
||||||
if len(selector) == 1:
|
if len(selector) == 1:
|
||||||
self._variable_dictionary[selector[0]] = {}
|
self.variable_dictionary[selector[0]] = {}
|
||||||
return
|
return
|
||||||
hash_key = hash(tuple(selector[1:]))
|
hash_key = hash(tuple(selector[1:]))
|
||||||
self._variable_dictionary[selector[0]].pop(hash_key, None)
|
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
||||||
|
@ -285,16 +285,18 @@ class Graph(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# collect all branches node ids
|
# collect all branches node ids
|
||||||
end_to_node_id: Optional[str] = None
|
|
||||||
for branch_node_id, node_ids in in_branch_node_ids.items():
|
for branch_node_id, node_ids in in_branch_node_ids.items():
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
node_parallel_mapping[node_id] = parallel.id
|
node_parallel_mapping[node_id] = parallel.id
|
||||||
|
|
||||||
if not end_to_node_id and edge_mapping.get(node_id):
|
end_to_node_id: Optional[str] = None
|
||||||
node_edges = edge_mapping[node_id]
|
for node_id in node_parallel_mapping:
|
||||||
target_node_id = node_edges[0].target_node_id
|
if not end_to_node_id and edge_mapping.get(node_id):
|
||||||
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
|
node_edges = edge_mapping[node_id]
|
||||||
end_to_node_id = target_node_id
|
target_node_id = node_edges[0].target_node_id
|
||||||
|
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
|
||||||
|
end_to_node_id = target_node_id
|
||||||
|
break
|
||||||
|
|
||||||
if end_to_node_id:
|
if end_to_node_id:
|
||||||
parallel.end_to_node_id = end_to_node_id
|
parallel.end_to_node_id = end_to_node_id
|
||||||
|
@ -244,8 +244,8 @@ class GraphEngine:
|
|||||||
|
|
||||||
next_node_id = final_node_id
|
next_node_id = final_node_id
|
||||||
|
|
||||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
|
||||||
break
|
# break
|
||||||
|
|
||||||
def _run_parallel_node(self,
|
def _run_parallel_node(self,
|
||||||
flask_app: Flask,
|
flask_app: Flask,
|
||||||
@ -402,10 +402,9 @@ class GraphEngine:
|
|||||||
:param variable_value: variable value
|
:param variable_value: variable value
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self.graph_runtime_state.variable_pool.append_variable(
|
self.graph_runtime_state.variable_pool.add(
|
||||||
node_id=node_id,
|
[node_id] + variable_key_list,
|
||||||
variable_key_list=variable_key_list,
|
variable_value
|
||||||
value=variable_value # type: ignore[arg-type]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# if variable_value is a dict, then recursively append variables
|
# if variable_value is a dict, then recursively append variables
|
||||||
|
@ -35,7 +35,7 @@ class AnswerNode(BaseNode):
|
|||||||
part = cast(VarGenerateRouteChunk, part)
|
part = cast(VarGenerateRouteChunk, part)
|
||||||
value_selector = part.value_selector
|
value_selector = part.value_selector
|
||||||
value = self.graph_runtime_state.variable_pool.get(
|
value = self.graph_runtime_state.variable_pool.get(
|
||||||
variable_selector=value_selector
|
value_selector
|
||||||
)
|
)
|
||||||
|
|
||||||
if value:
|
if value:
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
@ -98,7 +97,7 @@ class AnswerStreamProcessor:
|
|||||||
|
|
||||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||||
node_ids = []
|
node_ids = []
|
||||||
for edge in self.graph.edge_mapping[node_id]:
|
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||||
node_ids.append(edge.target_node_id)
|
node_ids.append(edge.target_node_id)
|
||||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||||
return node_ids
|
return node_ids
|
||||||
@ -108,7 +107,7 @@ class AnswerStreamProcessor:
|
|||||||
remove target node ids until merge
|
remove target node ids until merge
|
||||||
"""
|
"""
|
||||||
self.rest_node_ids.remove(node_id)
|
self.rest_node_ids.remove(node_id)
|
||||||
for edge in self.graph.edge_mapping[node_id]:
|
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||||
if edge.target_node_id in reachable_node_ids:
|
if edge.target_node_id in reachable_node_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -124,8 +123,10 @@ class AnswerStreamProcessor:
|
|||||||
"""
|
"""
|
||||||
for answer_node_id, position in self.route_position.items():
|
for answer_node_id, position in self.route_position.items():
|
||||||
# all depends on answer node id not in rest node ids
|
# all depends on answer node id not in rest node ids
|
||||||
if not all(dep_id not in self.rest_node_ids
|
if (event.route_node_state.node_id != answer_node_id
|
||||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
and (answer_node_id not in self.rest_node_ids
|
||||||
|
or not all(dep_id not in self.rest_node_ids
|
||||||
|
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
route_position = self.route_position[answer_node_id]
|
route_position = self.route_position[answer_node_id]
|
||||||
@ -145,53 +146,14 @@ class AnswerStreamProcessor:
|
|||||||
if not value_selector:
|
if not value_selector:
|
||||||
break
|
break
|
||||||
|
|
||||||
value = self.variable_pool.get_variable_value(
|
value = self.variable_pool.get(
|
||||||
variable_selector=value_selector
|
value_selector
|
||||||
)
|
)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
break
|
break
|
||||||
|
|
||||||
text = ''
|
text = value.markdown
|
||||||
if isinstance(value, str | int | float):
|
|
||||||
text = str(value)
|
|
||||||
elif isinstance(value, FileVar):
|
|
||||||
# convert file to markdown
|
|
||||||
text = value.to_markdown()
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
# handle files
|
|
||||||
file_vars = self._fetch_files_from_variable_value(value)
|
|
||||||
if file_vars:
|
|
||||||
file_var = file_vars[0]
|
|
||||||
try:
|
|
||||||
file_var_obj = FileVar(**file_var)
|
|
||||||
|
|
||||||
# convert file to markdown
|
|
||||||
text = file_var_obj.to_markdown()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Error creating file var: {e}')
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
# other types
|
|
||||||
text = json.dumps(value, ensure_ascii=False)
|
|
||||||
elif isinstance(value, list):
|
|
||||||
# handle files
|
|
||||||
file_vars = self._fetch_files_from_variable_value(value)
|
|
||||||
for file_var in file_vars:
|
|
||||||
try:
|
|
||||||
file_var_obj = FileVar(**file_var)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f'Error creating file var: {e}')
|
|
||||||
continue
|
|
||||||
|
|
||||||
# convert file to markdown
|
|
||||||
text = file_var_obj.to_markdown() + ' '
|
|
||||||
|
|
||||||
text = text.strip()
|
|
||||||
|
|
||||||
if not text and value:
|
|
||||||
# other types
|
|
||||||
text = json.dumps(value, ensure_ascii=False)
|
|
||||||
|
|
||||||
if text:
|
if text:
|
||||||
yield NodeRunStreamChunkEvent(
|
yield NodeRunStreamChunkEvent(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
@ -7,15 +8,15 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
|||||||
|
|
||||||
|
|
||||||
class ConditionProcessor:
|
class ConditionProcessor:
|
||||||
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
|
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
|
||||||
input_conditions = []
|
input_conditions = []
|
||||||
group_result = []
|
group_result = []
|
||||||
|
|
||||||
index = 0
|
index = 0
|
||||||
for condition in conditions:
|
for condition in conditions:
|
||||||
index += 1
|
index += 1
|
||||||
actual_value = variable_pool.get_variable_value(
|
actual_value = variable_pool.get_any(
|
||||||
variable_selector=condition.variable_selector
|
condition.variable_selector
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_value = None
|
expected_value = None
|
||||||
@ -24,8 +25,8 @@ class ConditionProcessor:
|
|||||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||||
if variable_selectors:
|
if variable_selectors:
|
||||||
for variable_selector in variable_selectors:
|
for variable_selector in variable_selectors:
|
||||||
value = variable_pool.get_variable_value(
|
value = variable_pool.get_any(
|
||||||
variable_selector=variable_selector.value_selector
|
variable_selector.value_selector
|
||||||
)
|
)
|
||||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||||
|
|
||||||
@ -63,37 +64,37 @@ class ConditionProcessor:
|
|||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if comparison_operator == "contains":
|
if comparison_operator == "contains":
|
||||||
return self._assert_contains(actual_value, expected_value) # type: ignore
|
return self._assert_contains(actual_value, expected_value)
|
||||||
elif comparison_operator == "not contains":
|
elif comparison_operator == "not contains":
|
||||||
return self._assert_not_contains(actual_value, expected_value) # type: ignore
|
return self._assert_not_contains(actual_value, expected_value)
|
||||||
elif comparison_operator == "start with":
|
elif comparison_operator == "start with":
|
||||||
return self._assert_start_with(actual_value, expected_value) # type: ignore
|
return self._assert_start_with(actual_value, expected_value)
|
||||||
elif comparison_operator == "end with":
|
elif comparison_operator == "end with":
|
||||||
return self._assert_end_with(actual_value, expected_value) # type: ignore
|
return self._assert_end_with(actual_value, expected_value)
|
||||||
elif comparison_operator == "is":
|
elif comparison_operator == "is":
|
||||||
return self._assert_is(actual_value, expected_value) # type: ignore
|
return self._assert_is(actual_value, expected_value)
|
||||||
elif comparison_operator == "is not":
|
elif comparison_operator == "is not":
|
||||||
return self._assert_is_not(actual_value, expected_value) # type: ignore
|
return self._assert_is_not(actual_value, expected_value)
|
||||||
elif comparison_operator == "empty":
|
elif comparison_operator == "empty":
|
||||||
return self._assert_empty(actual_value) # type: ignore
|
return self._assert_empty(actual_value)
|
||||||
elif comparison_operator == "not empty":
|
elif comparison_operator == "not empty":
|
||||||
return self._assert_not_empty(actual_value) # type: ignore
|
return self._assert_not_empty(actual_value)
|
||||||
elif comparison_operator == "=":
|
elif comparison_operator == "=":
|
||||||
return self._assert_equal(actual_value, expected_value) # type: ignore
|
return self._assert_equal(actual_value, expected_value)
|
||||||
elif comparison_operator == "≠":
|
elif comparison_operator == "≠":
|
||||||
return self._assert_not_equal(actual_value, expected_value) # type: ignore
|
return self._assert_not_equal(actual_value, expected_value)
|
||||||
elif comparison_operator == ">":
|
elif comparison_operator == ">":
|
||||||
return self._assert_greater_than(actual_value, expected_value) # type: ignore
|
return self._assert_greater_than(actual_value, expected_value)
|
||||||
elif comparison_operator == "<":
|
elif comparison_operator == "<":
|
||||||
return self._assert_less_than(actual_value, expected_value) # type: ignore
|
return self._assert_less_than(actual_value, expected_value)
|
||||||
elif comparison_operator == "≥":
|
elif comparison_operator == "≥":
|
||||||
return self._assert_greater_than_or_equal(actual_value, expected_value) # type: ignore
|
return self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||||
elif comparison_operator == "≤":
|
elif comparison_operator == "≤":
|
||||||
return self._assert_less_than_or_equal(actual_value, expected_value) # type: ignore
|
return self._assert_less_than_or_equal(actual_value, expected_value)
|
||||||
elif comparison_operator == "null":
|
elif comparison_operator == "null":
|
||||||
return self._assert_null(actual_value) # type: ignore
|
return self._assert_null(actual_value)
|
||||||
elif comparison_operator == "not null":
|
elif comparison_operator == "not null":
|
||||||
return self._assert_not_null(actual_value) # type: ignore
|
return self._assert_not_null(actual_value)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||||
|
|
||||||
|
@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
|
NodeRunStreamChunkEvent,
|
||||||
NodeRunSucceededEvent,
|
NodeRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
@ -139,10 +140,12 @@ def test_run_parallel(mock_close, mock_remove):
|
|||||||
assert not isinstance(item, NodeRunFailedEvent)
|
assert not isinstance(item, NodeRunFailedEvent)
|
||||||
assert not isinstance(item, GraphRunFailedEvent)
|
assert not isinstance(item, GraphRunFailedEvent)
|
||||||
|
|
||||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']:
|
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
|
||||||
|
'answer2', 'answer3', 'answer4', 'answer5'
|
||||||
|
]:
|
||||||
assert item.parallel_id is not None
|
assert item.parallel_id is not None
|
||||||
|
|
||||||
assert len(items) == 12
|
assert len(items) == 19
|
||||||
assert isinstance(items[0], GraphRunStartedEvent)
|
assert isinstance(items[0], GraphRunStartedEvent)
|
||||||
assert isinstance(items[1], NodeRunStartedEvent)
|
assert isinstance(items[1], NodeRunStartedEvent)
|
||||||
assert items[1].route_node_state.node_id == 'start'
|
assert items[1].route_node_state.node_id == 'start'
|
||||||
@ -291,12 +294,16 @@ def test_run_branch(mock_close, mock_remove):
|
|||||||
print(type(item), item)
|
print(type(item), item)
|
||||||
items.append(item)
|
items.append(item)
|
||||||
|
|
||||||
assert len(items) == 8
|
assert len(items) == 10
|
||||||
assert items[3].route_node_state.node_id == 'if-else-1'
|
assert items[3].route_node_state.node_id == 'if-else-1'
|
||||||
assert items[4].route_node_state.node_id == 'if-else-1'
|
assert items[4].route_node_state.node_id == 'if-else-1'
|
||||||
assert items[5].route_node_state.node_id == 'answer-1'
|
assert isinstance(items[5], NodeRunStreamChunkEvent)
|
||||||
assert items[6].route_node_state.node_id == 'answer-1'
|
assert items[5].chunk_content == '1 '
|
||||||
assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato'
|
assert isinstance(items[6], NodeRunStreamChunkEvent)
|
||||||
assert isinstance(items[7], GraphRunSucceededEvent)
|
assert items[6].chunk_content == 'takato'
|
||||||
|
assert items[7].route_node_state.node_id == 'answer-1'
|
||||||
|
assert items[8].route_node_state.node_id == 'answer-1'
|
||||||
|
assert items[8].route_node_state.node_run_result.outputs['answer'] == '1 takato'
|
||||||
|
assert isinstance(items[9], GraphRunSucceededEvent)
|
||||||
|
|
||||||
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
||||||
|
@ -187,9 +187,8 @@ def test_process():
|
|||||||
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
|
||||||
if isinstance(event, NodeRunSucceededEvent):
|
if isinstance(event, NodeRunSucceededEvent):
|
||||||
if 'llm' in event.route_node_state.node_id:
|
if 'llm' in event.route_node_state.node_id:
|
||||||
variable_pool.append_variable(
|
variable_pool.add(
|
||||||
event.route_node_state.node_id,
|
[event.route_node_state.node_id, "text"],
|
||||||
["text"],
|
|
||||||
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1])))
|
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1])))
|
||||||
)
|
)
|
||||||
yield event
|
yield event
|
||||||
|
@ -45,23 +45,23 @@ def test_execute_if_else_result_true():
|
|||||||
SystemVariable.FILES: [],
|
SystemVariable.FILES: [],
|
||||||
SystemVariable.USER_ID: 'aaa'
|
SystemVariable.USER_ID: 'aaa'
|
||||||
}, user_inputs={})
|
}, user_inputs={})
|
||||||
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
|
pool.add(['start', 'array_contains'], ['ab', 'def'])
|
||||||
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
|
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
|
||||||
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
|
pool.add(['start', 'contains'], 'cabcde')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
|
pool.add(['start', 'not_contains'], 'zacde')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
|
pool.add(['start', 'start_with'], 'abc')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
|
pool.add(['start', 'end_with'], 'zzab')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
|
pool.add(['start', 'is'], 'ab')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
|
pool.add(['start', 'is_not'], 'aab')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
|
pool.add(['start', 'empty'], '')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
|
pool.add(['start', 'not_empty'], 'aaa')
|
||||||
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
|
pool.add(['start', 'equals'], 22)
|
||||||
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
|
pool.add(['start', 'not_equals'], 23)
|
||||||
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
|
pool.add(['start', 'greater_than'], 23)
|
||||||
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
|
pool.add(['start', 'less_than'], 21)
|
||||||
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
|
pool.add(['start', 'greater_than_or_equal'], 22)
|
||||||
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
|
pool.add(['start', 'less_than_or_equal'], 21)
|
||||||
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
|
pool.add(['start', 'not_null'], '1212')
|
||||||
|
|
||||||
node = IfElseNode(
|
node = IfElseNode(
|
||||||
graph_init_params=init_params,
|
graph_init_params=init_params,
|
||||||
@ -204,13 +204,60 @@ def test_execute_if_else_result_true():
|
|||||||
|
|
||||||
|
|
||||||
def test_execute_if_else_result_false():
|
def test_execute_if_else_result_false():
|
||||||
node = IfElseNode(
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-llm-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
init_params = GraphInitParams(
|
||||||
tenant_id='1',
|
tenant_id='1',
|
||||||
app_id='1',
|
app_id='1',
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
workflow_id='1',
|
workflow_id='1',
|
||||||
user_id='1',
|
user_id='1',
|
||||||
user_from=UserFrom.ACCOUNT,
|
user_from=UserFrom.ACCOUNT,
|
||||||
invoke_from=InvokeFrom.DEBUGGER,
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct variable pool
|
||||||
|
pool = VariablePool(system_variables={
|
||||||
|
SystemVariable.FILES: [],
|
||||||
|
SystemVariable.USER_ID: 'aaa'
|
||||||
|
}, user_inputs={}, environment_variables=[])
|
||||||
|
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
||||||
|
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
||||||
|
|
||||||
|
node = IfElseNode(
|
||||||
|
graph_init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=GraphRuntimeState(
|
||||||
|
variable_pool=pool,
|
||||||
|
start_at=time.perf_counter()
|
||||||
|
),
|
||||||
config={
|
config={
|
||||||
'id': 'if-else',
|
'id': 'if-else',
|
||||||
'data': {
|
'data': {
|
||||||
@ -233,19 +280,11 @@ def test_execute_if_else_result_false():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# construct variable pool
|
|
||||||
pool = VariablePool(system_variables={
|
|
||||||
SystemVariable.FILES: [],
|
|
||||||
SystemVariable.USER_ID: 'aaa'
|
|
||||||
}, user_inputs={}, environment_variables=[])
|
|
||||||
pool.add(['start', 'array_contains'], ['1ab', 'def'])
|
|
||||||
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
|
|
||||||
|
|
||||||
# Mock db.session.close()
|
# Mock db.session.close()
|
||||||
db.session.close = MagicMock()
|
db.session.close = MagicMock()
|
||||||
|
|
||||||
# execute node
|
# execute node
|
||||||
result = node._run(pool)
|
result = node._run()
|
||||||
|
|
||||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
assert result.outputs['result'] is False
|
assert result.outputs['result'] is False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user