This commit is contained in:
takatost 2024-07-23 00:10:23 +08:00
parent a603e01f5e
commit 2c695ded79
9 changed files with 140 additions and 127 deletions

View File

@ -2,7 +2,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Union from typing import Any, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
from typing_extensions import deprecated from typing_extensions import deprecated
from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory from core.app.segments import ArrayVariable, ObjectVariable, Variable, factory
@ -21,7 +21,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary. # The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one. # elements of the selector except the first one.
_variable_dictionary: dict[str, dict[int, Variable]] = Field( variable_dictionary: dict[str, dict[int, Variable]] = Field(
description='Variables mapping', description='Variables mapping',
default=defaultdict(dict) default=defaultdict(dict)
) )
@ -36,10 +36,12 @@ class VariablePool(BaseModel):
) )
environment_variables: Sequence[Variable] = Field( environment_variables: Sequence[Variable] = Field(
description="Environment variables." description="Environment variables.",
default_factory=list
) )
def __post_init__(self): @model_validator(mode="after")
def val_model_after(self):
""" """
Append system variables Append system variables
:return: :return:
@ -52,6 +54,8 @@ class VariablePool(BaseModel):
for var in self.environment_variables or []: for var in self.environment_variables or []:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
return self
def add(self, selector: Sequence[str], value: Any, /) -> None: def add(self, selector: Sequence[str], value: Any, /) -> None:
""" """
Adds a variable to the variable pool. Adds a variable to the variable pool.
@ -78,7 +82,7 @@ class VariablePool(BaseModel):
v = value v = value
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]][hash_key] = v self.variable_dictionary[selector[0]][hash_key] = v
def get(self, selector: Sequence[str], /) -> Variable | None: def get(self, selector: Sequence[str], /) -> Variable | None:
""" """
@ -96,7 +100,7 @@ class VariablePool(BaseModel):
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key) value = self.variable_dictionary[selector[0]].get(hash_key)
return value return value
@ -117,7 +121,7 @@ class VariablePool(BaseModel):
if len(selector) < 2: if len(selector) < 2:
raise ValueError('Invalid selector') raise ValueError('Invalid selector')
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key) value = self.variable_dictionary[selector[0]].get(hash_key)
if value is None: if value is None:
return value return value
@ -140,7 +144,7 @@ class VariablePool(BaseModel):
if not selector: if not selector:
return return
if len(selector) == 1: if len(selector) == 1:
self._variable_dictionary[selector[0]] = {} self.variable_dictionary[selector[0]] = {}
return return
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
self._variable_dictionary[selector[0]].pop(hash_key, None) self.variable_dictionary[selector[0]].pop(hash_key, None)

View File

@ -285,16 +285,18 @@ class Graph(BaseModel):
) )
# collect all branches node ids # collect all branches node ids
end_to_node_id: Optional[str] = None
for branch_node_id, node_ids in in_branch_node_ids.items(): for branch_node_id, node_ids in in_branch_node_ids.items():
for node_id in node_ids: for node_id in node_ids:
node_parallel_mapping[node_id] = parallel.id node_parallel_mapping[node_id] = parallel.id
if not end_to_node_id and edge_mapping.get(node_id): end_to_node_id: Optional[str] = None
node_edges = edge_mapping[node_id] for node_id in node_parallel_mapping:
target_node_id = node_edges[0].target_node_id if not end_to_node_id and edge_mapping.get(node_id):
if node_parallel_mapping.get(target_node_id) == parent_parallel_id: node_edges = edge_mapping[node_id]
end_to_node_id = target_node_id target_node_id = node_edges[0].target_node_id
if node_parallel_mapping.get(target_node_id) == parent_parallel_id:
end_to_node_id = target_node_id
break
if end_to_node_id: if end_to_node_id:
parallel.end_to_node_id = end_to_node_id parallel.end_to_node_id = end_to_node_id

View File

