add graph engine test

This commit is contained in:
takatost 2024-07-16 16:37:37 +08:00
parent 00fb23d0c9
commit 00ec36d47c
17 changed files with 1122 additions and 904 deletions

View File

@ -3,20 +3,6 @@ from typing import Any, Optional
from pydantic import BaseModel
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from models.workflow import WorkflowNodeExecutionStatus
@ -55,25 +41,6 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}')
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
class SystemVariable(Enum):
"""
System Variables.

View File

@ -1,23 +1,31 @@
from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition
class RunConditionHandler(ABC):
def __init__(self, condition: RunCondition):
def __init__(self,
init_params: GraphInitParams,
graph: Graph,
condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition
@abstractmethod
def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
target_node_id: str) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
raise NotImplementedError

View File

@ -1,29 +1,33 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
target_node_id: str) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
if not self.condition.branch_identify:
raise Exception("Branch identify is required")
run_state = graph.run_state
node_route_result = run_state.node_route_results.get(source_node_id)
if not node_route_result:
node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id)
if not node_route_state:
return False
if not node_route_result.edge_source_handle:
run_result = node_route_state.node_run_result
if not run_result:
return False
return self.condition.branch_identify == node_route_result.edge_source_handle
if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@ -1,18 +1,19 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str,
target_node_id: str,
graph: "Graph") -> bool:
target_node_id: str) -> bool:
"""
Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id
:param target_node_id: target node id
:param graph: graph
:return: bool
"""
if not self.condition.conditions:
@ -21,10 +22,9 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
# process condition
condition_processor = ConditionProcessor()
compare_result, _ = condition_processor.process(
variable_pool=graph.run_state.variable_pool,
variable_pool=graph_runtime_state.variable_pool,
logical_operator="and",
conditions=self.condition.conditions
)
return compare_result

View File

@ -1,19 +1,35 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager:
@staticmethod
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
def get_condition_handler(
init_params: GraphInitParams,
graph: Graph,
run_condition: RunCondition
) -> RunConditionHandler:
"""
Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition
:return: condition handler
"""
if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(run_condition)
return BranchIdentifyRunConditionHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
else:
return ConditionRunConditionHandlerHandler(run_condition)
return ConditionRunConditionHandlerHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)

View File

@ -3,6 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field, model_validator
from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class GraphEngineEvent(BaseModel):
@ -50,10 +51,12 @@ class NodeRunStartedEvent(BaseNodeEvent):
class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent):
@ -61,7 +64,10 @@ class NodeRunSucceededEvent(BaseNodeEvent):
class NodeRunFailedEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result")
run_result: NodeRunResult = Field(
default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED),
description="run result"
)
reason: str = Field("", description="failed reason")
@model_validator(mode='before')

View File

@ -3,13 +3,13 @@ import queue
import time
from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, cast
from datetime import datetime, timezone
from typing import Optional
from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
@ -19,14 +19,21 @@ from core.workflow.graph_engine.entities.event import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom, node_classes
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import node_classes
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.test.test_node import TestNode
from extensions.ext_database import db
from models.workflow import WorkflowType
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
logger = logging.getLogger(__name__)
@ -43,7 +50,8 @@ class GraphEngine:
call_depth: int,
graph: Graph,
variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None:
max_execution_steps: int,
max_execution_time: int) -> None:
self.graph = graph
self.init_params = GraphInitParams(
tenant_id=tenant_id,
@ -61,12 +69,8 @@ class GraphEngine:
start_at=time.perf_counter()
)
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
self.max_execution_steps = cast(int, max_execution_steps)
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
self.max_execution_time = cast(int, max_execution_time)
self.callbacks = callbacks
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
def run_in_block_mode(self):
# TODO convert generator to result
@ -92,7 +96,7 @@ class GraphEngine:
return
except Exception as e:
yield GraphRunFailedEvent(reason=str(e))
return
raise e
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
next_node_id = start_node_id
@ -118,7 +122,7 @@ class GraphEngine:
)
except Exception as e:
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
return
raise e
previous_node_id = next_node_id
@ -141,11 +145,13 @@ class GraphEngine:
for edge in edge_mappings:
if edge.run_condition:
result = ConditionManager.get_condition_handler(
run_condition=edge.run_condition
init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check(
graph_runtime_state=self.graph_runtime_state,
source_node_id=edge.source_node_id,
target_node_id=edge.target_node_id,
graph=self.graph
)
if result:
@ -250,7 +256,16 @@ class GraphEngine:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
# init workflow run state
node_instance = node_cls( # type: ignore
# node_instance = node_cls( # type: ignore
# config=node_config,
# graph_init_params=self.init_params,
# graph=self.graph,
# graph_runtime_state=self.graph_runtime_state,
# previous_node_id=previous_node_id
# )
# init workflow run state
node_instance = TestNode(
config=node_config,
graph_init_params=self.init_params,
graph=self.graph,
@ -268,24 +283,64 @@ class GraphEngine:
self.graph_runtime_state.node_run_steps += 1
try:
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
# run node
generator = node_instance.run()
run_result = None
for item in generator:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
node_id=node_id,
parallel_id=parallel_id,
run_result=run_result,
reason=run_result.error
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
yield NodeRunSucceededEvent(
node_id=node_id,
parallel_id=parallel_id,
run_result=run_result
)
yield from generator
self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState(
node_id=node_id,
start_at=start_at,
status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
else RouteNodeState.Status.FAILED,
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
node_run_result=run_result,
failed_reason=run_result.error
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
)
# todo append self.graph_runtime_state.node_run_state.routes
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
node_id=node_id,
parallel_id=parallel_id,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
node_id=node_id,
parallel_id=parallel_id,
retriever_resources=item.retriever_resources,
context=item.context
)
# todo record state
# trigger node run success event
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
except GenerateTaskStoppedException as e:
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
except Exception as e:
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
# trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
raise e
finally:
db.session.close()

View File

@ -0,0 +1,33 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}

View File

@ -2,37 +2,20 @@ from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iterable_node import IterableNodeMixin
from models.workflow import WorkflowType
class BaseNode(ABC):
_node_data_cls: type[BaseNodeData]
_node_type: NodeType
tenant_id: str
app_id: str
workflow_type: WorkflowType
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
graph: Graph
graph_runtime_state: GraphRuntimeState
previous_node_id: Optional[str] = None
node_id: str
node_data: BaseNodeData
def __init__(self,
config: dict,
graph_init_params: GraphInitParams,
@ -81,24 +64,6 @@ class BaseNode(ABC):
else:
yield from result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
# TODO remove callbacks
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text,
metadata={
"value_selector": value_selector
}
)
@classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
"""

