mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 12:15:52 +08:00
save
This commit is contained in:
parent
1d8ecac093
commit
8375517ccd
@ -1,6 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
@ -21,20 +23,23 @@ class ValueType(Enum):
|
||||
FILE = "file"
|
||||
|
||||
|
||||
class VariablePool:
|
||||
class VariablePool(BaseModel):
|
||||
|
||||
def __init__(self, system_variables: dict[SystemVariable, Any],
|
||||
user_inputs: dict) -> None:
|
||||
# system variables
|
||||
# for example:
|
||||
# {
|
||||
# 'query': 'abc',
|
||||
# 'files': []
|
||||
# }
|
||||
self.variables_mapping = {}
|
||||
self.user_inputs = user_inputs
|
||||
self.system_variables = system_variables
|
||||
for system_variable, value in system_variables.items():
|
||||
variables_mapping: dict[str, dict[int, VariableValue]] = Field(
|
||||
description='Variables mapping',
|
||||
default={},
|
||||
)
|
||||
|
||||
user_inputs: dict = Field(
|
||||
description='User inputs',
|
||||
)
|
||||
|
||||
system_variables: dict[SystemVariable, Any] = Field(
|
||||
description='System variables',
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for system_variable, value in self.system_variables.items():
|
||||
self.append_variable('sys', [system_variable.value], value)
|
||||
|
||||
def append_variable(self, node_id: str, variable_key_list: list[str], value: VariableValue) -> None:
|
||||
|
@ -1,21 +1,40 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
"""branch identify, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
|
||||
class GraphNode(BaseModel):
|
||||
id: str
|
||||
"""node id"""
|
||||
|
||||
parent_id: Optional[str] = None
|
||||
"""parent node id, e.g. iteration/loop"""
|
||||
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
children_node_ids: list[str] = []
|
||||
"""children node ids"""
|
||||
descendant_node_ids: list[str] = []
|
||||
"""descendant node ids"""
|
||||
|
||||
run_condition_callback: Optional[Callable] = None
|
||||
"""condition function check if the node can be executed"""
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""condition to run the node"""
|
||||
|
||||
run_condition_callback: Optional[Callable] = Field(None, exclude=True)
|
||||
"""condition function check if the node can be executed, translated from run_conditions, not serialized"""
|
||||
|
||||
node_config: dict
|
||||
"""original node config"""
|
||||
@ -23,72 +42,96 @@ class GraphNode(BaseModel):
|
||||
source_edge_config: Optional[dict] = None
|
||||
"""original source edge config"""
|
||||
|
||||
target_edge_config: Optional[dict] = None
|
||||
"""original target edge config"""
|
||||
sub_graph: Optional["Graph"] = None
|
||||
"""sub graph of the node, e.g. iteration/loop sub graph"""
|
||||
|
||||
def add_child(self, node_id: str) -> None:
|
||||
self.children_node_ids.append(node_id)
|
||||
self.descendant_node_ids.append(node_id)
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
graph_config: dict
|
||||
"""graph config from workflow"""
|
||||
|
||||
graph_nodes: dict[str, GraphNode] = {}
|
||||
"""graph nodes"""
|
||||
|
||||
root_node: Optional[GraphNode] = None
|
||||
root_node: GraphNode
|
||||
"""root node of the graph"""
|
||||
|
||||
@classmethod
|
||||
def init(cls, root_node_config: dict, run_condition: Optional[RunCondition] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param root_node_config: root node config
|
||||
:param run_condition: run condition when root node parent is iteration/loop
|
||||
:return: graph
|
||||
"""
|
||||
root_node = GraphNode(
|
||||
id=root_node_config.get('id'),
|
||||
parent_id=root_node_config.get('parentId'),
|
||||
node_config=root_node_config,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
graph = cls(root_node=root_node)
|
||||
|
||||
# TODO parse run_condition to run_condition_callback
|
||||
|
||||
graph.add_graph_node(graph.root_node)
|
||||
return graph
|
||||
|
||||
def add_edge(self, edge_config: dict,
|
||||
source_node_config: dict,
|
||||
target_node_config: dict,
|
||||
run_condition_callback: Optional[Callable] = None) -> None:
|
||||
target_node_sub_graph: Optional["Graph"] = None) -> None:
|
||||
"""
|
||||
Add edge to the graph
|
||||
|
||||
:param edge_config: edge config
|
||||
:param source_node_config: source node config
|
||||
:param target_node_config: target node config
|
||||
:param run_condition_callback: condition callback
|
||||
:param target_node_sub_graph: sub graph
|
||||
"""
|
||||
source_node_id = source_node_config.get('id')
|
||||
if not source_node_id:
|
||||
return
|
||||
|
||||
if source_node_id not in self.graph_nodes:
|
||||
return
|
||||
|
||||
target_node_id = target_node_config.get('id')
|
||||
if not target_node_id:
|
||||
return
|
||||
|
||||
if source_node_id not in self.graph_nodes:
|
||||
source_graph_node = GraphNode(
|
||||
id=source_node_id,
|
||||
node_config=source_node_config,
|
||||
children_node_ids=[target_node_id],
|
||||
target_edge_config=edge_config,
|
||||
)
|
||||
source_node = self.graph_nodes[source_node_id]
|
||||
source_node.add_child(target_node_id)
|
||||
|
||||
# if run_conditions:
|
||||
# run_condition_callback = lambda: all()
|
||||
|
||||
self.add_graph_node(source_graph_node)
|
||||
else:
|
||||
source_node = self.graph_nodes[source_node_id]
|
||||
source_node.add_child(target_node_id)
|
||||
source_node.target_edge_config = edge_config
|
||||
|
||||
if target_node_id not in self.graph_nodes:
|
||||
run_condition = None # todo
|
||||
run_condition_callback = None # todo
|
||||
|
||||
target_graph_node = GraphNode(
|
||||
id=target_node_id,
|
||||
parent_id=source_node_config.get('parentId'),
|
||||
predecessor_node_id=source_node_id,
|
||||
node_config=target_node_config,
|
||||
run_condition=run_condition,
|
||||
run_condition_callback=run_condition_callback,
|
||||
source_edge_config=edge_config,
|
||||
sub_graph=target_node_sub_graph
|
||||
)
|
||||
|
||||
self.add_graph_node(target_graph_node)
|
||||
else:
|
||||
target_node = self.graph_nodes[target_node_id]
|
||||
target_node.predecessor_node_id = source_node_id
|
||||
target_node.run_conditions = run_conditions
|
||||
target_node.run_condition_callback = run_condition_callback
|
||||
target_node.source_edge_config = edge_config
|
||||
target_node.sub_graph = target_node_sub_graph
|
||||
|
||||
def add_graph_node(self, graph_node: GraphNode) -> None:
|
||||
"""
|
||||
@ -123,13 +166,13 @@ class Graph(BaseModel):
|
||||
return None
|
||||
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
if not graph_node.children_node_ids:
|
||||
if not graph_node.descendant_node_ids:
|
||||
return None
|
||||
|
||||
descendants_graph = Graph(graph_config=self.graph_config)
|
||||
descendants_graph = Graph()
|
||||
descendants_graph.add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.children_node_ids:
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
|
||||
return descendants_graph
|
||||
@ -147,5 +190,5 @@ class Graph(BaseModel):
|
||||
graph_node = self.graph_nodes[node_id]
|
||||
descendants_graph.add_graph_node(graph_node)
|
||||
|
||||
for child_node_id in graph_node.children_node_ids:
|
||||
for child_node_id in graph_node.descendant_node_ids:
|
||||
self._add_descendants_graph_nodes(descendants_graph, child_node_id)
|
||||
|
@ -1,26 +1,12 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class IfElseNodeData(BaseNodeData):
|
||||
"""
|
||||
Answer Node Data.
|
||||
"""
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
||||
|
||||
logical_operator: Literal["and", "or"] = "and"
|
||||
conditions: list[Condition]
|
||||
|
@ -1,10 +1,11 @@
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.if_else.entities import IfElseNodeData
|
||||
from core.workflow.utils.condition.processor import ConditionAssertionError, ConditionProcessor
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
@ -19,90 +20,42 @@ class IfElseNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
node_data = self.node_data
|
||||
node_data = cast(self._node_data_cls, node_data)
|
||||
node_data = cast(IfElseNodeData, node_data)
|
||||
|
||||
node_inputs = {
|
||||
node_inputs: dict[str, list] = {
|
||||
"conditions": []
|
||||
}
|
||||
|
||||
process_datas = {
|
||||
process_datas: dict[str, list] = {
|
||||
"condition_results": []
|
||||
}
|
||||
|
||||
try:
|
||||
logical_operator = node_data.logical_operator
|
||||
input_conditions = []
|
||||
for condition in node_data.conditions:
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
)
|
||||
processor = ConditionProcessor()
|
||||
compare_result, sub_condition_compare_results = processor.process(
|
||||
variable_pool=variable_pool,
|
||||
logical_operator=node_data.logical_operator,
|
||||
conditions=node_data.conditions,
|
||||
)
|
||||
|
||||
expected_value = condition.value
|
||||
node_inputs["conditions"] = [{
|
||||
"actual_value": result['actual_value'],
|
||||
"expected_value": result['expected_value'],
|
||||
"comparison_operator": result['comparison_operator'],
|
||||
} for result in sub_condition_compare_results]
|
||||
|
||||
input_conditions.append({
|
||||
"actual_value": actual_value,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": condition.comparison_operator
|
||||
})
|
||||
|
||||
node_inputs["conditions"] = input_conditions
|
||||
|
||||
for input_condition in input_conditions:
|
||||
actual_value = input_condition["actual_value"]
|
||||
expected_value = input_condition["expected_value"]
|
||||
comparison_operator = input_condition["comparison_operator"]
|
||||
|
||||
if comparison_operator == "contains":
|
||||
compare_result = self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
compare_result = self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
compare_result = self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
compare_result = self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
compare_result = self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
compare_result = self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
compare_result = self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
compare_result = self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
compare_result = self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
compare_result = self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
compare_result = self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
compare_result = self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
compare_result = self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
compare_result = self._assert_not_null(actual_value)
|
||||
else:
|
||||
continue
|
||||
|
||||
process_datas["condition_results"].append({
|
||||
**input_condition,
|
||||
"result": compare_result
|
||||
})
|
||||
except Exception as e:
|
||||
process_datas["condition_results"] = sub_condition_compare_results
|
||||
except ConditionAssertionError as e:
|
||||
node_inputs["conditions"] = e.conditions
|
||||
process_datas["condition_results"] = e.sub_condition_compare_results
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=node_inputs,
|
||||
process_data=process_datas,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
if logical_operator == "and":
|
||||
compare_result = False not in [condition["result"] for condition in process_datas["condition_results"]]
|
||||
else:
|
||||
compare_result = True in [condition["result"] for condition in process_datas["condition_results"]]
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -114,280 +67,6 @@ class IfElseNode(BaseNode):
|
||||
}
|
||||
)
|
||||
|
||||
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value not in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert start with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.startswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert end with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.endswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is not
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert not empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value <= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value >= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value < expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value > expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert not null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
0
api/core/workflow/utils/condition/__init__.py
Normal file
0
api/core/workflow/utils/condition/__init__.py
Normal file
17
api/core/workflow/utils/condition/entities.py
Normal file
17
api/core/workflow/utils/condition/entities.py
Normal file
@ -0,0 +1,17 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
22
api/core/workflow/utils/condition/funcs.py
Normal file
22
api/core/workflow/utils/condition/funcs.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_runtime_state_entities import WorkflowRuntimeState
|
||||
from core.workflow.graph import GraphNode
|
||||
|
||||
|
||||
def source_handle_condition_func(workflow_runtime_state: WorkflowRuntimeState,
|
||||
graph_node: GraphNode,
|
||||
# TODO cycle_state optional
|
||||
predecessor_node_run_result: Optional[NodeRunResult] = None) -> bool:
|
||||
if not graph_node.source_edge_config:
|
||||
return False
|
||||
|
||||
if not graph_node.source_edge_config.get('sourceHandle'):
|
||||
return True
|
||||
|
||||
source_handle = predecessor_node_run_result.edge_source_handle \
|
||||
if predecessor_node_run_result else None
|
||||
|
||||
return (source_handle is not None
|
||||
and graph_node.source_edge_config.get('sourceHandle') == source_handle)
|
369
api/core/workflow/utils/condition/processor.py
Normal file
369
api/core/workflow/utils/condition/processor.py
Normal file
@ -0,0 +1,369 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process(self, variable_pool: VariablePool,
|
||||
logical_operator: Literal["and", "or"],
|
||||
conditions: list[Condition]) -> tuple[bool, list[dict]]:
|
||||
"""
|
||||
Process conditions
|
||||
|
||||
:param variable_pool: variable pool
|
||||
:param logical_operator: logical operator
|
||||
:param conditions: conditions
|
||||
"""
|
||||
input_conditions = []
|
||||
sub_condition_compare_results = []
|
||||
|
||||
try:
|
||||
for condition in conditions:
|
||||
actual_value = variable_pool.get_variable_value(
|
||||
variable_selector=condition.variable_selector
|
||||
)
|
||||
|
||||
expected_value = condition.value
|
||||
|
||||
input_conditions.append({
|
||||
"actual_value": actual_value,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": condition.comparison_operator
|
||||
})
|
||||
|
||||
for input_condition in input_conditions:
|
||||
actual_value = input_condition["actual_value"]
|
||||
expected_value = input_condition["expected_value"]
|
||||
comparison_operator = input_condition["comparison_operator"]
|
||||
|
||||
if comparison_operator == "contains":
|
||||
compare_result = self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
compare_result = self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
compare_result = self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
compare_result = self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
compare_result = self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
compare_result = self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
compare_result = self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
compare_result = self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
compare_result = self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
compare_result = self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
compare_result = self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
compare_result = self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
compare_result = self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
compare_result = self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
compare_result = self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
compare_result = self._assert_not_null(actual_value)
|
||||
else:
|
||||
continue
|
||||
|
||||
sub_condition_compare_results.append({
|
||||
**input_condition,
|
||||
"result": compare_result
|
||||
})
|
||||
except Exception as e:
|
||||
raise ConditionAssertionError(str(e), input_conditions, sub_condition_compare_results)
|
||||
|
||||
if logical_operator == "and":
|
||||
compare_result = False not in [condition["result"] for condition in sub_condition_compare_results]
|
||||
else:
|
||||
compare_result = True in [condition["result"] for condition in sub_condition_compare_results]
|
||||
|
||||
return compare_result, sub_condition_compare_results
|
||||
|
||||
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value not in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert start with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.startswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert end with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.endswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is not
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert not empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value <= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value >= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert greater than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value < expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert less than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value > expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert not null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ConditionAssertionError(Exception):
|
||||
def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
|
||||
self.message = message
|
||||
self.conditions = conditions
|
||||
self.sub_condition_compare_results = sub_condition_compare_results
|
||||
super().__init__(self.message)
|
@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
@ -129,18 +128,172 @@ class WorkflowEngineManager:
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
graph_config=graph_config,
|
||||
workflow_runtime_state=workflow_runtime_state,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
try:
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
graph_config=graph_config,
|
||||
workflow_runtime_state=workflow_runtime_state,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
except WorkflowRunFailedError as e:
|
||||
self._workflow_run_failed(
|
||||
error=e.error,
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
self._workflow_run_failed(
|
||||
error=str(e),
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
# workflow run success
|
||||
self._workflow_run_success(
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
|
||||
"""
|
||||
Initialize graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id if needed
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get('edges')
|
||||
if not edge_configs:
|
||||
return None
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
source_edges_mapping: dict[str, list[dict]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in source_edges_mapping:
|
||||
source_edges_mapping[source_node_id] = []
|
||||
|
||||
source_edges_mapping[source_node_id].append(edge_config)
|
||||
|
||||
target_node_id = edge_config.get('target')
|
||||
if target_node_id:
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
if not node_configs:
|
||||
return None
|
||||
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
nodes_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
nodes_mapping[node_id] = node_config
|
||||
|
||||
# fetch root node
|
||||
if root_node_id:
|
||||
root_node_config = next((node_config for node_config in root_node_configs
|
||||
if node_config.get('id') == root_node_id), None)
|
||||
else:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_config = next((node_config for node_config in root_node_configs
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||
|
||||
if not root_node_config:
|
||||
return None
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
root_node_config=root_node_config
|
||||
)
|
||||
|
||||
# add edge from root node
|
||||
self._recursively_add_edges(
|
||||
graph=graph,
|
||||
source_node_config=root_node_config,
|
||||
edges_mapping=source_edges_mapping,
|
||||
nodes_mapping=nodes_mapping,
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def _recursively_add_edges(self, graph: Graph,
|
||||
source_node_config: dict,
|
||||
edges_mapping: dict,
|
||||
nodes_mapping: dict,
|
||||
root_node_configs: list[dict]) -> None:
|
||||
"""
|
||||
Add edges
|
||||
|
||||
:param source_node_config: source node config
|
||||
:param edges_mapping: edges mapping
|
||||
:param nodes_mapping: nodes mapping
|
||||
:param root_node_configs: root node configs
|
||||
"""
|
||||
source_node_id = source_node_config.get('id')
|
||||
if not source_node_id:
|
||||
return
|
||||
|
||||
for edge_config in edges_mapping.get(source_node_id, []):
|
||||
target_node_id = edge_config.get('target')
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
target_node_config = nodes_mapping.get(target_node_id)
|
||||
if not target_node_config:
|
||||
continue
|
||||
|
||||
sub_graph: Optional[Graph] = None
|
||||
target_node_type: NodeType = NodeType.value_of(target_node_config.get('data', {}).get('type'))
|
||||
if target_node_type and target_node_type in [IterationNode.node_type, NodeType.LOOP]:
|
||||
# find iteration/loop sub nodes that have no predecessor node
|
||||
for root_node_config in root_node_configs:
|
||||
if root_node_config.get('parentId') == target_node_id:
|
||||
# create sub graph
|
||||
sub_graph = Graph.init(
|
||||
root_node_config=root_node_config
|
||||
)
|
||||
|
||||
self._recursively_add_edges(
|
||||
graph=sub_graph,
|
||||
source_node_config=root_node_config,
|
||||
edges_mapping=edges_mapping,
|
||||
nodes_mapping=nodes_mapping,
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
break
|
||||
|
||||
# add edge
|
||||
graph.add_edge(
|
||||
edge_config=edge_config,
|
||||
source_node_config=source_node_config,
|
||||
target_node_config=target_node_config,
|
||||
target_node_sub_graph=sub_graph,
|
||||
)
|
||||
|
||||
# recursively add edges
|
||||
self._recursively_add_edges(
|
||||
graph=graph,
|
||||
source_node_config=target_node_config,
|
||||
edges_mapping=edges_mapping,
|
||||
nodes_mapping=nodes_mapping,
|
||||
root_node_configs=root_node_configs
|
||||
)
|
||||
|
||||
def _run_workflow(self, graph_config: dict,
|
||||
workflow_runtime_state: WorkflowRuntimeState,
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
@ -157,10 +310,15 @@ class WorkflowEngineManager:
|
||||
"""
|
||||
try:
|
||||
# init graph
|
||||
graph = Graph(
|
||||
graph = self._init_graph(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise WorkflowRunFailedError(
|
||||
error='Start node not found in workflow graph.'
|
||||
)
|
||||
|
||||
predecessor_node: Optional[BaseNode] = None
|
||||
current_iteration_node: Optional[BaseIterationNode] = None
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
@ -231,11 +389,11 @@ class WorkflowEngineManager:
|
||||
|
||||
# max steps reached
|
||||
if workflow_run_state.workflow_node_steps > max_execution_steps:
|
||||
raise ValueError('Max steps {} reached.'.format(max_execution_steps))
|
||||
raise WorkflowRunFailedError('Max steps {} reached.'.format(max_execution_steps))
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
|
||||
raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
|
||||
raise WorkflowRunFailedError('Max execution time {}s reached.'.format(max_execution_time))
|
||||
|
||||
if len(next_nodes) == 1:
|
||||
next_node = next_nodes[0]
|
||||
@ -256,63 +414,59 @@ class WorkflowEngineManager:
|
||||
else:
|
||||
result_dict = {}
|
||||
|
||||
# new thread
|
||||
worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'graph': graph,
|
||||
'workflow_run_state': workflow_run_state,
|
||||
'predecessor_node': predecessor_node,
|
||||
'next_nodes': next_nodes,
|
||||
'callbacks': callbacks,
|
||||
'result': result_dict
|
||||
})
|
||||
|
||||
worker_thread.start()
|
||||
worker_thread.join()
|
||||
# # new thread
|
||||
# worker_thread = threading.Thread(target=self._async_run_nodes, kwargs={
|
||||
# 'flask_app': current_app._get_current_object(),
|
||||
# 'graph': graph,
|
||||
# 'workflow_run_state': workflow_run_state,
|
||||
# 'predecessor_node': predecessor_node,
|
||||
# 'next_nodes': next_nodes,
|
||||
# 'callbacks': callbacks,
|
||||
# 'result': result_dict
|
||||
# })
|
||||
#
|
||||
# worker_thread.start()
|
||||
# worker_thread.join()
|
||||
|
||||
if not workflow_run_state.workflow_node_runs:
|
||||
self._workflow_run_failed(
|
||||
error='Start node not found in workflow graph.',
|
||||
callbacks=callbacks
|
||||
raise WorkflowRunFailedError(
|
||||
error='Start node not found in workflow graph.'
|
||||
)
|
||||
return
|
||||
except GenerateTaskStoppedException as e:
|
||||
return
|
||||
except Exception as e:
|
||||
self._workflow_run_failed(
|
||||
error=str(e),
|
||||
callbacks=callbacks
|
||||
raise WorkflowRunFailedError(
|
||||
error=str(e)
|
||||
)
|
||||
return
|
||||
|
||||
def _async_run_nodes(self, flask_app: Flask,
|
||||
graph: dict,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
predecessor_node: Optional[BaseNode],
|
||||
next_nodes: list[BaseNode],
|
||||
callbacks: list[BaseWorkflowCallback],
|
||||
result: dict):
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
for next_node in next_nodes:
|
||||
# TODO run sub workflows
|
||||
# run node
|
||||
is_continue = self._run_node(
|
||||
graph=graph,
|
||||
workflow_run_state=workflow_run_state,
|
||||
predecessor_node=predecessor_node,
|
||||
current_node=next_node,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
if not is_continue:
|
||||
break
|
||||
|
||||
predecessor_node = next_node
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating")
|
||||
finally:
|
||||
db.session.remove()
|
||||
# def _async_run_nodes(self, flask_app: Flask,
|
||||
# graph: dict,
|
||||
# workflow_run_state: WorkflowRunState,
|
||||
# predecessor_node: Optional[BaseNode],
|
||||
# next_nodes: list[BaseNode],
|
||||
# callbacks: list[BaseWorkflowCallback],
|
||||
# result: dict):
|
||||
# with flask_app.app_context():
|
||||
# try:
|
||||
# for next_node in next_nodes:
|
||||
# # TODO run sub workflows
|
||||
# # run node
|
||||
# is_continue = self._run_node(
|
||||
# graph=graph,
|
||||
# workflow_run_state=workflow_run_state,
|
||||
# predecessor_node=predecessor_node,
|
||||
# current_node=next_node,
|
||||
# callbacks=callbacks
|
||||
# )
|
||||
#
|
||||
# if not is_continue:
|
||||
# break
|
||||
#
|
||||
# predecessor_node = next_node
|
||||
# except Exception as e:
|
||||
# logger.exception("Unknown Error when generating")
|
||||
# finally:
|
||||
# db.session.remove()
|
||||
|
||||
def _run_node(self, graph: dict,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
@ -584,14 +738,25 @@ class WorkflowEngineManager:
|
||||
workflow_call_depth=0
|
||||
)
|
||||
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
graph=workflow.graph,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks,
|
||||
start_node=node_id,
|
||||
end_node=end_node_id
|
||||
)
|
||||
try:
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
graph_config=workflow.graph,
|
||||
workflow_runtime_state=workflow_runtime_state,
|
||||
callbacks=callbacks,
|
||||
start_node=node_id,
|
||||
end_node=end_node_id
|
||||
)
|
||||
except WorkflowRunFailedError as e:
|
||||
self._workflow_run_failed(
|
||||
error=e.error,
|
||||
callbacks=callbacks
|
||||
)
|
||||
except Exception as e:
|
||||
self._workflow_run_failed(
|
||||
error=str(e),
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
# workflow run success
|
||||
self._workflow_run_success(
|
||||
@ -1072,3 +1237,8 @@ class WorkflowEngineManager:
|
||||
variable_key_list=variable_key_list,
|
||||
value=value
|
||||
)
|
||||
|
||||
|
||||
class WorkflowRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
|
@ -0,0 +1,234 @@
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
|
||||
|
||||
def test__init_graph():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-source-answer-target",
|
||||
"source": "llm",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "1717222650545-source-1719481290322-target",
|
||||
"source": "1717222650545",
|
||||
"target": "1719481290322",
|
||||
},
|
||||
{
|
||||
"id": "1719481290322-1-llm-target",
|
||||
"source": "1719481290322",
|
||||
"sourceHandle": "1",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "1719481290322-2-1719481315734-target",
|
||||
"source": "1719481290322",
|
||||
"sourceHandle": "2",
|
||||
"target": "1719481315734",
|
||||
},
|
||||
{
|
||||
"id": "1719481315734-source-1719481326339-target",
|
||||
"source": "1719481315734",
|
||||
"target": "1719481326339",
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"desc": "",
|
||||
"title": "Start",
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"label": "name",
|
||||
"max_length": 48,
|
||||
"options": [],
|
||||
"required": False,
|
||||
"type": "text-input",
|
||||
"variable": "name"
|
||||
}
|
||||
]
|
||||
},
|
||||
"id": "1717222650545",
|
||||
"position": {
|
||||
"x": -147.65487258270954,
|
||||
"y": 263.5326708413438
|
||||
},
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": []
|
||||
},
|
||||
"desc": "",
|
||||
"memory": {
|
||||
"query_prompt_template": "{{#sys.query#}}",
|
||||
"role_prefix": {
|
||||
"assistant": "",
|
||||
"user": ""
|
||||
},
|
||||
"window": {
|
||||
"enabled": False,
|
||||
"size": 10
|
||||
}
|
||||
},
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"provider": "bedrock"
|
||||
},
|
||||
"prompt_config": {
|
||||
"jinja2_variables": [
|
||||
{
|
||||
"value_selector": [
|
||||
"sys",
|
||||
"query"
|
||||
],
|
||||
"variable": "query"
|
||||
}
|
||||
]
|
||||
},
|
||||
"prompt_template": [
|
||||
{
|
||||
"edition_type": "basic",
|
||||
"id": "8b02d178-3aa0-4dbd-82bf-8b6a40658300",
|
||||
"jinja2_text": "",
|
||||
"role": "system",
|
||||
"text": "yep"
|
||||
}
|
||||
],
|
||||
"title": "LLM",
|
||||
"type": "llm",
|
||||
"variables": [],
|
||||
"vision": {
|
||||
"configs": {
|
||||
"detail": "low"
|
||||
},
|
||||
"enabled": True
|
||||
}
|
||||
},
|
||||
"id": "llm",
|
||||
"position": {
|
||||
"x": 654.0331237272932,
|
||||
"y": 263.5326708413438
|
||||
},
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "123{{#llm.text#}}",
|
||||
"desc": "",
|
||||
"title": "Answer",
|
||||
"type": "answer",
|
||||
"variables": []
|
||||
},
|
||||
"id": "answer",
|
||||
"position": {
|
||||
"x": 958.1129142362784,
|
||||
"y": 263.5326708413438
|
||||
},
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"classes": [
|
||||
{
|
||||
"id": "1",
|
||||
"name": "happy"
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"name": "sad"
|
||||
}
|
||||
],
|
||||
"desc": "",
|
||||
"instructions": "",
|
||||
"model": {
|
||||
"completion_params": {
|
||||
"temperature": 0.7
|
||||
},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai"
|
||||
},
|
||||
"query_variable_selector": [
|
||||
"1717222650545",
|
||||
"sys.query"
|
||||
],
|
||||
"title": "Question Classifier",
|
||||
"topics": [],
|
||||
"type": "question-classifier"
|
||||
},
|
||||
"id": "1719481290322",
|
||||
"position": {
|
||||
"x": 165.25154615277052,
|
||||
"y": 263.5326708413438
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"authorization": {
|
||||
"config": None,
|
||||
"type": "no-auth"
|
||||
},
|
||||
"body": {
|
||||
"data": "",
|
||||
"type": "none"
|
||||
},
|
||||
"desc": "",
|
||||
"headers": "",
|
||||
"method": "get",
|
||||
"params": "",
|
||||
"timeout": {
|
||||
"max_connect_timeout": 0,
|
||||
"max_read_timeout": 0,
|
||||
"max_write_timeout": 0
|
||||
},
|
||||
"title": "HTTP Request",
|
||||
"type": "http-request",
|
||||
"url": "https://baidu.com",
|
||||
"variables": []
|
||||
},
|
||||
"height": 88,
|
||||
"id": "1719481315734",
|
||||
"position": {
|
||||
"x": 654.0331237272932,
|
||||
"y": 474.1180064703089
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#1719481315734.status_code#}}",
|
||||
"desc": "",
|
||||
"title": "Answer 2",
|
||||
"type": "answer",
|
||||
"variables": []
|
||||
},
|
||||
"height": 105,
|
||||
"id": "1719481326339",
|
||||
"position": {
|
||||
"x": 958.1129142362784,
|
||||
"y": 474.1180064703089
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
workflow_engine_manager = WorkflowEngineManager()
|
||||
graph = workflow_engine_manager._init_graph(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
assert graph.root_node.id == "1717222650545"
|
||||
assert graph.root_node.source_edge_config is None
|
||||
assert graph.root_node.target_edge_config is not None
|
||||
assert graph.root_node.descendant_node_ids == ["1719481290322"]
|
||||
|
||||
assert graph.graph_nodes.get("1719481290322") is not None
|
||||
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
|
||||
|
||||
assert graph.graph_nodes.get("llm").run_condition_callback is not None
|
||||
assert graph.graph_nodes.get("1719481315734").run_condition_callback is not None
|
Loading…
x
Reference in New Issue
Block a user