@ -244,8 +244,8 @@ class GraphEngine:
next_node_id = final_node_id next_node_id = final_node_id
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id: # if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') == in_parallel_id:
break # break
def _run_parallel_node(self, def _run_parallel_node(self,
flask_app: Flask, flask_app: Flask,
@ -402,10 +402,9 @@ class GraphEngine:
:param variable_value: variable value :param variable_value: variable value
:return: :return:
""" """
self.graph_runtime_state.variable_pool.append_variable( self.graph_runtime_state.variable_pool.add(
node_id=node_id, [node_id] + variable_key_list,
variable_key_list=variable_key_list, variable_value
value=variable_value # type: ignore[arg-type]
) )
# if variable_value is a dict, then recursively append variables # if variable_value is a dict, then recursively append variables

View File

@ -35,7 +35,7 @@ class AnswerNode(BaseNode):
part = cast(VarGenerateRouteChunk, part) part = cast(VarGenerateRouteChunk, part)
value_selector = part.value_selector value_selector = part.value_selector
value = self.graph_runtime_state.variable_pool.get( value = self.graph_runtime_state.variable_pool.get(
variable_selector=value_selector value_selector
) )
if value: if value:

View File

@ -1,4 +1,3 @@
import json
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, cast from typing import Optional, cast
@ -98,7 +97,7 @@ class AnswerStreamProcessor:
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]: def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
node_ids = [] node_ids = []
for edge in self.graph.edge_mapping[node_id]: for edge in self.graph.edge_mapping.get(node_id, []):
node_ids.append(edge.target_node_id) node_ids.append(edge.target_node_id)
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
return node_ids return node_ids
@ -108,7 +107,7 @@ class AnswerStreamProcessor:
remove target node ids until merge remove target node ids until merge
""" """
self.rest_node_ids.remove(node_id) self.rest_node_ids.remove(node_id)
for edge in self.graph.edge_mapping[node_id]: for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids: if edge.target_node_id in reachable_node_ids:
continue continue
@ -124,8 +123,10 @@ class AnswerStreamProcessor:
""" """
for answer_node_id, position in self.route_position.items(): for answer_node_id, position in self.route_position.items():
# all depends on answer node id not in rest node ids # all depends on answer node id not in rest node ids
if not all(dep_id not in self.rest_node_ids if (event.route_node_state.node_id != answer_node_id
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]): and (answer_node_id not in self.rest_node_ids
or not all(dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
continue continue
route_position = self.route_position[answer_node_id] route_position = self.route_position[answer_node_id]
@ -145,53 +146,14 @@ class AnswerStreamProcessor:
if not value_selector: if not value_selector:
break break
value = self.variable_pool.get_variable_value( value = self.variable_pool.get(
variable_selector=value_selector value_selector
) )
if value is None: if value is None:
break break
text = '' text = value.markdown
if isinstance(value, str | int | float):
text = str(value)
elif isinstance(value, FileVar):
# convert file to markdown
text = value.to_markdown()
elif isinstance(value, dict):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
if file_vars:
file_var = file_vars[0]
try:
file_var_obj = FileVar(**file_var)
# convert file to markdown
text = file_var_obj.to_markdown()
except Exception as e:
logger.error(f'Error creating file var: {e}')
if not text:
# other types
text = json.dumps(value, ensure_ascii=False)
elif isinstance(value, list):
# handle files
file_vars = self._fetch_files_from_variable_value(value)
for file_var in file_vars:
try:
file_var_obj = FileVar(**file_var)
except Exception as e:
logger.error(f'Error creating file var: {e}')
continue
# convert file to markdown
text = file_var_obj.to_markdown() + ' '
text = text.strip()
if not text and value:
# other types
text = json.dumps(value, ensure_ascii=False)
if text: if text:
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
@ -7,15 +8,15 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class ConditionProcessor: class ConditionProcessor:
def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]): def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
input_conditions = [] input_conditions = []
group_result = [] group_result = []
index = 0 index = 0
for condition in conditions: for condition in conditions:
index += 1 index += 1
actual_value = variable_pool.get_variable_value( actual_value = variable_pool.get_any(
variable_selector=condition.variable_selector condition.variable_selector
) )
expected_value = None expected_value = None
@ -24,8 +25,8 @@ class ConditionProcessor:
variable_selectors = variable_template_parser.extract_variable_selectors() variable_selectors = variable_template_parser.extract_variable_selectors()
if variable_selectors: if variable_selectors:
for variable_selector in variable_selectors: for variable_selector in variable_selectors:
value = variable_pool.get_variable_value( value = variable_pool.get_any(
variable_selector=variable_selector.value_selector variable_selector.value_selector
) )
expected_value = variable_template_parser.format({variable_selector.variable: value}) expected_value = variable_template_parser.format({variable_selector.variable: value})
@ -63,37 +64,37 @@ class ConditionProcessor:
:return: bool :return: bool
""" """
if comparison_operator == "contains": if comparison_operator == "contains":
return self._assert_contains(actual_value, expected_value) # type: ignore return self._assert_contains(actual_value, expected_value)
elif comparison_operator == "not contains": elif comparison_operator == "not contains":
return self._assert_not_contains(actual_value, expected_value) # type: ignore return self._assert_not_contains(actual_value, expected_value)
elif comparison_operator == "start with": elif comparison_operator == "start with":
return self._assert_start_with(actual_value, expected_value) # type: ignore return self._assert_start_with(actual_value, expected_value)
elif comparison_operator == "end with": elif comparison_operator == "end with":
return self._assert_end_with(actual_value, expected_value) # type: ignore return self._assert_end_with(actual_value, expected_value)
elif comparison_operator == "is": elif comparison_operator == "is":
return self._assert_is(actual_value, expected_value) # type: ignore return self._assert_is(actual_value, expected_value)
elif comparison_operator == "is not": elif comparison_operator == "is not":
return self._assert_is_not(actual_value, expected_value) # type: ignore return self._assert_is_not(actual_value, expected_value)
elif comparison_operator == "empty": elif comparison_operator == "empty":
return self._assert_empty(actual_value) # type: ignore return self._assert_empty(actual_value)
elif comparison_operator == "not empty": elif comparison_operator == "not empty":
return self._assert_not_empty(actual_value) # type: ignore return self._assert_not_empty(actual_value)
elif comparison_operator == "=": elif comparison_operator == "=":
return self._assert_equal(actual_value, expected_value) # type: ignore return self._assert_equal(actual_value, expected_value)
elif comparison_operator == "": elif comparison_operator == "":
return self._assert_not_equal(actual_value, expected_value) # type: ignore return self._assert_not_equal(actual_value, expected_value)
elif comparison_operator == ">": elif comparison_operator == ">":
return self._assert_greater_than(actual_value, expected_value) # type: ignore return self._assert_greater_than(actual_value, expected_value)
elif comparison_operator == "<": elif comparison_operator == "<":
return self._assert_less_than(actual_value, expected_value) # type: ignore return self._assert_less_than(actual_value, expected_value)
elif comparison_operator == "": elif comparison_operator == "":
return self._assert_greater_than_or_equal(actual_value, expected_value) # type: ignore return self._assert_greater_than_or_equal(actual_value, expected_value)
elif comparison_operator == "": elif comparison_operator == "":
return self._assert_less_than_or_equal(actual_value, expected_value) # type: ignore return self._assert_less_than_or_equal(actual_value, expected_value)
elif comparison_operator == "null": elif comparison_operator == "null":
return self._assert_null(actual_value) # type: ignore return self._assert_null(actual_value)
elif comparison_operator == "not null": elif comparison_operator == "not null":
return self._assert_not_null(actual_value) # type: ignore return self._assert_not_null(actual_value)
else: else:
raise ValueError(f"Invalid comparison operator: {comparison_operator}") raise ValueError(f"Invalid comparison operator: {comparison_operator}")

View File

@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.event import (
GraphRunSucceededEvent, GraphRunSucceededEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
@ -139,10 +140,12 @@ def test_run_parallel(mock_close, mock_remove):
assert not isinstance(item, NodeRunFailedEvent) assert not isinstance(item, NodeRunFailedEvent)
assert not isinstance(item, GraphRunFailedEvent) assert not isinstance(item, GraphRunFailedEvent)
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ['answer2', 'answer3']: if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
'answer2', 'answer3', 'answer4', 'answer5'
]:
assert item.parallel_id is not None assert item.parallel_id is not None
assert len(items) == 12 assert len(items) == 19
assert isinstance(items[0], GraphRunStartedEvent) assert isinstance(items[0], GraphRunStartedEvent)
assert isinstance(items[1], NodeRunStartedEvent) assert isinstance(items[1], NodeRunStartedEvent)
assert items[1].route_node_state.node_id == 'start' assert items[1].route_node_state.node_id == 'start'
@ -291,12 +294,16 @@ def test_run_branch(mock_close, mock_remove):
print(type(item), item) print(type(item), item)
items.append(item) items.append(item)
assert len(items) == 8 assert len(items) == 10
assert items[3].route_node_state.node_id == 'if-else-1' assert items[3].route_node_state.node_id == 'if-else-1'
assert items[4].route_node_state.node_id == 'if-else-1' assert items[4].route_node_state.node_id == 'if-else-1'
assert items[5].route_node_state.node_id == 'answer-1' assert isinstance(items[5], NodeRunStreamChunkEvent)
assert items[6].route_node_state.node_id == 'answer-1' assert items[5].chunk_content == '1 '
assert items[6].route_node_state.node_run_result.outputs['answer'] == '1 takato' assert isinstance(items[6], NodeRunStreamChunkEvent)
assert isinstance(items[7], GraphRunSucceededEvent) assert items[6].chunk_content == 'takato'
assert items[7].route_node_state.node_id == 'answer-1'
assert items[8].route_node_state.node_id == 'answer-1'
assert items[8].route_node_state.node_run_result.outputs['answer'] == '1 takato'
assert isinstance(items[9], GraphRunSucceededEvent)
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2)) # print(graph_engine.graph_runtime_state.model_dump_json(indent=2))

