This commit is contained in:
takatost 2024-07-23 00:10:23 +08:00
parent a603e01f5e
commit 2c695ded79
9 changed files with 140 additions and 127 deletions

View File

@ -2,7 +2,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
@ -21,7 +21,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
_variable_dictionary: dict[str, dict[int, Variable]] = Field(
variable_dictionary: dict[str, dict[int, Variable]] = Field(
description='Variables mapping',
default=defaultdict(dict)
)
@ -36,10 +36,12 @@ class VariablePool(BaseModel):
)
environment_variables: Sequence[Variable] = Field(
description="Environment variables."
description="Environment variables.",
default_factory=list
)
def __post_init__(self):
@model_validator(mode="after")
def val_model_after(self):
"""
Append system variables
:return:
@ -52,6 +54,8 @@ class VariablePool(BaseModel):
for var in self.environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None:
"""
Adds a variable to the variable pool.
@ -78,7 +82,7 @@ class VariablePool(BaseModel):
v = value
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v
self.variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None:
"""
@ -96,7 +100,7 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
return value
@ -117,7 +121,7 @@ class VariablePool(BaseModel):
if len(selector) < 2:
raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
value = self.variable_dictionary[selector[0]].get(hash_key)
if value is None:
return value
@ -140,7 +144,7 @@ class VariablePool(BaseModel):
if not selector:
return
if len(selector) == 1:
self._variable_dictionary[selector[0]] = {}
self.variable_dictionary[selector[0]] = {}
return
hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]].pop(hash_key, None)
self.variable_dictionary[selector[0]].pop(hash_key, None)

View File

@ -285,16 +285,18 @@ class Graph(BaseModel):
)
# collect all branches node ids
end_to_node_id: Optional[str] = None
for branch_node_id, node_ids in in_branch_node_ids.items():
for node_id in node_ids:
node_parallel_mapping[node_id] = parallel.id
end_to_node_id: Optional[str] = None
for node_id in node_parallel_mapping:
if not end_to_node_id and edge_mapping.get(node_id):
node_edges = edge_mapping[node_id]
target_node_id = node_edges[0].target_node_id
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
end_to_node_id = target_node_id
break
if end_to_node_id:
parallel.end_to_node_id = end_to_node_id

View File

@ -244,8 +244,8 @@ class GraphEngine:
next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
break
# if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
# break
def _run_parallel_node(self,
flask_app: Flask,
@ -402,10 +402,9 @@ class GraphEngine:
:param variable_value: variable value
:return:
"""
self.graph_runtime_state.variable_pool.append_variable(
node_id=node_id,
variable_key_list=variable_key_list,
value=variable_value # type: ignore[arg-type]
self.graph_runtime_state.variable_pool.add(
[node_id] + variable_key_list,
variable_value
)
# if variable_value is a dict, then recursively append variables

View File

@ -35,7 +35,7 @@ class AnswerNode(BaseNode):
part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get(
variable_selector=value_selector
value_selector
)
if value:

View File

@ -1,4 +1,3 @@
import json
import logging
from collections.abc import Generator
from typing import Optional, cast
@ -98,7 +97,7 @@ class AnswerStreamProcessor:
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = []
for edge in self.graph.edge_mapping[node_id]:
for edge in self.graph.edge_mapping.get(node_id, []):
node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids
@ -108,7 +107,7 @@ class AnswerStreamProcessor:
remove target node ids until merge
"""
self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping[node_id]:
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids:
continue
@ -124,8 +123,10 @@ class AnswerStreamProcessor:
"""
for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids
if not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
if (event.route_node_state.node_id != answer_node_id
and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue
route_position = self.route_position[answer_node_id]
@ -145,53 +146,14 @@ class AnswerStreamProcessor:
if not value_selector:
break
value = self.variable_pool.get_variable_value(
variable_selector=value_selector
value = self.variable_pool.get(
value_selector
)
if value is None:
break
text = ''
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
text = value.markdown
if text:
yield NodeRunStreamChunkEvent(

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Any, Optional
from core.file.file_obj import FileVar
@ -7,15 +8,15 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class ConditionProcessor:
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]):
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
input_conditions = []
group_result = []
index = 0
for condition in conditions:
index += 1
actual_value = variable_pool.get_variable_value(
variable_selector=condition.variable_selector
actual_value = variable_pool.get_any(
condition.variable_selector
)
expected_value = None
@ -24,8 +25,8 @@ class ConditionProcessor:
variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors:
for variable_selector in variable_selectors:
value = variable_pool.get_variable_value(
variable_selector=variable_selector.value_selector
value = variable_pool.get_any(
variable_selector.value_selector
)
expected_value = variable_template_parser.format({variable_selector.variable: value})
@ -63,37 +64,37 @@ class ConditionProcessor:
:return: bool
"""
if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value) # type: ignore
return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value) # type: ignore
return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value) # type: ignore
return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value) # type: ignore
return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value) # type: ignore
return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value) # type: ignore
return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty":
return self._assert_empty(actual_value) # type: ignore
return self._assert_empty(actual_value)
elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value) # type: ignore
return self._assert_not_empty(actual_value)
elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value) # type: ignore
return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value) # type: ignore
return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value) # type: ignore
return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value) # type: ignore
return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value) # type: ignore
return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value) # type: ignore
return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null":
return self._assert_null(actual_value) # type: ignore
return self._assert_null(actual_value)
elif comparison_operator == "not null":
return self._assert_not_null(actual_value) # type: ignore
return self._assert_not_null(actual_value)
else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}")

View File

@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.event import (
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
@ -139,10 +140,12 @@ def test_run_parallel(mock_close, mock_remove):
assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']:
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
'answer2', 'answer3', 'answer4', 'answer5'
]:
assert item.parallel_id is not None
assert len(items) == 12
assert len(items) == 19
assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start'
@ -291,12 +294,16 @@ def test_run_branch(mock_close, mock_remove):
print(type(item), item)
items.append(item)
assert len(items) == 8
assert len(items) == 10
assert items[3].route_node_state.node_id == 'if-else-1'
assert items[4].route_node_state.node_id == 'if-else-1'
assert items[5].route_node_state.node_id == 'answer-1'
assert items[6].route_node_state.node_id == 'answer-1'
assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato'
assert isinstance(items[7], GraphRunSucceededEvent)
assert isinstance(items[5], NodeRunStreamChunkEvent)
assert items[5].chunk_content == '1 '
assert isinstance(items[6], NodeRunStreamChunkEvent)
assert items[6].chunk_content == 'takato'
assert items[7].route_node_state.node_id == 'answer-1'
assert items[8].route_node_state.node_id == 'answer-1'
assert items[8].route_node_state.node_run_result.outputs['answer'] == '1 takato'
assert isinstance(items[9], GraphRunSucceededEvent)
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))

View File

@ -187,9 +187,8 @@ def test_process():
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunSucceededEvent):
if 'llm' in event.route_node_state.node_id:
variable_pool.append_variable(
event.route_node_state.node_id,
["text"],
variable_pool.add(
[event.route_node_state.node_id, "text"],
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1])))
)
yield event

View File

@ -45,23 +45,23 @@ def test_execute_if_else_result_true():
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def'])
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde')
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde')
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc')
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab')
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab')
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab')
pool.append_variable(node_id='start', variable_key_list=['empty'], value='')
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa')
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22)
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23)
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23)
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21)
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22)
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21)
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212')
pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ac', 'def'])
pool.add(['start', 'contains'], 'cabcde')
pool.add(['start', 'not_contains'], 'zacde')
pool.add(['start', 'start_with'], 'abc')
pool.add(['start', 'end_with'], 'zzab')
pool.add(['start', 'is'], 'ab')
pool.add(['start', 'is_not'], 'aab')
pool.add(['start', 'empty'], '')
pool.add(['start', 'not_empty'], 'aaa')
pool.add(['start', 'equals'], 22)
pool.add(['start', 'not_equals'], 23)
pool.add(['start', 'greater_than'], 23)
pool.add(['start', 'less_than'], 21)
pool.add(['start', 'greater_than_or_equal'], 22)
pool.add(['start', 'less_than_or_equal'], 21)
pool.add(['start', 'not_null'], '1212')
node = IfElseNode(
graph_init_params=init_params,
@ -204,13 +204,60 @@ def test_execute_if_else_result_true():
def test_execute_if_else_result_false():
node = IfElseNode(
graph_config = {
"edges": [
{
"id": "start-source-llm-target",
"source": "start",
"target": "llm",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm"
},
]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1',
app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1',
user_id='1',
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
node = IfElseNode(
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
start_at=time.perf_counter()
),
config={
'id': 'if-else',
'data': {
@ -233,19 +280,11 @@ def test_execute_if_else_result_false():
}
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
# Mock db.session.close()
db.session.close = MagicMock()
# execute node
result = node._run(pool)
result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] is False