View File

@ -170,7 +170,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
"""
Invoke large language model
:param node_data_model: node data model
@ -204,7 +204,7 @@ class LLMNode(BaseNode):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
"""
Handle invoke result
:param invoke_result: invoke result

View File

View File

@ -0,0 +1,8 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class TestNodeData(BaseNodeData):
"""
Test Node Data.
"""
pass

View File

@ -0,0 +1,33 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.test.entities import TestNodeData
from models.workflow import WorkflowNodeExecutionStatus
class TestNode(BaseNode):
_node_data_cls = TestNodeData
node_type = NodeType.ANSWER
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"content": "abc"
},
edge_source_handle="1"
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View File

@ -17,7 +17,8 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes
from core.workflow.nodes import node_classes
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.start.start_node import StartNode
@ -93,7 +94,8 @@ class WorkflowEntry:
call_depth=call_depth,
graph=graph,
variable_pool=variable_pool,
callbacks=callbacks
max_execution_steps=current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS"),
max_execution_time=current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
)
# init workflow run

View File

@ -10,7 +10,8 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.workflow_engine_manager import WorkflowEngineManager, node_classes
from core.workflow.nodes import node_classes
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from models.account import Account

View File

@ -0,0 +1,864 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.utils.condition.entities import Condition
def test_init():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "start-source-qc-target",
"source": "start",
"target": "qc",
},
{
"id": "qc-1-llm-target",
"source": "qc",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "qc-2-http-target",
"source": "qc",
"sourceHandle": "2",
"target": "http",
},
{
"id": "http-source-answer2-target",
"source": "http",
"target": "answer2",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "question-classifier"
},
"id": "qc",
},
{
"data": {
"type": "http-request",
},
"id": "http",
},
{
"data": {
"type": "answer",
},
"id": "answer2",
}
],
}
graph = Graph.init(
graph_config=graph_config
)
start_node_id = "start"
assert graph.root_node_id == start_node_id
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
def test__init_iteration_graph():
graph_config = {
"edges": [
{
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
},
{
"id": "iteration-source-llm-target",
"source": "iteration",
"sourceHandle": "source",
"target": "llm",
},
{
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
"source": "template-transform-in-iteration",
"sourceHandle": "source",
"target": "llm-in-iteration",
},
{
"id": "llm-in-iteration-source-answer-in-iteration-target",
"source": "llm-in-iteration",
"sourceHandle": "source",
"target": "answer-in-iteration",
},
{
"id": "start-source-code-target",
"source": "start",
"sourceHandle": "source",
"target": "code",
},
{
"id": "code-source-iteration-target",
"source": "code",
"sourceHandle": "source",
"target": "iteration",
}
],
"nodes": [
{
"data": {
"type": "start",
},
"id": "start",
},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "iteration"
},
"id": "iteration",
},
{
"data": {
"type": "template-transform",
},
"id": "template-transform-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "llm",
},
"id": "llm-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "answer",
},
"id": "answer-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "code",
},
"id": "code",
}
]
}
graph = Graph.init(
graph_config=graph_config,
root_node_id="template-transform-in-iteration"
)
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=["iteration", "index"],
comparison_operator="",
value="5"
)
]
)
)
# iteration:
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
assert graph.root_node_id == "template-transform-in-iteration"
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
def test_parallels_graph():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
assert graph.edge_mapping.get(f"llm{i+1}") is not None
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph2():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
if i < 2:
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph3():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph4():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "code2",
},
{
"id": "llm3-source-code3-target",
"source": "llm3",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
assert graph.edge_mapping.get(f"code{i + 1}") is not None
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph5():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm4",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm5",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-code1-target",
"source": "llm2",
"target": "code1",
},
{
"id": "llm3-source-code2-target",
"source": "llm3",
"target": "code2",
},
{
"id": "llm4-source-code2-target",
"source": "llm4",
"target": "code2",
},
{
"id": "llm5-source-code3-target",
"source": "llm5",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "llm",
},
"id": "llm4"
},
{
"data": {
"type": "llm",
},
"id": "llm5"
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(5):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm3") is not None
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm4") is not None
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm5") is not None
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 8
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph6():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm1-source-code2-target",
"source": "llm1",
"target": "code2",
},
{
"id": "llm2-source-code3-target",
"source": "llm2",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code3") is not None
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 2
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
parent_parallel = None
child_parallel = None
for p_id, parallel in graph.parallel_mapping.items():
if parallel.parent_parallel_id is None:
parent_parallel = parallel
else:
child_parallel = parallel
for node_id in ["llm1", "llm2", "llm3", "code3"]:
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
for node_id in ["code1", "code2"]:
assert graph.node_parallel_mapping[node_id] == child_parallel.id

View File

@ -1,9 +1,16 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.utils.condition.entities import Condition
from core.workflow.graph_engine.graph_engine import GraphEngine
from models.workflow import WorkflowType
def test_init():
@patch('extensions.ext_database.db.session.remove')
@patch('extensions.ext_database.db.session.close')
def test_run(mock_close, mock_remove):
graph_config = {
"edges": [
{
@ -37,37 +44,43 @@ def test_init():
"nodes": [
{
"data": {
"type": "start"
"type": "start",
"title": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
"title": "llm"
},
"id": "llm"
},
{
"data": {
"type": "answer",
"title": "answer"
},
"id": "answer",
},
{
"data": {
"type": "question-classifier"
"type": "question-classifier",
"title": "qc"
},
"id": "qc",
},
{
"data": {
"type": "http-request",
"title": "http"
},
"id": "http",
},
{
"data": {
"type": "answer",
"title": "answer2"
},
"id": "answer2",
}
@ -78,787 +91,30 @@ def test_init():
graph_config=graph_config
)
start_node_id = "start"
variable_pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
assert graph.root_node_id == start_node_id
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
def test__init_iteration_graph():
graph_config = {
"edges": [
{
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
},
{
"id": "iteration-source-llm-target",
"source": "iteration",
"sourceHandle": "source",
"target": "llm",
},
{
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
"source": "template-transform-in-iteration",
"sourceHandle": "source",
"target": "llm-in-iteration",
},
{
"id": "llm-in-iteration-source-answer-in-iteration-target",
"source": "llm-in-iteration",
"sourceHandle": "source",
"target": "answer-in-iteration",
},
{
"id": "start-source-code-target",
"source": "start",
"sourceHandle": "source",
"target": "code",
},
{
"id": "code-source-iteration-target",
"source": "code",
"sourceHandle": "source",
"target": "iteration",
}
],
"nodes": [
{
"data": {
"type": "start",
},
"id": "start",
},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "iteration"
},
"id": "iteration",
},
{
"data": {
"type": "template-transform",
},
"id": "template-transform-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "llm",
},
"id": "llm-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "answer",
},
"id": "answer-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "code",
},
"id": "code",
}
]
}
graph = Graph.init(
graph_config=graph_config,
root_node_id="template-transform-in-iteration"
)
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=["iteration", "index"],
comparison_operator="",
value="5"
)
]
)
graph_engine = GraphEngine(
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
user_id="444",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.WEB_APP,
call_depth=0,
graph=graph,
variable_pool=variable_pool,
max_execution_steps=500,
max_execution_time=1200
)
# iteration:
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
print("")
assert graph.root_node_id == "template-transform-in-iteration"
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
def test_parallels_graph():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
assert graph.edge_mapping.get(f"llm{i+1}") is not None
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph2():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
if i < 2:
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph3():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph4():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "code2",
},
{
"id": "llm3-source-code3-target",
"source": "llm3",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
assert graph.edge_mapping.get(f"code{i + 1}") is not None
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph5():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm4",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm5",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-code1-target",
"source": "llm2",
"target": "code1",
},
{
"id": "llm3-source-code2-target",
"source": "llm3",
"target": "code2",
},
{
"id": "llm4-source-code2-target",
"source": "llm4",
"target": "code2",
},
{
"id": "llm5-source-code3-target",
"source": "llm5",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "llm",
},
"id": "llm4"
},
{
"data": {
"type": "llm",
},
"id": "llm5"
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(5):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm3") is not None
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm4") is not None
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm5") is not None
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 8
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph6():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm1-source-code2-target",
"source": "llm1",
"target": "code2",
},
{
"id": "llm2-source-code3-target",
"source": "llm2",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code3") is not None
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 2
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
parent_parallel = None
child_parallel = None
for p_id, parallel in graph.parallel_mapping.items():
if parallel.parent_parallel_id is None:
parent_parallel = parallel
else:
child_parallel = parallel
for node_id in ["llm1", "llm2", "llm3", "code3"]:
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
for node_id in ["code1", "code2"]:
assert graph.node_parallel_mapping[node_id] == child_parallel.id
generator = graph_engine.run()
for item in generator:
print(type(item), item)