This commit is contained in:
takatost 2024-07-17 16:54:49 +08:00
parent cc96acdae3
commit 90e518b05b
6 changed files with 173 additions and 20 deletions

View File

@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, model_validator
from core.file.file_obj import FileVar from core.file.file_obj import FileVar
from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.node_entities import SystemVariable
@ -38,9 +38,21 @@ class VariablePool(BaseModel):
description='System variables', description='System variables',
) )
def __post_init__(self): @model_validator(mode='before')
for system_variable, value in self.system_variables.items(): def append_system_variables(cls, v: dict) -> dict:
self.append_variable('sys', [system_variable.value], value) """
Append system variables
:param v: params
:return:
"""
v['variables_mapping'] = {
'sys': {}
}
system_variables = v['system_variables']
for system_variable, value in system_variables.items():
variable_key_list_hash = hash((system_variable.value,))
v['variables_mapping']['sys'][variable_key_list_hash] = value
return v
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None: def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
""" """

View File

@ -4,6 +4,7 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class RunConditionHandler(ABC): class RunConditionHandler(ABC):
@ -18,13 +19,13 @@ class RunConditionHandler(ABC):
@abstractmethod @abstractmethod
def check(self, def check(self,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
source_node_id: str, previous_route_node_state: RouteNodeState,
target_node_id: str) -> bool: target_node_id: str) -> bool:
""" """
Check if the condition can be executed Check if the condition can be executed
:param graph_runtime_state: graph runtime state :param graph_runtime_state: graph runtime state
:param source_node_id: source node id :param previous_route_node_state: previous route node state
:param target_node_id: target node id :param target_node_id: target node id
:return: bool :return: bool
""" """

View File

@ -1,29 +1,26 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
class BranchIdentifyRunConditionHandler(RunConditionHandler): class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self, def check(self,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
source_node_id: str, previous_route_node_state: RouteNodeState,
target_node_id: str) -> bool: target_node_id: str) -> bool:
""" """
Check if the condition can be executed Check if the condition can be executed
:param graph_runtime_state: graph runtime state :param graph_runtime_state: graph runtime state
:param source_node_id: source node id :param previous_route_node_state: previous route node state
:param target_node_id: target node id :param target_node_id: target node id
:return: bool :return: bool
""" """
if not self.condition.branch_identify: if not self.condition.branch_identify:
raise Exception("Branch identify is required") raise Exception("Branch identify is required")
node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id) run_result = previous_route_node_state.node_run_result
if not node_route_state:
return False
run_result = node_route_state.node_run_result
if not run_result: if not run_result:
return False return False

View File

@ -1,18 +1,19 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.utils.condition.processor import ConditionProcessor from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler): class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self, def check(self,
graph_runtime_state: GraphRuntimeState, graph_runtime_state: GraphRuntimeState,
source_node_id: str, previous_route_node_state: RouteNodeState,
target_node_id: str) -> bool: target_node_id: str) -> bool:
""" """
Check if the condition can be executed Check if the condition can be executed
:param graph_runtime_state: graph runtime state :param graph_runtime_state: graph runtime state
:param source_node_id: source node id :param previous_route_node_state: previous route node state
:param target_node_id: target node id :param target_node_id: target node id
:return: bool :return: bool
""" """

View File

@ -167,7 +167,7 @@ class GraphEngine:
run_condition=edge.run_condition, run_condition=edge.run_condition,
).check( ).check(
graph_runtime_state=self.graph_runtime_state, graph_runtime_state=self.graph_runtime_state,
source_node_id=edge.source_node_id, previous_route_node_state=previous_route_node_state,
target_node_id=edge.target_node_id, target_node_id=edge.target_node_id,
) )

View File