View File

@ -187,9 +187,8 @@ def test_process():
# " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else "")) # " " + (event.chunk_content if isinstance(event, NodeRunStreamChunkEvent) else ""))
if isinstance(event, NodeRunSucceededEvent): if isinstance(event, NodeRunSucceededEvent):
if 'llm' in event.route_node_state.node_id: if 'llm' in event.route_node_state.node_id:
variable_pool.append_variable( variable_pool.add(
event.route_node_state.node_id, [event.route_node_state.node_id, "text"],
["text"],
"".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1]))) "".join(str(i) for i in range(0, int(event.route_node_state.node_id[-1])))
) )
yield event yield event

View File

@ -45,23 +45,23 @@ def test_execute_if_else_result_true():
SystemVariable.FILES: [], SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa' SystemVariable.USER_ID: 'aaa'
}, user_inputs={}) }, user_inputs={})
pool.append_variable(node_id='start', variable_key_list=['array_contains'], value=['ab', 'def']) pool.add(['start', 'array_contains'], ['ab', 'def'])
pool.append_variable(node_id='start', variable_key_list=['array_not_contains'], value=['ac', 'def']) pool.add(['start', 'array_not_contains'], ['ac', 'def'])
pool.append_variable(node_id='start', variable_key_list=['contains'], value='cabcde') pool.add(['start', 'contains'], 'cabcde')
pool.append_variable(node_id='start', variable_key_list=['not_contains'], value='zacde') pool.add(['start', 'not_contains'], 'zacde')
pool.append_variable(node_id='start', variable_key_list=['start_with'], value='abc') pool.add(['start', 'start_with'], 'abc')
pool.append_variable(node_id='start', variable_key_list=['end_with'], value='zzab') pool.add(['start', 'end_with'], 'zzab')
pool.append_variable(node_id='start', variable_key_list=['is'], value='ab') pool.add(['start', 'is'], 'ab')
pool.append_variable(node_id='start', variable_key_list=['is_not'], value='aab') pool.add(['start', 'is_not'], 'aab')
pool.append_variable(node_id='start', variable_key_list=['empty'], value='') pool.add(['start', 'empty'], '')
pool.append_variable(node_id='start', variable_key_list=['not_empty'], value='aaa') pool.add(['start', 'not_empty'], 'aaa')
pool.append_variable(node_id='start', variable_key_list=['equals'], value=22) pool.add(['start', 'equals'], 22)
pool.append_variable(node_id='start', variable_key_list=['not_equals'], value=23) pool.add(['start', 'not_equals'], 23)
pool.append_variable(node_id='start', variable_key_list=['greater_than'], value=23) pool.add(['start', 'greater_than'], 23)
pool.append_variable(node_id='start', variable_key_list=['less_than'], value=21) pool.add(['start', 'less_than'], 21)
pool.append_variable(node_id='start', variable_key_list=['greater_than_or_equal'], value=22) pool.add(['start', 'greater_than_or_equal'], 22)
pool.append_variable(node_id='start', variable_key_list=['less_than_or_equal'], value=21) pool.add(['start', 'less_than_or_equal'], 21)
pool.append_variable(node_id='start', variable_key_list=['not_null'], value='1212') pool.add(['start', 'not_null'], '1212')
node = IfElseNode( node = IfElseNode(
graph_init_params=init_params, graph_init_params=init_params,
@ -204,13 +204,60 @@ def test_execute_if_else_result_true():
def test_execute_if_else_result_false(): def test_execute_if_else_result_false():
node = IfElseNode( graph_config = {
"edges": [
{
"id": "start-source-llm-target",
"source": "start",
"target": "llm",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm"
},
]
}
graph = Graph.init(
graph_config=graph_config
)
init_params = GraphInitParams(
tenant_id='1', tenant_id='1',
app_id='1', app_id='1',
workflow_type=WorkflowType.WORKFLOW,
workflow_id='1', workflow_id='1',
user_id='1', user_id='1',
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
call_depth=0
)
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
node = IfElseNode(
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
start_at=time.perf_counter()
),
config={ config={
'id': 'if-else', 'id': 'if-else',
'data': { 'data': {
@ -233,19 +280,11 @@ def test_execute_if_else_result_false():
} }
) )
# construct variable pool
pool = VariablePool(system_variables={
SystemVariable.FILES: [],
SystemVariable.USER_ID: 'aaa'
}, user_inputs={}, environment_variables=[])
pool.add(['start', 'array_contains'], ['1ab', 'def'])
pool.add(['start', 'array_not_contains'], ['ab', 'def'])
# Mock db.session.close() # Mock db.session.close()
db.session.close = MagicMock() db.session.close = MagicMock()
# execute node # execute node
result = node._run(pool) result = node._run()
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs['result'] is False assert result.outputs['result'] is False