refactor runtime

This commit is contained in:
takatost 2024-07-08 16:29:13 +08:00
parent 1adaf42f9d
commit 0e885a3cae
12 changed files with 410 additions and 998 deletions

View File

@ -1,28 +0,0 @@
from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
from core.workflow.nodes.base_node import UserFrom
from models.workflow import WorkflowType
class WorkflowRuntimeState(BaseModel):
tenant_id: str
app_id: str
workflow_id: str
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
variable_pool: VariablePool
invoke_from: InvokeFrom
graph: Graph
call_depth: int
start_at: float
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)

View File

@ -3,9 +3,7 @@ from typing import Optional, cast
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
from core.workflow.entities.node_entities import NodeType
from core.workflow.graph_engine.entities.run_condition import RunCondition
@ -28,33 +26,6 @@ class GraphParallel(BaseModel):
"""parent parallel id if exists"""
class GraphStateRoute(BaseModel):
route_id: str
"""route id"""
node_id: str
"""node id"""
class GraphState(BaseModel):
routes: dict[str, list[GraphStateRoute]] = Field(default_factory=dict)
"""graph state routes (source_node_id: routes)"""
variable_pool: VariablePool
"""variable pool"""
node_route_results: dict[str, NodeRunResult] = Field(default_factory=dict)
"""node results in route (node_id: run_result)"""
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""
class Graph(BaseModel):
root_node_id: str
"""root node id of the graph"""
@ -71,19 +42,14 @@ class Graph(BaseModel):
node_parallel_mapping: dict[str, str] = Field(default_factory=dict)
"""graph node parallel mapping (node id: parallel id)"""
run_state: GraphState
"""graph run state"""
@classmethod
def init(cls,
graph_config: dict,
variable_pool: VariablePool,
root_node_id: Optional[str] = None) -> "Graph":
"""
Init graph
:param graph_config: graph config
:param variable_pool: variable pool
:param root_node_id: root node id
:return: graph
"""
@ -149,7 +115,7 @@ class Graph(BaseModel):
# fetch root node
if not root_node_id:
# if no root node id, use the START type node as root node
root_node_id = next((node_config for node_config in root_node_configs
root_node_id = next((node_config.get("id") for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_id or root_node_id not in root_node_ids:
@ -178,80 +144,12 @@ class Graph(BaseModel):
root_node_id=root_node_id,
node_ids=node_ids,
edge_mapping=edge_mapping,
run_state=GraphState(
variable_pool=variable_pool
),
parallel_mapping=parallel_mapping,
node_parallel_mapping=node_parallel_mapping
)
return graph
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
def next_node_ids(self) -> list[NextGraphNode]:
"""
Get next node ids
"""
# get current node ids in state
if not self.run_state.routes:
return [NextGraphNode(node_id=self.root_node_id)]
route_final_graph_edges: list[GraphEdge] = []
for route in self.run_state.routes[self.root_node_id]:
graph_edges = self.edge_mapping.get(route.node_id)
if not graph_edges:
continue
for edge in graph_edges:
if edge.target_node_id not in self.run_state.routes:
route_final_graph_edges.append(edge)
next_graph_nodes = []
for route_final_graph_edge in route_final_graph_edges:
node_id = route_final_graph_edge.target_node_id
# check condition
if route_final_graph_edge.run_condition:
result = ConditionManager.get_condition_handler(
run_condition=route_final_graph_edge.run_condition
).check(
source_node_id=route_final_graph_edge.source_node_id,
target_node_id=route_final_graph_edge.target_node_id,
graph=self
)
if not result:
continue
parallel = None
if route_final_graph_edge.target_node_id in self.node_parallel_mapping:
parallel = self.parallel_mapping[self.node_parallel_mapping[node_id]]
next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
return next_graph_nodes
def add_extra_edge(self, source_node_id: str,
target_node_id: str,
run_condition: Optional[RunCondition] = None) -> None:
@ -295,6 +193,29 @@ class Graph(BaseModel):
return leaf_node_ids
@classmethod
def _recursively_add_node_ids(cls,
node_ids: list[str],
edge_mapping: dict[str, list[GraphEdge]],
node_id: str) -> None:
"""
Recursively add node ids
:param node_ids: node ids
:param edge_mapping: edge mapping
:param node_id: node id
"""
for graph_edge in edge_mapping.get(node_id, []):
if graph_edge.target_node_id in node_ids:
continue
node_ids.append(graph_edge.target_node_id)
cls._recursively_add_node_ids(
node_ids=node_ids,
edge_mapping=edge_mapping,
node_id=graph_edge.target_node_id
)
@classmethod
def _recursively_add_parallels(cls,
edge_mapping: dict[str, list[GraphEdge]],

View File

@ -4,12 +4,12 @@ from pydantic import BaseModel, Field
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_graph import RuntimeGraph
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.nodes.base_node import UserFrom
class GraphRuntimeState(BaseModel):
# init params
tenant_id: str
app_id: str
user_id: str
@ -17,10 +17,10 @@ class GraphRuntimeState(BaseModel):
invoke_from: InvokeFrom
call_depth: int
graph: Graph
variable_pool: VariablePool
start_at: Optional[float] = None
total_tokens: int = 0
node_run_steps: int = 0
runtime_graph: RuntimeGraph = Field(default_factory=RuntimeGraph)
node_run_state: RuntimeRouteState = Field(default_factory=RuntimeRouteState)

View File

@ -0,0 +1,13 @@
from typing import Optional
from pydantic import BaseModel
from core.workflow.graph_engine.entities.graph import GraphParallel
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""

View File

@ -1,38 +0,0 @@
from datetime import datetime, timezone
from typing import Optional
from pydantic import BaseModel
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.runtime_node import RuntimeNode
from models.workflow import WorkflowNodeExecutionStatus
class RuntimeGraph(BaseModel):
runtime_nodes: dict[str, RuntimeNode] = {}
"""runtime nodes"""
def add_runtime_node(self, runtime_node: RuntimeNode) -> None:
self.runtime_nodes[runtime_node.id] = runtime_node
def add_link(self, source_runtime_node_id: str, target_runtime_node_id: str) -> None:
if source_runtime_node_id in self.runtime_nodes and target_runtime_node_id in self.runtime_nodes:
target_runtime_node = self.runtime_nodes[target_runtime_node_id]
target_runtime_node.predecessor_runtime_node_id = source_runtime_node_id
def runtime_node_finished(self, runtime_node_id: str, node_run_result: NodeRunResult) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.SUCCESS \
if node_run_result.status == WorkflowNodeExecutionStatus.RUNNING \
else RuntimeNode.Status.FAILED
runtime_node.node_run_result = node_run_result
runtime_node.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.failed_reason = node_run_result.error
def runtime_node_paused(self, runtime_node_id: str, paused_by: Optional[str] = None) -> None:
if runtime_node_id in self.runtime_nodes:
runtime_node = self.runtime_nodes[runtime_node_id]
runtime_node.status = RuntimeNode.Status.PAUSED
runtime_node.paused_at = datetime.now(timezone.utc).replace(tzinfo=None)
runtime_node.paused_by = paused_by

View File

@ -1,48 +0,0 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.graph_engine.entities.graph import GraphNode
class RuntimeNode(BaseModel):
class Status(Enum):
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""random id for current runtime node"""
graph_node: GraphNode
"""graph node"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.PENDING
"""node status"""
start_at: Optional[datetime] = None
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
predecessor_runtime_node_id: Optional[str] = None
"""predecessor runtime node id"""

View File

@ -0,0 +1,111 @@
import uuid
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class RouteNodeState(BaseModel):
class Status(Enum):
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
PAUSED = "paused"
state_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
"""node state id"""
node_id: str
"""node id"""
node_run_result: Optional[NodeRunResult] = None
"""node run result"""
status: Status = Status.RUNNING
"""node status"""
start_at: datetime
"""start time"""
paused_at: Optional[datetime] = None
"""paused time"""
finished_at: Optional[datetime] = None
"""finished time"""
failed_reason: Optional[str] = None
"""failed reason"""
paused_by: Optional[str] = None
"""paused by"""
class RuntimeRouteState(BaseModel):
routes: dict[str, list[str]] = Field(default_factory=dict)
"""graph state routes (source_node_state_id: target_node_state_id)"""
node_state_mapping: dict[str, RouteNodeState] = Field(default_factory=dict)
"""node state mapping (route_node_state_id: route_node_state)"""
def create_node_state(self, node_id: str) -> RouteNodeState:
"""
Create node state
:param node_id: node id
"""
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
self.node_state_mapping[state.state_id] = state
return state
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
"""
Add route to the graph state
:param source_node_state_id: source node state id
:param target_node_state_id: target node state id
"""
if source_node_state_id not in self.routes:
self.routes[source_node_state_id] = []
self.routes[source_node_state_id].append(target_node_state_id)
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
-> list[RouteNodeState]:
"""
Get routes with node state by source node id
:param source_node_state_id: source node state id
:return: routes with node state
"""
return [self.node_state_mapping[target_state_id]
for target_state_id in self.routes.get(source_node_state_id, [])]
def set_node_state_finished(self, node_state_id: str, run_result: NodeRunResult) -> None:
"""
Node finished
:param node_state_id: route node state id
:param run_result: run result
"""
if node_state_id not in self.node_state_mapping:
raise Exception(f"Route state {node_state_id} not found")
route = self.node_state_mapping[node_state_id]
if route.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
raise Exception(f"Route state {node_state_id} already finished")
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
route.status = RouteNodeState.Status.SUCCESS
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
route.status = RouteNodeState.Status.FAILED
route.failed_reason = run_result.error
else:
raise Exception(f"Invalid route status {run_result.status}")
route.node_run_result = run_result
route.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)

View File

@ -22,6 +22,7 @@ class GraphEngine:
graph: Graph,
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
self.graph = graph
self.graph_runtime_state = GraphRuntimeState(
tenant_id=tenant_id,
app_id=app_id,
@ -29,7 +30,6 @@ class GraphEngine:
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
variable_pool=variable_pool
)
@ -43,3 +43,49 @@ class GraphEngine:
def run(self) -> Generator:
self.graph_runtime_state.start_at = time.perf_counter()
pass
# def next_node_ids(self, node_state_id: str) -> list[NextGraphNode]:
# """
# Get next node ids
#
# :param node_state_id: source node state id
# """
# # get current node ids in state
# node_run_state = self.graph_runtime_state.node_run_state
# graph = self.graph
# if not node_run_state.routes:
# return [NextGraphNode(node_id=graph.root_node_id)]
#
# route_final_graph_edges: list[GraphEdge] = []
# for route in route_state.routes[graph.root_node_id]:
# graph_edges = graph.edge_mapping.get(route.node_id)
# if not graph_edges:
# continue
#
# for edge in graph_edges:
# if edge.target_node_id not in route_state.routes:
# route_final_graph_edges.append(edge)
#
# next_graph_nodes = []
# for route_final_graph_edge in route_final_graph_edges:
# node_id = route_final_graph_edge.target_node_id
# # check condition
# if route_final_graph_edge.run_condition:
# result = ConditionManager.get_condition_handler(
# run_condition=route_final_graph_edge.run_condition
# ).check(
# source_node_id=route_final_graph_edge.source_node_id,
# target_node_id=route_final_graph_edge.target_node_id,
# graph=self
# )
#
# if not result:
# continue
#
# parallel = None
# if route_final_graph_edge.target_node_id in graph.node_parallel_mapping:
# parallel = graph.parallel_mapping[graph.node_parallel_mapping[node_id]]
#
# next_graph_nodes.append(NextGraphNode(node_id=node_id, parallel=parallel))
#
# return next_graph_nodes

View File

@ -106,7 +106,7 @@ class WorkflowEntry:
)
# init graph
graph = self._init_graph(
graph = Graph.init(
graph_config=graph_config
)
@ -152,86 +152,6 @@ class WorkflowEntry:
return rst
def _init_graph(self, graph_config: dict, root_node_id: Optional[str] = None) -> Optional[Graph]:
"""
Initialize graph
:param graph_config: graph config
:param root_node_id: root node id if needed
:return: graph
"""
# edge configs
edge_configs = graph_config.get('edges')
if not edge_configs:
return None
edge_configs = cast(list, edge_configs)
# reorganize edges mapping
source_edges_mapping: dict[str, list[dict]] = {}
target_edge_ids = set()
for edge_config in edge_configs:
source_node_id = edge_config.get('source')
if not source_node_id:
continue
if source_node_id not in source_edges_mapping:
source_edges_mapping[source_node_id] = []
source_edges_mapping[source_node_id].append(edge_config)
target_node_id = edge_config.get('target')
if target_node_id:
target_edge_ids.add(target_node_id)
# node configs
node_configs = graph_config.get('nodes')
if not node_configs:
return None
node_configs = cast(list, node_configs)
# fetch nodes that have no predecessor node
root_node_configs = []
nodes_mapping: dict[str, dict] = {}
for node_config in node_configs:
node_id = node_config.get('id')
if not node_id:
continue
if node_id not in target_edge_ids:
root_node_configs.append(node_config)
nodes_mapping[node_id] = node_config
# fetch root node
if root_node_id:
root_node_config = next((node_config for node_config in root_node_configs
if node_config.get('id') == root_node_id), None)
else:
# if no root node id, use the START type node as root node
root_node_config = next((node_config for node_config in root_node_configs
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
if not root_node_config:
return None
# init graph
graph = Graph.init(
root_node_config=root_node_config
)
# add edge from root node
self._recursively_add_edges(
graph=graph,
source_node_config=root_node_config,
edges_mapping=source_edges_mapping,
nodes_mapping=nodes_mapping,
root_node_configs=root_node_configs
)
return graph
def _recursively_add_edges(self, graph: Graph,
source_node_config: dict,
edges_mapping: dict,

View File

@ -0,0 +1,208 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.utils.condition.entities import Condition
def test_init():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "start-source-qc-target",
"source": "start",
"target": "qc",
},
{
"id": "qc-1-llm-target",
"source": "qc",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "qc-2-http-target",
"source": "qc",
"sourceHandle": "2",
"target": "http",
},
{
"id": "http-source-answer2-target",
"source": "http",
"target": "answer2",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "question-classifier"
},
"id": "qc",
},
{
"data": {
"type": "http-request",
},
"id": "http",
},
{
"data": {
"type": "answer",
},
"id": "answer2",
}
],
}
graph = Graph.init(
graph_config=graph_config
)
start_node_id = "start"
assert graph.root_node_id == start_node_id
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
def test__init_graph_with_iteration():
graph_config = {
"edges": [
{
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
},
{
"id": "iteration-source-llm-target",
"source": "iteration",
"sourceHandle": "source",
"target": "llm",
},
{
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
"source": "template-transform-in-iteration",
"sourceHandle": "source",
"target": "llm-in-iteration",
},
{
"id": "llm-in-iteration-source-answer-in-iteration-target",
"source": "llm-in-iteration",
"sourceHandle": "source",
"target": "answer-in-iteration",
},
{
"id": "start-source-code-target",
"source": "start",
"sourceHandle": "source",
"target": "code",
},
{
"id": "code-source-iteration-target",
"source": "code",
"sourceHandle": "source",
"target": "iteration",
}
],
"nodes": [
{
"data": {
"type": "start",
},
"id": "start",
},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "iteration"
},
"id": "iteration",
},
{
"data": {
"type": "template-transform",
},
"id": "template-transform-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "llm",
},
"id": "llm-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "answer",
},
"id": "answer-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "code",
},
"id": "code",
}
]
}
graph = Graph.init(
graph_config=graph_config,
root_node_id="template-transform-in-iteration"
)
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=["iteration", "index"],
comparison_operator="",
value="5"
)
]
)
)
# iteration:
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
assert graph.root_node_id == "template-transform-in-iteration"
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"