@ -21,7 +21,7 @@ from models.workflow import WorkflowType
@patch('extensions.ext_database.db.session.remove') @patch('extensions.ext_database.db.session.remove')
@patch('extensions.ext_database.db.session.close') @patch('extensions.ext_database.db.session.close')
def test_run(mock_close, mock_remove): def test_run_parallel(mock_close, mock_remove):
graph_config = { graph_config = {
"edges": [ "edges": [
{ {
@ -127,7 +127,7 @@ def test_run(mock_close, mock_remove):
max_execution_time=1200 max_execution_time=1200
) )
print("") # print("")
app = Flask('test') app = Flask('test')
@ -135,7 +135,7 @@ def test_run(mock_close, mock_remove):
with app.app_context(): with app.app_context():
generator = graph_engine.run() generator = graph_engine.run()
for item in generator: for item in generator:
print(type(item), item) # print(type(item), item)
items.append(item) items.append(item)
if isinstance(item, NodeRunSucceededEvent): if isinstance(item, NodeRunSucceededEvent):
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
@ -153,4 +153,146 @@ def test_run(mock_close, mock_remove):
assert isinstance(items[2], NodeRunSucceededEvent) assert isinstance(items[2], NodeRunSucceededEvent)
assert items[2].route_node_state.node_id == 'start' assert items[2].route_node_state.node_id == 'start'
print(graph_engine.graph_runtime_state)
@patch('extensions.ext_database.db.session.remove')
@patch('extensions.ext_database.db.session.close')
def test_run_branch(mock_close, mock_remove):
graph_config = {
"edges": [{
"id": "1",
"source": "start",
"target": "if-else-1",
}, {
"id": "2",
"source": "if-else-1",
"sourceHandle": "true",
"target": "answer-1",
}, {
"id": "3",
"source": "if-else-1",
"sourceHandle": "false",
"target": "if-else-2",
}, {
"id": "4",
"source": "if-else-2",
"sourceHandle": "true",
"target": "answer-2",
}, {
"id": "5",
"source": "if-else-2",
"sourceHandle": "false",
"target": "answer-3",
}],
"nodes": [{
"data": {
"title": "Start",
"type": "start",
"variables": []
},
"id": "start"
}, {
"data": {
"answer": "1",
"title": "Answer",
"type": "answer",
"variables": []
},
"id": "answer-1",
}, {
"data": {
"cases": [{
"case_id": "true",
"conditions": [{
"comparison_operator": "contains",
"id": "b0f02473-08b6-4a81-af91-15345dcb2ec8",
"value": "hi",
"varType": "string",
"variable_selector": ["sys", "query"]
}],
"id": "true",
"logical_operator": "and"
}],
"desc": "",
"title": "IF/ELSE",
"type": "if-else"
},
"id": "if-else-1",
}, {
"data": {
"cases": [{
"case_id": "true",
"conditions": [{
"comparison_operator": "contains",
"id": "ae895199-5608-433b-b5f0-0997ae1431e4",
"value": "takatost",
"varType": "string",
"variable_selector": ["sys", "query"]
}],
"id": "true",
"logical_operator": "and"
}],
"title": "IF/ELSE 2",
"type": "if-else"
},
"id": "if-else-2",
}, {
"data": {
"answer": "2",
"title": "Answer 2",
"type": "answer",
},
"id": "answer-2",
}, {
"data": {
"answer": "3",
"title": "Answer 3",
"type": "answer",
},
"id": "answer-3",
}]
}
graph = Graph.init(
graph_config=graph_config
)
variable_pool = VariablePool(system_variables={
SystemVariable.QUERY: 'hi',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200
)
# print("")
app = Flask('test')
items = []
with app.app_context():
generator = graph_engine.run()
for item in generator:
# print(type(item), item)
items.append(item)
assert len(items) == 8
assert items[3].route_node_state.node_id == 'if-else-1'
assert items[4].route_node_state.node_id == 'if-else-1'
assert items[5].route_node_state.node_id == 'answer-1'
assert items[6].route_node_state.node_id == 'answer-1'
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))