mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 23:55:52 +08:00
fix bugs
This commit is contained in:
parent
cc96acdae3
commit
90e518b05b
@ -1,7 +1,7 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from core.file.file_obj import FileVar
|
from core.file.file_obj import FileVar
|
||||||
from core.workflow.entities.node_entities import SystemVariable
|
from core.workflow.entities.node_entities import SystemVariable
|
||||||
@ -38,9 +38,21 @@ class VariablePool(BaseModel):
|
|||||||
description='System variables',
|
description='System variables',
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
@model_validator(mode='before')
|
||||||
for system_variable, value in self.system_variables.items():
|
def append_system_variables(cls, v: dict) -> dict:
|
||||||
self.append_variable('sys', [system_variable.value], value)
|
"""
|
||||||
|
Append system variables
|
||||||
|
:param v: params
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
v['variables_mapping'] = {
|
||||||
|
'sys': {}
|
||||||
|
}
|
||||||
|
system_variables = v['system_variables']
|
||||||
|
for system_variable, value in system_variables.items():
|
||||||
|
variable_key_list_hash = hash((system_variable.value,))
|
||||||
|
v['variables_mapping']['sys'][variable_key_list_hash] = value
|
||||||
|
return v
|
||||||
|
|
||||||
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
||||||
"""
|
"""
|
||||||
|
@ -4,6 +4,7 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
|
|
||||||
|
|
||||||
class RunConditionHandler(ABC):
|
class RunConditionHandler(ABC):
|
||||||
@ -18,13 +19,13 @@ class RunConditionHandler(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
source_node_id: str,
|
previous_route_node_state: RouteNodeState,
|
||||||
target_node_id: str) -> bool:
|
target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_runtime_state: graph runtime state
|
:param graph_runtime_state: graph runtime state
|
||||||
:param source_node_id: source node id
|
:param previous_route_node_state: previous route node state
|
||||||
:param target_node_id: target node id
|
:param target_node_id: target node id
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
|
@ -1,29 +1,26 @@
|
|||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
|
|
||||||
|
|
||||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||||
|
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
source_node_id: str,
|
previous_route_node_state: RouteNodeState,
|
||||||
target_node_id: str) -> bool:
|
target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_runtime_state: graph runtime state
|
:param graph_runtime_state: graph runtime state
|
||||||
:param source_node_id: source node id
|
:param previous_route_node_state: previous route node state
|
||||||
:param target_node_id: target node id
|
:param target_node_id: target node id
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.condition.branch_identify:
|
if not self.condition.branch_identify:
|
||||||
raise Exception("Branch identify is required")
|
raise Exception("Branch identify is required")
|
||||||
|
|
||||||
node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id)
|
run_result = previous_route_node_state.node_run_result
|
||||||
if not node_route_state:
|
|
||||||
return False
|
|
||||||
|
|
||||||
run_result = node_route_state.node_run_result
|
|
||||||
if not run_result:
|
if not run_result:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
|
|
||||||
|
|
||||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||||
def check(self,
|
def check(self,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
source_node_id: str,
|
previous_route_node_state: RouteNodeState,
|
||||||
target_node_id: str) -> bool:
|
target_node_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
:param graph_runtime_state: graph runtime state
|
:param graph_runtime_state: graph runtime state
|
||||||
:param source_node_id: source node id
|
:param previous_route_node_state: previous route node state
|
||||||
:param target_node_id: target node id
|
:param target_node_id: target node id
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
|
@ -167,7 +167,7 @@ class GraphEngine:
|
|||||||
run_condition=edge.run_condition,
|
run_condition=edge.run_condition,
|
||||||
).check(
|
).check(
|
||||||
graph_runtime_state=self.graph_runtime_state,
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
source_node_id=edge.source_node_id,
|
previous_route_node_state=previous_route_node_state,
|
||||||
target_node_id=edge.target_node_id,
|
target_node_id=edge.target_node_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ from models.workflow import WorkflowType
|
|||||||
|
|
||||||
@patch('extensions.ext_database.db.session.remove')
|
@patch('extensions.ext_database.db.session.remove')
|
||||||
@patch('extensions.ext_database.db.session.close')
|
@patch('extensions.ext_database.db.session.close')
|
||||||
def test_run(mock_close, mock_remove):
|
def test_run_parallel(mock_close, mock_remove):
|
||||||
graph_config = {
|
graph_config = {
|
||||||
"edges": [
|
"edges": [
|
||||||
{
|
{
|
||||||
@ -127,7 +127,7 @@ def test_run(mock_close, mock_remove):
|
|||||||
max_execution_time=1200
|
max_execution_time=1200
|
||||||
)
|
)
|
||||||
|
|
||||||
print("")
|
# print("")
|
||||||
|
|
||||||
app = Flask('test')
|
app = Flask('test')
|
||||||
|
|
||||||
@ -135,7 +135,7 @@ def test_run(mock_close, mock_remove):
|
|||||||
with app.app_context():
|
with app.app_context():
|
||||||
generator = graph_engine.run()
|
generator = graph_engine.run()
|
||||||
for item in generator:
|
for item in generator:
|
||||||
print(type(item), item)
|
# print(type(item), item)
|
||||||
items.append(item)
|
items.append(item)
|
||||||
if isinstance(item, NodeRunSucceededEvent):
|
if isinstance(item, NodeRunSucceededEvent):
|
||||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||||
@ -153,4 +153,146 @@ def test_run(mock_close, mock_remove):
|
|||||||
assert isinstance(items[2], NodeRunSucceededEvent)
|
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||||
assert items[2].route_node_state.node_id == 'start'
|
assert items[2].route_node_state.node_id == 'start'
|
||||||
|
|
||||||
print(graph_engine.graph_runtime_state)
|
|
||||||
|
@patch('extensions.ext_database.db.session.remove')
|
||||||
|
@patch('extensions.ext_database.db.session.close')
|
||||||
|
def test_run_branch(mock_close, mock_remove):
|
||||||
|
graph_config = {
|
||||||
|
"edges": [{
|
||||||
|
"id": "1",
|
||||||
|
"source": "start",
|
||||||
|
"target": "if-else-1",
|
||||||
|
}, {
|
||||||
|
"id": "2",
|
||||||
|
"source": "if-else-1",
|
||||||
|
"sourceHandle": "true",
|
||||||
|
"target": "answer-1",
|
||||||
|
}, {
|
||||||
|
"id": "3",
|
||||||
|
"source": "if-else-1",
|
||||||
|
"sourceHandle": "false",
|
||||||
|
"target": "if-else-2",
|
||||||
|
}, {
|
||||||
|
"id": "4",
|
||||||
|
"source": "if-else-2",
|
||||||
|
"sourceHandle": "true",
|
||||||
|
"target": "answer-2",
|
||||||
|
}, {
|
||||||
|
"id": "5",
|
||||||
|
"source": "if-else-2",
|
||||||
|
"sourceHandle": "false",
|
||||||
|
"target": "answer-3",
|
||||||
|
}],
|
||||||
|
"nodes": [{
|
||||||
|
"data": {
|
||||||
|
"title": "Start",
|
||||||
|
"type": "start",
|
||||||
|
"variables": []
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"answer": "1",
|
||||||
|
"title": "Answer",
|
||||||
|
"type": "answer",
|
||||||
|
"variables": []
|
||||||
|
},
|
||||||
|
"id": "answer-1",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"cases": [{
|
||||||
|
"case_id": "true",
|
||||||
|
"conditions": [{
|
||||||
|
"comparison_operator": "contains",
|
||||||
|
"id": "b0f02473-08b6-4a81-af91-15345dcb2ec8",
|
||||||
|
"value": "hi",
|
||||||
|
"varType": "string",
|
||||||
|
"variable_selector": ["sys", "query"]
|
||||||
|
}],
|
||||||
|
"id": "true",
|
||||||
|
"logical_operator": "and"
|
||||||
|
}],
|
||||||
|
"desc": "",
|
||||||
|
"title": "IF/ELSE",
|
||||||
|
"type": "if-else"
|
||||||
|
},
|
||||||
|
"id": "if-else-1",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"cases": [{
|
||||||
|
"case_id": "true",
|
||||||
|
"conditions": [{
|
||||||
|
"comparison_operator": "contains",
|
||||||
|
"id": "ae895199-5608-433b-b5f0-0997ae1431e4",
|
||||||
|
"value": "takatost",
|
||||||
|
"varType": "string",
|
||||||
|
"variable_selector": ["sys", "query"]
|
||||||
|
}],
|
||||||
|
"id": "true",
|
||||||
|
"logical_operator": "and"
|
||||||
|
}],
|
||||||
|
"title": "IF/ELSE 2",
|
||||||
|
"type": "if-else"
|
||||||
|
},
|
||||||
|
"id": "if-else-2",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"answer": "2",
|
||||||
|
"title": "Answer 2",
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer-2",
|
||||||
|
}, {
|
||||||
|
"data": {
|
||||||
|
"answer": "3",
|
||||||
|
"title": "Answer 3",
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer-3",
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(system_variables={
|
||||||
|
SystemVariable.QUERY: 'hi',
|
||||||
|
SystemVariable.FILES: [],
|
||||||
|
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||||
|
SystemVariable.USER_ID: 'aaa'
|
||||||
|
}, user_inputs={})
|
||||||
|
|
||||||
|
graph_engine = GraphEngine(
|
||||||
|
tenant_id="111",
|
||||||
|
app_id="222",
|
||||||
|
workflow_type=WorkflowType.CHAT,
|
||||||
|
workflow_id="333",
|
||||||
|
user_id="444",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
|
call_depth=0,
|
||||||
|
graph=graph,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
max_execution_steps=500,
|
||||||
|
max_execution_time=1200
|
||||||
|
)
|
||||||
|
|
||||||
|
# print("")
|
||||||
|
|
||||||
|
app = Flask('test')
|
||||||
|
|
||||||
|
items = []
|
||||||
|
with app.app_context():
|
||||||
|
generator = graph_engine.run()
|
||||||
|
for item in generator:
|
||||||
|
# print(type(item), item)
|
||||||
|
items.append(item)
|
||||||
|
|
||||||
|
assert len(items) == 8
|
||||||
|
assert items[3].route_node_state.node_id == 'if-else-1'
|
||||||
|
assert items[4].route_node_state.node_id == 'if-else-1'
|
||||||
|
assert items[5].route_node_state.node_id == 'answer-1'
|
||||||
|
assert items[6].route_node_state.node_id == 'answer-1'
|
||||||
|
|
||||||
|
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user