View File

@ -1,693 +0,0 @@
from typing import Optional
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.workflow_entry import WorkflowEntry
def test__init_graph():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "1717222650545-source-1719481290322-target",
"source": "1717222650545",
"target": "1719481290322",
},
{
"id": "1719481290322-1-llm-target",
"source": "1719481290322",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "1719481290322-2-1719481315734-target",
"source": "1719481290322",
"sourceHandle": "2",
"target": "1719481315734",
},
{
"id": "1719481315734-source-1719481326339-target",
"source": "1719481315734",
"target": "1719481326339",
}
],
"nodes": [
{
"data": {
"desc": "",
"title": "Start",
"type": "start",
"variables": [
{
"label": "name",
"max_length": 48,
"options": [],
"required": False,
"type": "text-input",
"variable": "name"
}
]
},
"id": "1717222650545",
"position": {
"x": -147.65487258270954,
"y": 263.5326708413438
},
},
{
"data": {
"context": {
"enabled": False,
"variable_selector": []
},
"desc": "",
"memory": {
"query_prompt_template": "{{#sys.query#}}",
"role_prefix": {
"assistant": "",
"user": ""
},
"window": {
"enabled": False,
"size": 10
}
},
"model": {
"completion_params": {
"temperature": 0
},
"mode": "chat",
"name": "anthropic.claude-3-sonnet-20240229-v1:0",
"provider": "bedrock"
},
"prompt_config": {
"jinja2_variables": [
{
"value_selector": [
"sys",
"query"
],
"variable": "query"
}
]
},
"prompt_template": [
{
"edition_type": "basic",
"id": "8b02d178-3aa0-4dbd-82bf-8b6a40658300",
"jinja2_text": "",
"role": "system",
"text": "yep"
}
],
"title": "LLM",
"type": "llm",
"variables": [],
"vision": {
"configs": {
"detail": "low"
},
"enabled": True
}
},
"id": "llm",
"position": {
"x": 654.0331237272932,
"y": 263.5326708413438
},
},
{
"data": {
"answer": "123{{#llm.text#}}",
"desc": "",
"title": "Answer",
"type": "answer",
"variables": []
},
"id": "answer",
"position": {
"x": 958.1129142362784,
"y": 263.5326708413438
},
},
{
"data": {
"classes": [
{
"id": "1",
"name": "happy"
},
{
"id": "2",
"name": "sad"
}
],
"desc": "",
"instructions": "",
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-4o",
"provider": "openai"
},
"query_variable_selector": [
"1717222650545",
"sys.query"
],
"title": "Question Classifier",
"topics": [],
"type": "question-classifier"
},
"id": "1719481290322",
"position": {
"x": 165.25154615277052,
"y": 263.5326708413438
}
},
{
"data": {
"authorization": {
"config": None,
"type": "no-auth"
},
"body": {
"data": "",
"type": "none"
},
"desc": "",
"headers": "",
"method": "get",
"params": "",
"timeout": {
"max_connect_timeout": 0,
"max_read_timeout": 0,
"max_write_timeout": 0
},
"title": "HTTP Request",
"type": "http-request",
"url": "https://baidu.com",
"variables": []
},
"height": 88,
"id": "1719481315734",
"position": {
"x": 654.0331237272932,
"y": 474.1180064703089
}
},
{
"data": {
"answer": "{{#1719481315734.status_code#}}",
"desc": "",
"title": "Answer 2",
"type": "answer",
"variables": []
},
"height": 105,
"id": "1719481326339",
"position": {
"x": 958.1129142362784,
"y": 474.1180064703089
},
}
],
}
workflow_entry = WorkflowEntry()
graph = workflow_entry._init_graph(
graph_config=graph_config
)
assert graph.root_node.id == "1717222650545"
assert graph.root_node.source_edge_config is None
assert graph.root_node.descendant_node_ids == ["1719481290322"]
assert graph.graph_nodes.get("1719481290322") is not None
assert len(graph.graph_nodes.get("1719481290322").descendant_node_ids) == 2
assert graph.graph_nodes.get("llm").run_condition is not None
assert graph.graph_nodes.get("1719481315734").run_condition is not None
def test__init_graph_with_iteration():
graph_config = {
"edges": [
{
"data": {
"sourceType": "llm",
"targetType": "answer"
},
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
"targetHandle": "target",
"type": "custom"
},
{
"data": {
"isInIteration": False,
"sourceType": "iteration",
"targetType": "llm"
},
"id": "1720001776597-source-llm-target",
"selected": False,
"source": "1720001776597",
"sourceHandle": "source",
"target": "llm",
"targetHandle": "target",
"type": "custom",
"zIndex": 0
},
{
"data": {
"isInIteration": True,
"iteration_id": "1720001776597",
"sourceType": "template-transform",
"targetType": "llm"
},
"id": "1720001783092-source-1720001859851-target",
"source": "1720001783092",
"sourceHandle": "source",
"target": "1720001859851",
"targetHandle": "target",
"type": "custom",
"zIndex": 1002
},
{
"data": {
"isInIteration": True,
"iteration_id": "1720001776597",
"sourceType": "llm",
"targetType": "answer"
},
"id": "1720001859851-source-1720001879621-target",
"source": "1720001859851",
"sourceHandle": "source",
"target": "1720001879621",
"targetHandle": "target",
"type": "custom",
"zIndex": 1002
},
{
"data": {
"isInIteration": False,
"sourceType": "start",
"targetType": "code"
},
"id": "1720001771022-source-1720001956578-target",
"source": "1720001771022",
"sourceHandle": "source",
"target": "1720001956578",
"targetHandle": "target",
"type": "custom",
"zIndex": 0
},
{
"data": {
"isInIteration": False,
"sourceType": "code",
"targetType": "iteration"
},
"id": "1720001956578-source-1720001776597-target",
"source": "1720001956578",
"sourceHandle": "source",
"target": "1720001776597",
"targetHandle": "target",
"type": "custom",
"zIndex": 0
}
],
"nodes": [
{
"data": {
"desc": "",
"selected": False,
"title": "Start",
"type": "start",
"variables": []
},
"height": 53,
"id": "1720001771022",
"position": {
"x": 80,
"y": 282
},
"positionAbsolute": {
"x": 80,
"y": 282
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
},
{
"data": {
"context": {
"enabled": False,
"variable_selector": []
},
"desc": "",
"memory": {
"role_prefix": {
"assistant": "",
"user": ""
},
"window": {
"enabled": False,
"size": 10
}
},
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-3.5-turbo",
"provider": "openai"
},
"prompt_template": [
{
"id": "b7d1350e-cf0d-4ff3-8ad0-52b6f1218781",
"role": "system",
"text": ""
}
],
"selected": False,
"title": "LLM",
"type": "llm",
"variables": [],
"vision": {
"enabled": False
}
},
"height": 97,
"id": "llm",
"position": {
"x": 1730.595805935594,
"y": 282
},
"positionAbsolute": {
"x": 1730.595805935594,
"y": 282
},
"selected": True,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
},
{
"data": {
"answer": "{{#llm.text#}}",
"desc": "",
"selected": False,
"title": "Answer",
"type": "answer",
"variables": []
},
"height": 105,
"id": "answer",
"position": {
"x": 2042.803154918583,
"y": 282
},
"positionAbsolute": {
"x": 2042.803154918583,
"y": 282
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
},
{
"data": {
"desc": "",
"height": 202,
"iterator_selector": [
"1720001956578",
"result"
],
"output_selector": [
"1720001859851",
"text"
],
"output_type": "array[string]",
"selected": False,
"startNodeType": "template-transform",
"start_node_id": "1720001783092",
"title": "Iteration",
"type": "iteration",
"width": 985
},
"height": 202,
"id": "1720001776597",
"position": {
"x": 678.6748900850307,
"y": 282
},
"positionAbsolute": {
"x": 678.6748900850307,
"y": 282
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 985,
"zIndex": 1
},
{
"data": {
"desc": "",
"isInIteration": True,
"isIterationStart": True,
"iteration_id": "1720001776597",
"selected": False,
"template": "{{ arg1 }}",
"title": "Template",
"type": "template-transform",
"variables": [
{
"value_selector": [
"1720001776597",
"item"
],
"variable": "arg1"
}
]
},
"extent": "parent",
"height": 53,
"id": "1720001783092",
"parentId": "1720001776597",
"position": {
"x": 117,
"y": 85
},
"positionAbsolute": {
"x": 795.6748900850307,
"y": 367
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
"zIndex": 1001
},
{
"data": {
"context": {
"enabled": False,
"variable_selector": []
},
"desc": "",
"isInIteration": True,
"iteration_id": "1720001776597",
"model": {
"completion_params": {
"temperature": 0.7
},
"mode": "chat",
"name": "gpt-3.5-turbo",
"provider": "openai"
},
"prompt_template": [
{
"id": "9575b8f2-33c4-4611-b6d0-17d8d436a250",
"role": "system",
"text": "{{#1720001783092.output#}}"
}
],
"selected": False,
"title": "LLM 2",
"type": "llm",
"variables": [],
"vision": {
"enabled": False
}
},
"extent": "parent",
"height": 97,
"id": "1720001859851",
"parentId": "1720001776597",
"position": {
"x": 421,
"y": 85
},
"positionAbsolute": {
"x": 1099.6748900850307,
"y": 367
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
"zIndex": 1002
},
{
"data": {
"answer": "{{#1720001859851.text#}}",
"desc": "",
"isInIteration": True,
"iteration_id": "1720001776597",
"selected": False,
"title": "Answer 2",
"type": "answer",
"variables": []
},
"extent": "parent",
"height": 105,
"id": "1720001879621",
"parentId": "1720001776597",
"position": {
"x": 725,
"y": 85
},
"positionAbsolute": {
"x": 1403.6748900850307,
"y": 367
},
"selected": False,
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244,
"zIndex": 1002
},
{
"data": {
"code": "\ndef main() -> dict:\n return {\n \"result\": [\n \"a\",\n \"b\"\n ]\n }\n",
"code_language": "python3",
"desc": "",
"outputs": {
"result": {
"children": None,
"type": "array[string]"
}
},
"selected": False,
"title": "Code",
"type": "code",
"variables": []
},
"height": 53,
"id": "1720001956578",
"position": {
"x": 380,
"y": 282
},
"positionAbsolute": {
"x": 380,
"y": 282
},
"sourcePosition": "right",
"targetPosition": "left",
"type": "custom",
"width": 244
}
]
}
workflow_entry = WorkflowEntry()
graph = workflow_entry._init_graph(
graph_config=graph_config
)
# start 1720001771022 -> code 1720001956578 -> iteration 1720001776597 -> llm llm -> answer answer
# iteration 1720001776597:
# [template 1720001783092 -> llm 1720001859851 -> answer 1720001879621]
main_graph_orders = [
"1720001771022",
"1720001956578",
"1720001776597",
"llm",
"answer"
]
iteration_sub_graph_orders = [
"1720001783092",
"1720001859851",
"1720001879621"
]
assert graph.root_node.id == "1720001771022"
print("")
current_graph = graph
for i, node_id in enumerate(main_graph_orders):
current_root_node = current_graph.root_node
assert current_root_node is not None
assert current_root_node.id == node_id
if current_root_node.node_config.get("data", {}).get("type") == "iteration":
assert current_root_node.sub_graph is not None
sub_graph = current_root_node.sub_graph
assert sub_graph.root_node.id == "1720001783092"
current_sub_graph = sub_graph
for j, sub_node_id in enumerate(iteration_sub_graph_orders):
sub_descendant_graphs = current_sub_graph.get_descendant_graphs(node_id=current_sub_graph.root_node.id)
print(f"Iteration [{current_sub_graph.root_node.id}] -> {len(sub_descendant_graphs)}"
f" {[sub_descendant_graph.root_node.id for sub_descendant_graph in sub_descendant_graphs]}")
if j == len(iteration_sub_graph_orders) - 1:
break
assert len(sub_descendant_graphs) == 1
first_sub_descendant_graph = sub_descendant_graphs[0]
assert first_sub_descendant_graph.root_node.id == iteration_sub_graph_orders[j + 1]
assert first_sub_descendant_graph.root_node.predecessor_node_id == sub_node_id
current_sub_graph = first_sub_descendant_graph
descendant_graphs = current_graph.get_descendant_graphs(node_id=current_graph.root_node.id)
print(f"[{current_graph.root_node.id}] -> {len(descendant_graphs)}"
f" {[descendant_graph.root_node.id for descendant_graph in descendant_graphs]}")
if i == len(main_graph_orders) - 1:
assert len(descendant_graphs) == 0
break
assert len(descendant_graphs) == 1
first_descendant_graph = descendant_graphs[0]
assert first_descendant_graph.root_node.id == main_graph_orders[i + 1]
assert first_descendant_graph.root_node.predecessor_node_id == node_id
current_graph = first_descendant_graph