diff --git a/api/core/workflow/entities/workflow_runtime_state.py b/api/core/workflow/entities/workflow_runtime_state.py deleted file mode 100644 index 1ef9bf1571..0000000000 --- a/api/core/workflow/entities/workflow_runtime_state.py +++ /dev/null @@ -1,28 +0,0 @@ - -from pydantic import BaseModel, Field - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph -from core.workflow.nodes.base_node import UserFrom -from models.workflow import WorkflowType - - -class WorkflowRuntimeState(BaseModel): - tenant_id: str - app_id: str - workflow_id: str - workflow_type: WorkflowType - user_id: str - user_from: UserFrom - variable_pool: VariablePool - invoke_from: InvokeFrom - graph: Graph - call_depth: int - start_at: float - - total_tokens: int = 0 - node_run_steps: int = 0 - - runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph) diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index dbff4fba65..cceea6ee9f 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -3,9 +3,7 @@ from typing import Optional, cast from pydantic import BaseModel, Field -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager +from core.workflow.entities.node_entities import NodeType from core.workflow.graph_engine.entities.run_condition import RunCondition @@ -28,33 +26,6 @@ class GraphParallel(BaseModel): """parent parallel id if exists""" -class GraphStateRoute(BaseModel): - route_id: str - """route id""" - - node_id: str - """node id""" - - -class GraphState(BaseModel): - routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict) - """graph state routes (source_node_id: routes)""" - - variable_pool: VariablePool - """variable pool""" - - node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict) - """node results in route (node_id: run_result)""" - - -class NextGraphNode(BaseModel): - node_id: str - """next node id""" - - parallel: Optional[GraphParallel] = None - """parallel""" - - class Graph(BaseModel): root_node_id: str """root node id of the graph""" @@ -71,19 +42,14 @@ class Graph(BaseModel): node_parallel_mapping: dict[str, str] = Field(default_factory=dict) """graph node parallel mapping (node id: parallel id)""" - run_state: GraphState - """graph run state""" - @classmethod def init(cls, graph_config: dict, - variable_pool: VariablePool, root_node_id: Optional[str] = None) -> "Graph": """ Init graph :param graph_config: graph config - :param variable_pool: variable pool :param root_node_id: root node id :return: graph """ @@ -149,7 +115,7 @@ class Graph(BaseModel): # fetch root node if not root_node_id: # if no root node id, use the START type node as root node - root_node_id = next((node_config for node_config in root_node_configs + root_node_id = next((node_config.get("id") for node_config in root_node_configs if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) if not root_node_id or root_node_id not in root_node_ids: @@ -178,80 +144,12 @@ class Graph(BaseModel): root_node_id=root_node_id, node_ids=node_ids, edge_mapping=edge_mapping, - run_state=GraphState( - variable_pool=variable_pool - ), parallel_mapping=parallel_mapping, node_parallel_mapping=node_parallel_mapping ) return graph - @classmethod - def _recursively_add_node_ids(cls, - node_ids: list[str], - edge_mapping: dict[str, list[GraphEdge]], - node_id: str) -> None: - """ - Recursively add node ids - - :param node_ids: node ids - :param edge_mapping: edge mapping - :param node_id: node id - """ - for graph_edge in edge_mapping.get(node_id, []): - if graph_edge.target_node_id in node_ids: - continue - - node_ids.append(graph_edge.target_node_id) - cls._recursively_add_node_ids( - node_ids=node_ids, - edge_mapping=edge_mapping, - node_id=graph_edge.target_node_id - ) - - def next_node_ids(self) -> list[NextGraphNode]: - """ - Get next node ids - """ - # get current node ids in state - if not self.run_state.routes: - return [NextGraphNode(node_id=self.root_node_id)] - - route_final_graph_edges: list[GraphEdge] = [] - for route in self.run_state.routes[self.root_node_id]: - graph_edges = self.edge_mapping.get(route.node_id) - if not graph_edges: - continue - - for edge in graph_edges: - if edge.target_node_id not in self.run_state.routes: - route_final_graph_edges.append(edge) - - next_graph_nodes = [] - for route_final_graph_edge in route_final_graph_edges: - node_id = route_final_graph_edge.target_node_id - # check condition - if route_final_graph_edge.run_condition: - result = ConditionManager.get_condition_handler( - run_condition=route_final_graph_edge.run_condition - ).check( - source_node_id=route_final_graph_edge.source_node_id, - target_node_id=route_final_graph_edge.target_node_id, - graph=self - ) - - if not result: - continue - - parallel = None - if route_final_graph_edge.target_node_id in self.node_parallel_mapping: - parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]] - - next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel)) - - return next_graph_nodes - def add_extra_edge(self, source_node_id: str, target_node_id: str, run_condition: Optional[RunCondition] = None) -> None: @@ -295,6 +193,29 @@ class Graph(BaseModel): return leaf_node_ids + @classmethod + def _recursively_add_node_ids(cls, + node_ids: list[str], + edge_mapping: dict[str, list[GraphEdge]], + node_id: str) -> None: + """ + Recursively add node ids + + :param node_ids: node ids + :param edge_mapping: edge mapping + :param node_id: node id + """ + for graph_edge in edge_mapping.get(node_id, []): + if graph_edge.target_node_id in node_ids: + continue + + node_ids.append(graph_edge.target_node_id) + cls._recursively_add_node_ids( + node_ids=node_ids, + edge_mapping=edge_mapping, + node_id=graph_edge.target_node_id + ) + @classmethod def _recursively_add_parallels(cls, edge_mapping: dict[str, list[GraphEdge]], diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index 8eafb8869e..7075ff75ba 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -4,12 +4,12 @@ from pydantic import BaseModel, Field from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState from core.workflow.nodes.base_node import UserFrom class GraphRuntimeState(BaseModel): + # init params tenant_id: str app_id: str user_id: str @@ -17,10 +17,10 @@ class GraphRuntimeState(BaseModel): invoke_from: InvokeFrom call_depth: int - graph: Graph variable_pool: VariablePool - + start_at: Optional[float] = None total_tokens: int = 0 node_run_steps: int = 0 - runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph) + + node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState) diff --git a/api/core/workflow/graph_engine/entities/next_graph_node.py b/api/core/workflow/graph_engine/entities/next_graph_node.py new file mode 100644 index 0000000000..6aa4341ddf --- /dev/null +++ b/api/core/workflow/graph_engine/entities/next_graph_node.py @@ -0,0 +1,13 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.workflow.graph_engine.entities.graph import GraphParallel + + +class NextGraphNode(BaseModel): + node_id: str + """next node id""" + + parallel: Optional[GraphParallel] = None + """parallel""" diff --git a/api/core/workflow/graph_engine/entities/runtime_graph.py b/api/core/workflow/graph_engine/entities/runtime_graph.py deleted file mode 100644 index 916a5a983f..0000000000 --- a/api/core/workflow/graph_engine/entities/runtime_graph.py +++ /dev/null @@ -1,38 +0,0 @@ -from datetime import datetime, timezone -from typing import Optional - -from pydantic import BaseModel - -from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.graph_engine.entities.runtime_node import RuntimeNode -from models.workflow import WorkflowNodeExecutionStatus - - -class RuntimeGraph(BaseModel): - runtime_nodes: dict[str, RuntimeNode] = {} - """runtime nodes""" - - def add_runtime_node(self, runtime_node: RuntimeNode) -> None: - self.runtime_nodes[runtime_node.id] = runtime_node - - def add_link(self, source_runtime_node_id: str, target_runtime_node_id: str) -> None: - if source_runtime_node_id in self.runtime_nodes and target_runtime_node_id in self.runtime_nodes: - target_runtime_node = self.runtime_nodes[target_runtime_node_id] - target_runtime_node.predecessor_runtime_node_id = source_runtime_node_id - - def runtime_node_finished(self, runtime_node_id: str, node_run_result: NodeRunResult) -> None: - if runtime_node_id in self.runtime_nodes: - runtime_node = self.runtime_nodes[runtime_node_id] - runtime_node.status = RuntimeNode.Status.SUCCESS \ - if node_run_result.status == WorkflowNodeExecutionStatus.RUNNING \ - else RuntimeNode.Status.FAILED - runtime_node.node_run_result = node_run_result - runtime_node.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) - runtime_node.failed_reason = node_run_result.error - - def runtime_node_paused(self, runtime_node_id: str, paused_by: Optional[str] = None) -> None: - if runtime_node_id in self.runtime_nodes: - runtime_node = self.runtime_nodes[runtime_node_id] - runtime_node.status = RuntimeNode.Status.PAUSED - runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) - runtime_node.paused_by = paused_by diff --git a/api/core/workflow/graph_engine/entities/runtime_node.py b/api/core/workflow/graph_engine/entities/runtime_node.py deleted file mode 100644 index 2ba894fa13..0000000000 --- a/api/core/workflow/graph_engine/entities/runtime_node.py +++ /dev/null @@ -1,48 +0,0 @@ -import uuid -from datetime import datetime -from enum import Enum -from typing import Optional - -from pydantic import BaseModel, Field - -from core.workflow.entities.node_entities import NodeRunResult -from core.workflow.graph_engine.entities.graph import GraphNode - - -class RuntimeNode(BaseModel): - class Status(Enum): - PENDING = "pending" - RUNNING = "running" - SUCCESS = "success" - FAILED = "failed" - PAUSED = "paused" - - id: str = Field(default_factory=lambda: str(uuid.uuid4())) - """random id for current runtime node""" - - graph_node: GraphNode - """graph node""" - - node_run_result: Optional[NodeRunResult] = None - """node run result""" - - status: Status = Status.PENDING - """node status""" - - start_at: Optional[datetime] = None - """start time""" - - paused_at: Optional[datetime] = None - """paused time""" - - finished_at: Optional[datetime] = None - """finished time""" - - failed_reason: Optional[str] = None - """failed reason""" - - paused_by: Optional[str] = None - """paused by""" - - predecessor_runtime_node_id: Optional[str] = None - """predecessor runtime node id""" diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py new file mode 100644 index 0000000000..e6851ac223 --- /dev/null +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -0,0 +1,111 @@ +import uuid +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus + + +class RouteNodeState(BaseModel): + class Status(Enum): + RUNNING = "running" + SUCCESS = "success" + FAILED = "failed" + PAUSED = "paused" + + state_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + """node state id""" + + node_id: str + """node id""" + + node_run_result: Optional[NodeRunResult] = None + """node run result""" + + status: Status = Status.RUNNING + """node status""" + + start_at: datetime + """start time""" + + paused_at: Optional[datetime] = None + """paused time""" + + finished_at: Optional[datetime] = None + """finished time""" + + failed_reason: Optional[str] = None + """failed reason""" + + paused_by: Optional[str] = None + """paused by""" + + +class RuntimeRouteState(BaseModel): + routes: dict[str, list[str]] = Field(default_factory=dict) + """graph state routes (source_node_state_id: target_node_state_id)""" + + node_state_mapping: dict[str, RouteNodeState] = Field(default_factory=dict) + """node state mapping (route_node_state_id: route_node_state)""" + + def create_node_state(self, node_id: str) -> RouteNodeState: + """ + Create node state + + :param node_id: node id + """ + state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + self.node_state_mapping[state.state_id] = state + return state + + def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: + """ + Add route to the graph state + + :param source_node_state_id: source node state id + :param target_node_state_id: target node state id + """ + if source_node_state_id not in self.routes: + self.routes[source_node_state_id] = [] + + self.routes[source_node_state_id].append(target_node_state_id) + + def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \ + -> list[RouteNodeState]: + """ + Get routes with node state by source node id + + :param source_node_state_id: source node state id + :return: routes with node state + """ + return [self.node_state_mapping[target_state_id] + for target_state_id in self.routes.get(source_node_state_id, [])] + + def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None: + """ + Node finished + + :param node_state_id: route node state id + :param run_result: run result + """ + if node_state_id not in self.node_state_mapping: + raise Exception(f"Route state {node_state_id} not found") + + route = self.node_state_mapping[node_state_id] + + if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]: + raise Exception(f"Route state {node_state_id} already finished") + + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + route.status = RouteNodeState.Status.SUCCESS + elif run_result.status == WorkflowNodeExecutionStatus.FAILED: + route.status = RouteNodeState.Status.FAILED + route.failed_reason = run_result.error + else: + raise Exception(f"Invalid route status {run_result.status}") + + route.node_run_result = run_result + route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 0851bec603..c0f5acdb9a 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -22,6 +22,7 @@ class GraphEngine: graph: Graph, variable_pool: VariablePool, callbacks: list[BaseWorkflowCallback]) -> None: + self.graph = graph self.graph_runtime_state = GraphRuntimeState( tenant_id=tenant_id, app_id=app_id, @@ -29,7 +30,6 @@ class GraphEngine: user_from=user_from, invoke_from=invoke_from, call_depth=call_depth, - graph=graph, variable_pool=variable_pool ) @@ -43,3 +43,49 @@ class GraphEngine: def run(self) -> Generator: self.graph_runtime_state.start_at = time.perf_counter() pass + + # def next_node_ids(self, node_state_id: str) -> list[NextGraphNode]: + # """ + # Get next node ids + # + # :param node_state_id: source node state id + # """ + # # get current node ids in state + # node_run_state = self.graph_runtime_state.node_run_state + # graph = self.graph + # if not node_run_state.routes: + # return [NextGraphNode(node_id=graph.root_node_id)] + # + # route_final_graph_edges: list[GraphEdge] = [] + # for route in route_state.routes[graph.root_node_id]: + # graph_edges = graph.edge_mapping.get(route.node_id) + # if not graph_edges: + # continue + # + # for edge in graph_edges: + # if edge.target_node_id not in route_state.routes: + # route_final_graph_edges.append(edge) + # + # next_graph_nodes = [] + # for route_final_graph_edge in route_final_graph_edges: + # node_id = route_final_graph_edge.target_node_id + # # check condition + # if route_final_graph_edge.run_condition: + # result = ConditionManager.get_condition_handler( + # run_condition=route_final_graph_edge.run_condition + # ).check( + # source_node_id=route_final_graph_edge.source_node_id, + # target_node_id=route_final_graph_edge.target_node_id, + # graph=self + # ) + # + # if not result: + # continue + # + # parallel = None + # if route_final_graph_edge.target_node_id in graph.node_parallel_mapping: + # parallel = graph.parallel_mapping[graph.node_parallel_mapping[node_id]] + # + # next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel)) + # + # return next_graph_nodes diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index de59302c18..0412ba572e 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -106,7 +106,7 @@ class WorkflowEntry: ) # init graph - graph = self._init_graph( + graph = Graph.init( graph_config=graph_config ) @@ -152,86 +152,6 @@ class WorkflowEntry: return rst - def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]: - """ - Initialize graph - - :param graph_config: graph config - :param root_node_id: root node id if needed - :return: graph - """ - # edge configs - edge_configs = graph_config.get('edges') - if not edge_configs: - return None - - edge_configs = cast(list, edge_configs) - - # reorganize edges mapping - source_edges_mapping: dict[str, list[dict]] = {} - target_edge_ids = set() - for edge_config in edge_configs: - source_node_id = edge_config.get('source') - if not source_node_id: - continue - - if source_node_id not in source_edges_mapping: - source_edges_mapping[source_node_id] = [] - - source_edges_mapping[source_node_id].append(edge_config) - - target_node_id = edge_config.get('target') - if target_node_id: - target_edge_ids.add(target_node_id) - - # node configs - node_configs = graph_config.get('nodes') - if not node_configs: - return None - - node_configs = cast(list, node_configs) - - # fetch nodes that have no predecessor node - root_node_configs = [] - nodes_mapping: dict[str, dict] = {} - for node_config in node_configs: - node_id = node_config.get('id') - if not node_id: - continue - - if node_id not in target_edge_ids: - root_node_configs.append(node_config) - - nodes_mapping[node_id] = node_config - - # fetch root node - if root_node_id: - root_node_config = next((node_config for node_config in root_node_configs - if node_config.get('id') == root_node_id), None) - else: - # if no root node id, use the START type node as root node - root_node_config = next((node_config for node_config in root_node_configs - if node_config.get('data', {}).get('type', '') == NodeType.START.value), None) - - if not root_node_config: - return None - - # init graph - graph = Graph.init( - root_node_config=root_node_config - ) - - # add edge from root node - self._recursively_add_edges( - graph=graph, - source_node_config=root_node_config, - edges_mapping=source_edges_mapping, - nodes_mapping=nodes_mapping, - root_node_configs=root_node_configs - ) - - return graph - def _recursively_add_edges(self, graph: Graph, source_node_config: dict, edges_mapping: dict, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py new file mode 100644 index 0000000000..0b595fc97a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -0,0 +1,208 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.utils.condition.entities import Condition + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "llm-source-answer-target", + "source": "llm", + "target": "answer", + }, + { + "id": "start-source-qc-target", + "source": "start", + "target": "qc", + }, + { + "id": "qc-1-llm-target", + "source": "qc", + "sourceHandle": "1", + "target": "llm", + }, + { + "id": "qc-2-http-target", + "source": "qc", + "sourceHandle": "2", + "target": "http", + }, + { + "id": "http-source-answer2-target", + "source": "http", + "target": "answer2", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + { + "data": { + "type": "question-classifier" + }, + "id": "qc", + }, + { + "data": { + "type": "http-request", + }, + "id": "http", + }, + { + "data": { + "type": "answer", + }, + "id": "answer2", + } + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + start_node_id = "start" + + assert graph.root_node_id == start_node_id + assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" + assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} + + +def test__init_graph_with_iteration(): + graph_config = { + "edges": [ + { + "id": "llm-answer", + "source": "llm", + "sourceHandle": "source", + "target": "answer", + }, + { + "id": "iteration-source-llm-target", + "source": "iteration", + "sourceHandle": "source", + "target": "llm", + }, + { + "id": "template-transform-in-iteration-source-llm-in-iteration-target", + "source": "template-transform-in-iteration", + "sourceHandle": "source", + "target": "llm-in-iteration", + }, + { + "id": "llm-in-iteration-source-answer-in-iteration-target", + "source": "llm-in-iteration", + "sourceHandle": "source", + "target": "answer-in-iteration", + }, + { + "id": "start-source-code-target", + "source": "start", + "sourceHandle": "source", + "target": "code", + }, + { + "id": "code-source-iteration-target", + "source": "code", + "sourceHandle": "source", + "target": "iteration", + } + ], + "nodes": [ + { + "data": { + "type": "start", + }, + "id": "start", + }, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + { + "data": { + "type": "iteration" + }, + "id": "iteration", + }, + { + "data": { + "type": "template-transform", + }, + "id": "template-transform-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "llm", + }, + "id": "llm-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "answer", + }, + "id": "answer-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "code", + }, + "id": "code", + } + ] + } + + graph = Graph.init( + graph_config=graph_config, + root_node_id="template-transform-in-iteration" + ) + graph.add_extra_edge( + source_node_id="answer-in-iteration", + target_node_id="template-transform-in-iteration", + run_condition=RunCondition( + type="condition", + conditions=[ + Condition( + variable_selector=["iteration", "index"], + comparison_operator="≤", + value="5" + ) + ] + ) + ) + + # iteration: + # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] + + assert graph.root_node_id == "template-transform-in-iteration" + assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" + assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" + assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py deleted file mode 100644 index 47e62aff8c..0000000000 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ /dev/null @@ -1,693 +0,0 @@ -from typing import Optional - -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.workflow_entry import WorkflowEntry - - -def test__init_graph(): - graph_config = { - "edges": [ - { - "id": "llm-source-answer-target", - "source": "llm", - "target": "answer", - }, - { - "id": "1717222650545-source-1719481290322-target", - "source": "1717222650545", - "target": "1719481290322", - }, - { - "id": "1719481290322-1-llm-target", - "source": "1719481290322", - "sourceHandle": "1", - "target": "llm", - }, - { - "id": "1719481290322-2-1719481315734-target", - "source": "1719481290322", - "sourceHandle": "2", - "target": "1719481315734", - }, - { - "id": "1719481315734-source-1719481326339-target", - "source": "1719481315734", - "target": "1719481326339", - } - ], - "nodes": [ - { - "data": { - "desc": "", - "title": "Start", - "type": "start", - "variables": [ - { - "label": "name", - "max_length": 48, - "options": [], - "required": False, - "type": "text-input", - "variable": "name" - } - ] - }, - "id": "1717222650545", - "position": { - "x": -147.65487258270954, - "y": 263.5326708413438 - }, - }, - { - "data": { - "context": { - "enabled": False, - "variable_selector": [] - }, - "desc": "", - "memory": { - "query_prompt_template": "{{#sys.query#}}", - "role_prefix": { - "assistant": "", - "user": "" - }, - "window": { - "enabled": False, - "size": 10 - } - }, - "model": { - "completion_params": { - "temperature": 0 - }, - "mode": "chat", - "name": "anthropic.claude-3-sonnet-20240229-v1:0", - "provider": "bedrock" - }, - "prompt_config": { - "jinja2_variables": [ - { - "value_selector": [ - "sys", - "query" - ], - "variable": "query" - } - ] - }, - "prompt_template": [ - { - "edition_type": "basic", - "id": "8b02d178-3aa0-4dbd-82bf-8b6a40658300", - "jinja2_text": "", - "role": "system", - "text": "yep" - } - ], - "title": "LLM", - "type": "llm", - "variables": [], - "vision": { - "configs": { - "detail": "low" - }, - "enabled": True - } - }, - "id": "llm", - "position": { - "x": 654.0331237272932, - "y": 263.5326708413438 - }, - }, - { - "data": { - "answer": "123{{#llm.text#}}", - "desc": "", - "title": "Answer", - "type": "answer", - "variables": [] - }, - "id": "answer", - "position": { - "x": 958.1129142362784, - "y": 263.5326708413438 - }, - }, - { - "data": { - "classes": [ - { - "id": "1", - "name": "happy" - }, - { - "id": "2", - "name": "sad" - } - ], - "desc": "", - "instructions": "", - "model": { - "completion_params": { - "temperature": 0.7 - }, - "mode": "chat", - "name": "gpt-4o", - "provider": "openai" - }, - "query_variable_selector": [ - "1717222650545", - "sys.query" - ], - "title": "Question Classifier", - "topics": [], - "type": "question-classifier" - }, - "id": "1719481290322", - "position": { - "x": 165.25154615277052, - "y": 263.5326708413438 - } - }, - { - "data": { - "authorization": { - "config": None, - "type": "no-auth" - }, - "body": { - "data": "", - "type": "none" - }, - "desc": "", - "headers": "", - "method": "get", - "params": "", - "timeout": { - "max_connect_timeout": 0, - "max_read_timeout": 0, - "max_write_timeout": 0 - }, - "title": "HTTP Request", - "type": "http-request", - "url": "https://baidu.com", - "variables": [] - }, - "height": 88, - "id": "1719481315734", - "position": { - "x": 654.0331237272932, - "y": 474.1180064703089 - } - }, - { - "data": { - "answer": "{{#1719481315734.status_code#}}", - "desc": "", - "title": "Answer 2", - "type": "answer", - "variables": [] - }, - "height": 105, - "id": "1719481326339", - "position": { - "x": 958.1129142362784, - "y": 474.1180064703089 - }, - } - ], - } - - workflow_entry = WorkflowEntry() - graph = workflow_entry._init_graph( - graph_config=graph_config - ) - - assert graph.root_node.id == "1717222650545" - assert graph.root_node.source_edge_config is None - assert graph.root_node.descendant_node_ids == ["1719481290322"] - - assert graph.graph_nodes.get("1719481290322") is not None - assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2 - - assert graph.graph_nodes.get("llm").run_condition is not None - assert graph.graph_nodes.get("1719481315734").run_condition is not None - - -def test__init_graph_with_iteration(): - graph_config = { - "edges": [ - { - "data": { - "sourceType": "llm", - "targetType": "answer" - }, - "id": "llm-answer", - "source": "llm", - "sourceHandle": "source", - "target": "answer", - "targetHandle": "target", - "type": "custom" - }, - { - "data": { - "isInIteration": False, - "sourceType": "iteration", - "targetType": "llm" - }, - "id": "1720001776597-source-llm-target", - "selected": False, - "source": "1720001776597", - "sourceHandle": "source", - "target": "llm", - "targetHandle": "target", - "type": "custom", - "zIndex": 0 - }, - { - "data": { - "isInIteration": True, - "iteration_id": "1720001776597", - "sourceType": "template-transform", - "targetType": "llm" - }, - "id": "1720001783092-source-1720001859851-target", - "source": "1720001783092", - "sourceHandle": "source", - "target": "1720001859851", - "targetHandle": "target", - "type": "custom", - "zIndex": 1002 - }, - { - "data": { - "isInIteration": True, - "iteration_id": "1720001776597", - "sourceType": "llm", - "targetType": "answer" - }, - "id": "1720001859851-source-1720001879621-target", - "source": "1720001859851", - "sourceHandle": "source", - "target": "1720001879621", - "targetHandle": "target", - "type": "custom", - "zIndex": 1002 - }, - { - "data": { - "isInIteration": False, - "sourceType": "start", - "targetType": "code" - }, - "id": "1720001771022-source-1720001956578-target", - "source": "1720001771022", - "sourceHandle": "source", - "target": "1720001956578", - "targetHandle": "target", - "type": "custom", - "zIndex": 0 - }, - { - "data": { - "isInIteration": False, - "sourceType": "code", - "targetType": "iteration" - }, - "id": "1720001956578-source-1720001776597-target", - "source": "1720001956578", - "sourceHandle": "source", - "target": "1720001776597", - "targetHandle": "target", - "type": "custom", - "zIndex": 0 - } - ], - "nodes": [ - { - "data": { - "desc": "", - "selected": False, - "title": "Start", - "type": "start", - "variables": [] - }, - "height": 53, - "id": "1720001771022", - "position": { - "x": 80, - "y": 282 - }, - "positionAbsolute": { - "x": 80, - "y": 282 - }, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244 - }, - { - "data": { - "context": { - "enabled": False, - "variable_selector": [] - }, - "desc": "", - "memory": { - "role_prefix": { - "assistant": "", - "user": "" - }, - "window": { - "enabled": False, - "size": 10 - } - }, - "model": { - "completion_params": { - "temperature": 0.7 - }, - "mode": "chat", - "name": "gpt-3.5-turbo", - "provider": "openai" - }, - "prompt_template": [ - { - "id": "b7d1350e-cf0d-4ff3-8ad0-52b6f1218781", - "role": "system", - "text": "" - } - ], - "selected": False, - "title": "LLM", - "type": "llm", - "variables": [], - "vision": { - "enabled": False - } - }, - "height": 97, - "id": "llm", - "position": { - "x": 1730.595805935594, - "y": 282 - }, - "positionAbsolute": { - "x": 1730.595805935594, - "y": 282 - }, - "selected": True, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244 - }, - { - "data": { - "answer": "{{#llm.text#}}", - "desc": "", - "selected": False, - "title": "Answer", - "type": "answer", - "variables": [] - }, - "height": 105, - "id": "answer", - "position": { - "x": 2042.803154918583, - "y": 282 - }, - "positionAbsolute": { - "x": 2042.803154918583, - "y": 282 - }, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244 - }, - { - "data": { - "desc": "", - "height": 202, - "iterator_selector": [ - "1720001956578", - "result" - ], - "output_selector": [ - "1720001859851", - "text" - ], - "output_type": "array[string]", - "selected": False, - "startNodeType": "template-transform", - "start_node_id": "1720001783092", - "title": "Iteration", - "type": "iteration", - "width": 985 - }, - "height": 202, - "id": "1720001776597", - "position": { - "x": 678.6748900850307, - "y": 282 - }, - "positionAbsolute": { - "x": 678.6748900850307, - "y": 282 - }, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 985, - "zIndex": 1 - }, - { - "data": { - "desc": "", - "isInIteration": True, - "isIterationStart": True, - "iteration_id": "1720001776597", - "selected": False, - "template": "{{ arg1 }}", - "title": "Template", - "type": "template-transform", - "variables": [ - { - "value_selector": [ - "1720001776597", - "item" - ], - "variable": "arg1" - } - ] - }, - "extent": "parent", - "height": 53, - "id": "1720001783092", - "parentId": "1720001776597", - "position": { - "x": 117, - "y": 85 - }, - "positionAbsolute": { - "x": 795.6748900850307, - "y": 367 - }, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244, - "zIndex": 1001 - }, - { - "data": { - "context": { - "enabled": False, - "variable_selector": [] - }, - "desc": "", - "isInIteration": True, - "iteration_id": "1720001776597", - "model": { - "completion_params": { - "temperature": 0.7 - }, - "mode": "chat", - "name": "gpt-3.5-turbo", - "provider": "openai" - }, - "prompt_template": [ - { - "id": "9575b8f2-33c4-4611-b6d0-17d8d436a250", - "role": "system", - "text": "{{#1720001783092.output#}}" - } - ], - "selected": False, - "title": "LLM 2", - "type": "llm", - "variables": [], - "vision": { - "enabled": False - } - }, - "extent": "parent", - "height": 97, - "id": "1720001859851", - "parentId": "1720001776597", - "position": { - "x": 421, - "y": 85 - }, - "positionAbsolute": { - "x": 1099.6748900850307, - "y": 367 - }, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244, - "zIndex": 1002 - }, - { - "data": { - "answer": "{{#1720001859851.text#}}", - "desc": "", - "isInIteration": True, - "iteration_id": "1720001776597", - "selected": False, - "title": "Answer 2", - "type": "answer", - "variables": [] - }, - "extent": "parent", - "height": 105, - "id": "1720001879621", - "parentId": "1720001776597", - "position": { - "x": 725, - "y": 85 - }, - "positionAbsolute": { - "x": 1403.6748900850307, - "y": 367 - }, - "selected": False, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244, - "zIndex": 1002 - }, - { - "data": { - "code": "\ndef main() -> dict:\n return {\n \"result\": [\n \"a\",\n \"b\"\n ]\n }\n", - "code_language": "python3", - "desc": "", - "outputs": { - "result": { - "children": None, - "type": "array[string]" - } - }, - "selected": False, - "title": "Code", - "type": "code", - "variables": [] - }, - "height": 53, - "id": "1720001956578", - "position": { - "x": 380, - "y": 282 - }, - "positionAbsolute": { - "x": 380, - "y": 282 - }, - "sourcePosition": "right", - "targetPosition": "left", - "type": "custom", - "width": 244 - } - ] - } - - workflow_entry = WorkflowEntry() - graph = workflow_entry._init_graph( - graph_config=graph_config - ) - - # start 1720001771022 -> code 1720001956578 -> iteration 1720001776597 -> llm llm -> answer answer - # iteration 1720001776597: - # [template 1720001783092 -> llm 1720001859851 -> answer 1720001879621] - - main_graph_orders = [ - "1720001771022", - "1720001956578", - "1720001776597", - "llm", - "answer" - ] - - iteration_sub_graph_orders = [ - "1720001783092", - "1720001859851", - "1720001879621" - ] - - assert graph.root_node.id == "1720001771022" - - print("") - - current_graph = graph - for i, node_id in enumerate(main_graph_orders): - current_root_node = current_graph.root_node - assert current_root_node is not None - assert current_root_node.id == node_id - - if current_root_node.node_config.get("data", {}).get("type") == "iteration": - assert current_root_node.sub_graph is not None - - sub_graph = current_root_node.sub_graph - assert sub_graph.root_node.id == "1720001783092" - - current_sub_graph = sub_graph - for j, sub_node_id in enumerate(iteration_sub_graph_orders): - sub_descendant_graphs = current_sub_graph.get_descendant_graphs(node_id=current_sub_graph.root_node.id) - print(f"Iteration [{current_sub_graph.root_node.id}] -> {len(sub_descendant_graphs)}" - f" {[sub_descendant_graph.root_node.id for sub_descendant_graph in sub_descendant_graphs]}") - - if j == len(iteration_sub_graph_orders) - 1: - break - - assert len(sub_descendant_graphs) == 1 - - first_sub_descendant_graph = sub_descendant_graphs[0] - assert first_sub_descendant_graph.root_node.id == iteration_sub_graph_orders[j + 1] - assert first_sub_descendant_graph.root_node.predecessor_node_id == sub_node_id - - current_sub_graph = first_sub_descendant_graph - - descendant_graphs = current_graph.get_descendant_graphs(node_id=current_graph.root_node.id) - print(f"[{current_graph.root_node.id}] -> {len(descendant_graphs)}" - f" {[descendant_graph.root_node.id for descendant_graph in descendant_graphs]}") - if i == len(main_graph_orders) - 1: - assert len(descendant_graphs) == 0 - break - - assert len(descendant_graphs) == 1 - - first_descendant_graph = descendant_graphs[0] - assert first_descendant_graph.root_node.id == main_graph_orders[i + 1] - assert first_descendant_graph.root_node.predecessor_node_id == node_id - - current_graph = first_descendant_graph