mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 17:05:58 +08:00
add graph engine test
This commit is contained in:
parent
00fb23d0c9
commit
00ec36d47c
@ -3,20 +3,6 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
|
||||||
from core.workflow.nodes.code.code_node import CodeNode
|
|
||||||
from core.workflow.nodes.end.end_node import EndNode
|
|
||||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
|
||||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
|
||||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
|
||||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
|
||||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
|
||||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
|
||||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
|
||||||
from core.workflow.nodes.start.start_node import StartNode
|
|
||||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
|
||||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
|
||||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
@ -55,25 +41,6 @@ class NodeType(Enum):
|
|||||||
raise ValueError(f'invalid node type value {value}')
|
raise ValueError(f'invalid node type value {value}')
|
||||||
|
|
||||||
|
|
||||||
node_classes = {
|
|
||||||
NodeType.START: StartNode,
|
|
||||||
NodeType.END: EndNode,
|
|
||||||
NodeType.ANSWER: AnswerNode,
|
|
||||||
NodeType.LLM: LLMNode,
|
|
||||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
|
||||||
NodeType.IF_ELSE: IfElseNode,
|
|
||||||
NodeType.CODE: CodeNode,
|
|
||||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
|
||||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
|
||||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
|
||||||
NodeType.TOOL: ToolNode,
|
|
||||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
|
||||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
|
||||||
NodeType.ITERATION: IterationNode,
|
|
||||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SystemVariable(Enum):
|
class SystemVariable(Enum):
|
||||||
"""
|
"""
|
||||||
System Variables.
|
System Variables.
|
||||||
|
@ -1,23 +1,31 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
|
||||||
|
|
||||||
class RunConditionHandler(ABC):
|
class RunConditionHandler(ABC):
|
||||||
def __init__(self, condition: RunCondition):
|
def __init__(self,
|
||||||
|
init_params: GraphInitParams,
|
||||||
|
graph: Graph,
|
||||||
|
condition: RunCondition):
|
||||||
|
self.init_params = init_params
|
||||||
|
self.graph = graph
|
||||||
self.condition = condition
|
self.condition = condition
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def check(self,
|
def check(self,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
target_node_id: str,
|
target_node_id: str) -> bool:
|
||||||
graph: "Graph") -> bool:
|
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
|
:param graph_runtime_state: graph runtime state
|
||||||
:param source_node_id: source node id
|
:param source_node_id: source node id
|
||||||
:param target_node_id: target node id
|
:param target_node_id: target node id
|
||||||
:param graph: graph
|
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -1,29 +1,33 @@
|
|||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
|
||||||
|
|
||||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||||
|
|
||||||
def check(self,
|
def check(self,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
target_node_id: str,
|
target_node_id: str) -> bool:
|
||||||
graph: "Graph") -> bool:
|
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
|
:param graph_runtime_state: graph runtime state
|
||||||
:param source_node_id: source node id
|
:param source_node_id: source node id
|
||||||
:param target_node_id: target node id
|
:param target_node_id: target node id
|
||||||
:param graph: graph
|
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.condition.branch_identify:
|
if not self.condition.branch_identify:
|
||||||
raise Exception("Branch identify is required")
|
raise Exception("Branch identify is required")
|
||||||
|
|
||||||
run_state = graph.run_state
|
node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id)
|
||||||
node_route_result = run_state.node_route_results.get(source_node_id)
|
if not node_route_state:
|
||||||
if not node_route_result:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not node_route_result.edge_source_handle:
|
run_result = node_route_state.node_run_result
|
||||||
|
if not run_result:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return self.condition.branch_identify == node_route_result.edge_source_handle
|
if not run_result.edge_source_handle:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.condition.branch_identify == run_result.edge_source_handle
|
||||||
|
@ -1,18 +1,19 @@
|
|||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||||
|
|
||||||
|
|
||||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||||
def check(self,
|
def check(self,
|
||||||
|
graph_runtime_state: GraphRuntimeState,
|
||||||
source_node_id: str,
|
source_node_id: str,
|
||||||
target_node_id: str,
|
target_node_id: str) -> bool:
|
||||||
graph: "Graph") -> bool:
|
|
||||||
"""
|
"""
|
||||||
Check if the condition can be executed
|
Check if the condition can be executed
|
||||||
|
|
||||||
|
:param graph_runtime_state: graph runtime state
|
||||||
:param source_node_id: source node id
|
:param source_node_id: source node id
|
||||||
:param target_node_id: target node id
|
:param target_node_id: target node id
|
||||||
:param graph: graph
|
|
||||||
:return: bool
|
:return: bool
|
||||||
"""
|
"""
|
||||||
if not self.condition.conditions:
|
if not self.condition.conditions:
|
||||||
@ -21,10 +22,9 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
|||||||
# process condition
|
# process condition
|
||||||
condition_processor = ConditionProcessor()
|
condition_processor = ConditionProcessor()
|
||||||
compare_result, _ = condition_processor.process(
|
compare_result, _ = condition_processor.process(
|
||||||
variable_pool=graph.run_state.variable_pool,
|
variable_pool=graph_runtime_state.variable_pool,
|
||||||
logical_operator="and",
|
logical_operator="and",
|
||||||
conditions=self.condition.conditions
|
conditions=self.condition.conditions
|
||||||
)
|
)
|
||||||
|
|
||||||
return compare_result
|
return compare_result
|
||||||
|
|
||||||
|
@ -1,19 +1,35 @@
|
|||||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
|
||||||
|
|
||||||
class ConditionManager:
|
class ConditionManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler:
|
def get_condition_handler(
|
||||||
|
init_params: GraphInitParams,
|
||||||
|
graph: Graph,
|
||||||
|
run_condition: RunCondition
|
||||||
|
) -> RunConditionHandler:
|
||||||
"""
|
"""
|
||||||
Get condition handler
|
Get condition handler
|
||||||
|
|
||||||
|
:param init_params: init params
|
||||||
|
:param graph: graph
|
||||||
:param run_condition: run condition
|
:param run_condition: run condition
|
||||||
:return: condition handler
|
:return: condition handler
|
||||||
"""
|
"""
|
||||||
if run_condition.type == "branch_identify":
|
if run_condition.type == "branch_identify":
|
||||||
return BranchIdentifyRunConditionHandler(run_condition)
|
return BranchIdentifyRunConditionHandler(
|
||||||
|
init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
condition=run_condition
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return ConditionRunConditionHandlerHandler(run_condition)
|
return ConditionRunConditionHandlerHandler(
|
||||||
|
init_params=init_params,
|
||||||
|
graph=graph,
|
||||||
|
condition=run_condition
|
||||||
|
)
|
||||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from core.workflow.entities.node_entities import NodeRunResult
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
class GraphEngineEvent(BaseModel):
|
class GraphEngineEvent(BaseModel):
|
||||||
@ -50,10 +51,12 @@ class NodeRunStartedEvent(BaseNodeEvent):
|
|||||||
|
|
||||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||||
chunk_content: str = Field(..., description="chunk content")
|
chunk_content: str = Field(..., description="chunk content")
|
||||||
|
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||||
|
|
||||||
|
|
||||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||||
|
context: str = Field(..., description="context")
|
||||||
|
|
||||||
|
|
||||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||||
@ -61,7 +64,10 @@ class NodeRunSucceededEvent(BaseNodeEvent):
|
|||||||
|
|
||||||
|
|
||||||
class NodeRunFailedEvent(BaseNodeEvent):
|
class NodeRunFailedEvent(BaseNodeEvent):
|
||||||
run_result: NodeRunResult = Field(..., description="run result")
|
run_result: NodeRunResult = Field(
|
||||||
|
default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED),
|
||||||
|
description="run result"
|
||||||
|
)
|
||||||
reason: str = Field("", description="failed reason")
|
reason: str = Field("", description="failed reason")
|
||||||
|
|
||||||
@model_validator(mode='before')
|
@model_validator(mode='before')
|
||||||
|
@ -3,13 +3,13 @@ import queue
|
|||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Optional, cast
|
from datetime import datetime, timezone
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||||
@ -19,14 +19,21 @@ from core.workflow.graph_engine.entities.event import (
|
|||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
|
NodeRunRetrieverResourceEvent,
|
||||||
NodeRunStartedEvent,
|
NodeRunStartedEvent,
|
||||||
|
NodeRunStreamChunkEvent,
|
||||||
|
NodeRunSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.base_node import UserFrom, node_classes
|
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||||
|
from core.workflow.nodes import node_classes
|
||||||
|
from core.workflow.nodes.base_node import UserFrom
|
||||||
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.test.test_node import TestNode
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||||
|
|
||||||
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
|
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -43,7 +50,8 @@ class GraphEngine:
|
|||||||
call_depth: int,
|
call_depth: int,
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
callbacks: list[BaseWorkflowCallback]) -> None:
|
max_execution_steps: int,
|
||||||
|
max_execution_time: int) -> None:
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.init_params = GraphInitParams(
|
self.init_params = GraphInitParams(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@ -61,12 +69,8 @@ class GraphEngine:
|
|||||||
start_at=time.perf_counter()
|
start_at=time.perf_counter()
|
||||||
)
|
)
|
||||||
|
|
||||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
self.max_execution_steps = max_execution_steps
|
||||||
self.max_execution_steps = cast(int, max_execution_steps)
|
self.max_execution_time = max_execution_time
|
||||||
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
|
||||||
self.max_execution_time = cast(int, max_execution_time)
|
|
||||||
|
|
||||||
self.callbacks = callbacks
|
|
||||||
|
|
||||||
def run_in_block_mode(self):
|
def run_in_block_mode(self):
|
||||||
# TODO convert generator to result
|
# TODO convert generator to result
|
||||||
@ -92,7 +96,7 @@ class GraphEngine:
|
|||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield GraphRunFailedEvent(reason=str(e))
|
yield GraphRunFailedEvent(reason=str(e))
|
||||||
return
|
raise e
|
||||||
|
|
||||||
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
|
||||||
next_node_id = start_node_id
|
next_node_id = start_node_id
|
||||||
@ -118,7 +122,7 @@ class GraphEngine:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
|
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
|
||||||
return
|
raise e
|
||||||
|
|
||||||
previous_node_id = next_node_id
|
previous_node_id = next_node_id
|
||||||
|
|
||||||
@ -141,11 +145,13 @@ class GraphEngine:
|
|||||||
for edge in edge_mappings:
|
for edge in edge_mappings:
|
||||||
if edge.run_condition:
|
if edge.run_condition:
|
||||||
result = ConditionManager.get_condition_handler(
|
result = ConditionManager.get_condition_handler(
|
||||||
run_condition=edge.run_condition
|
init_params=self.init_params,
|
||||||
|
graph=self.graph,
|
||||||
|
run_condition=edge.run_condition,
|
||||||
).check(
|
).check(
|
||||||
|
graph_runtime_state=self.graph_runtime_state,
|
||||||
source_node_id=edge.source_node_id,
|
source_node_id=edge.source_node_id,
|
||||||
target_node_id=edge.target_node_id,
|
target_node_id=edge.target_node_id,
|
||||||
graph=self.graph
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
@ -250,7 +256,16 @@ class GraphEngine:
|
|||||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||||
|
|
||||||
# init workflow run state
|
# init workflow run state
|
||||||
node_instance = node_cls( # type: ignore
|
# node_instance = node_cls( # type: ignore
|
||||||
|
# config=node_config,
|
||||||
|
# graph_init_params=self.init_params,
|
||||||
|
# graph=self.graph,
|
||||||
|
# graph_runtime_state=self.graph_runtime_state,
|
||||||
|
# previous_node_id=previous_node_id
|
||||||
|
# )
|
||||||
|
|
||||||
|
# init workflow run state
|
||||||
|
node_instance = TestNode(
|
||||||
config=node_config,
|
config=node_config,
|
||||||
graph_init_params=self.init_params,
|
graph_init_params=self.init_params,
|
||||||
graph=self.graph,
|
graph=self.graph,
|
||||||
@ -268,24 +283,64 @@ class GraphEngine:
|
|||||||
self.graph_runtime_state.node_run_steps += 1
|
self.graph_runtime_state.node_run_steps += 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||||
|
|
||||||
# run node
|
# run node
|
||||||
generator = node_instance.run()
|
generator = node_instance.run()
|
||||||
|
run_result = None
|
||||||
|
for item in generator:
|
||||||
|
if isinstance(item, RunCompletedEvent):
|
||||||
|
run_result = item.run_result
|
||||||
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||||
|
yield NodeRunFailedEvent(
|
||||||
|
node_id=node_id,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
run_result=run_result,
|
||||||
|
reason=run_result.error
|
||||||
|
)
|
||||||
|
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||||
|
yield NodeRunSucceededEvent(
|
||||||
|
node_id=node_id,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
run_result=run_result
|
||||||
|
)
|
||||||
|
|
||||||
yield from generator
|
self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState(
|
||||||
|
node_id=node_id,
|
||||||
|
start_at=start_at,
|
||||||
|
status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||||
|
else RouteNodeState.Status.FAILED,
|
||||||
|
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||||
|
node_run_result=run_result,
|
||||||
|
failed_reason=run_result.error
|
||||||
|
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# todo append self.graph_runtime_state.node_run_state.routes
|
||||||
|
break
|
||||||
|
elif isinstance(item, RunStreamChunkEvent):
|
||||||
|
yield NodeRunStreamChunkEvent(
|
||||||
|
node_id=node_id,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
chunk_content=item.chunk_content,
|
||||||
|
from_variable_selector=item.from_variable_selector,
|
||||||
|
)
|
||||||
|
elif isinstance(item, RunRetrieverResourceEvent):
|
||||||
|
yield NodeRunRetrieverResourceEvent(
|
||||||
|
node_id=node_id,
|
||||||
|
parallel_id=parallel_id,
|
||||||
|
retriever_resources=item.retriever_resources,
|
||||||
|
context=item.context
|
||||||
|
)
|
||||||
|
|
||||||
# todo record state
|
# todo record state
|
||||||
|
|
||||||
# trigger node run success event
|
|
||||||
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
|
|
||||||
except GenerateTaskStoppedException as e:
|
except GenerateTaskStoppedException as e:
|
||||||
# trigger node run failed event
|
# trigger node run failed event
|
||||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}")
|
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
|
||||||
# trigger node run failed event
|
raise e
|
||||||
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
|
|
||||||
return
|
|
||||||
finally:
|
finally:
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
|
@ -0,0 +1,33 @@
|
|||||||
|
from core.workflow.entities.node_entities import NodeType
|
||||||
|
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||||
|
from core.workflow.nodes.code.code_node import CodeNode
|
||||||
|
from core.workflow.nodes.end.end_node import EndNode
|
||||||
|
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||||
|
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||||
|
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||||
|
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||||
|
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||||
|
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||||
|
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||||
|
from core.workflow.nodes.start.start_node import StartNode
|
||||||
|
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||||
|
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||||
|
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||||
|
|
||||||
|
node_classes = {
|
||||||
|
NodeType.START: StartNode,
|
||||||
|
NodeType.END: EndNode,
|
||||||
|
NodeType.ANSWER: AnswerNode,
|
||||||
|
NodeType.LLM: LLMNode,
|
||||||
|
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||||
|
NodeType.IF_ELSE: IfElseNode,
|
||||||
|
NodeType.CODE: CodeNode,
|
||||||
|
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||||
|
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||||
|
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||||
|
NodeType.TOOL: ToolNode,
|
||||||
|
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||||
|
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||||
|
NodeType.ITERATION: IterationNode,
|
||||||
|
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||||
|
}
|
@ -2,37 +2,20 @@ from abc import ABC, abstractmethod
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||||
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
from core.workflow.nodes.iterable_node import IterableNodeMixin
|
||||||
from models.workflow import WorkflowType
|
|
||||||
|
|
||||||
|
|
||||||
class BaseNode(ABC):
|
class BaseNode(ABC):
|
||||||
_node_data_cls: type[BaseNodeData]
|
_node_data_cls: type[BaseNodeData]
|
||||||
_node_type: NodeType
|
_node_type: NodeType
|
||||||
|
|
||||||
tenant_id: str
|
|
||||||
app_id: str
|
|
||||||
workflow_type: WorkflowType
|
|
||||||
workflow_id: str
|
|
||||||
user_id: str
|
|
||||||
user_from: UserFrom
|
|
||||||
invoke_from: InvokeFrom
|
|
||||||
workflow_call_depth: int
|
|
||||||
graph: Graph
|
|
||||||
graph_runtime_state: GraphRuntimeState
|
|
||||||
previous_node_id: Optional[str] = None
|
|
||||||
|
|
||||||
node_id: str
|
|
||||||
node_data: BaseNodeData
|
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
config: dict,
|
config: dict,
|
||||||
graph_init_params: GraphInitParams,
|
graph_init_params: GraphInitParams,
|
||||||
@ -81,24 +64,6 @@ class BaseNode(ABC):
|
|||||||
else:
|
else:
|
||||||
yield from result
|
yield from result
|
||||||
|
|
||||||
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
|
|
||||||
"""
|
|
||||||
Publish text chunk
|
|
||||||
:param text: chunk text
|
|
||||||
:param value_selector: value selector
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# TODO remove callbacks
|
|
||||||
if self.callbacks:
|
|
||||||
for callback in self.callbacks:
|
|
||||||
callback.on_node_text_chunk(
|
|
||||||
node_id=self.node_id,
|
|
||||||
text=text,
|
|
||||||
metadata={
|
|
||||||
"value_selector": value_selector
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
|
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -170,7 +170,7 @@ class LLMNode(BaseNode):
|
|||||||
model_instance: ModelInstance,
|
model_instance: ModelInstance,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: list[PromptMessage],
|
||||||
stop: Optional[list[str]] = None) \
|
stop: Optional[list[str]] = None) \
|
||||||
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
|
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
:param node_data_model: node data model
|
:param node_data_model: node data model
|
||||||
@ -204,7 +204,7 @@ class LLMNode(BaseNode):
|
|||||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
|
||||||
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]:
|
-> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
|
||||||
"""
|
"""
|
||||||
Handle invoke result
|
Handle invoke result
|
||||||
:param invoke_result: invoke result
|
:param invoke_result: invoke result
|
||||||
|
0
api/core/workflow/nodes/test/__init__.py
Normal file
0
api/core/workflow/nodes/test/__init__.py
Normal file
8
api/core/workflow/nodes/test/entities.py
Normal file
8
api/core/workflow/nodes/test/entities.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class TestNodeData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Test Node Data.
|
||||||
|
"""
|
||||||
|
pass
|
33
api/core/workflow/nodes/test/test_node.py
Normal file
33
api/core/workflow/nodes/test/test_node.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
|
||||||
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||||
|
from core.workflow.nodes.base_node import BaseNode
|
||||||
|
from core.workflow.nodes.test.entities import TestNodeData
|
||||||
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
|
|
||||||
|
class TestNode(BaseNode):
|
||||||
|
_node_data_cls = TestNodeData
|
||||||
|
node_type = NodeType.ANSWER
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult:
|
||||||
|
"""
|
||||||
|
Run node
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
outputs={
|
||||||
|
"content": "abc"
|
||||||
|
},
|
||||||
|
edge_source_handle="1"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||||
|
"""
|
||||||
|
Extract variable selector to variable mapping
|
||||||
|
:param node_data: node data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return {}
|
@ -17,7 +17,8 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
|
|||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes
|
from core.workflow.nodes import node_classes
|
||||||
|
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||||
from core.workflow.nodes.iteration.entities import IterationState
|
from core.workflow.nodes.iteration.entities import IterationState
|
||||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||||
from core.workflow.nodes.start.start_node import StartNode
|
from core.workflow.nodes.start.start_node import StartNode
|
||||||
@ -93,7 +94,8 @@ class WorkflowEntry:
|
|||||||
call_depth=call_depth,
|
call_depth=call_depth,
|
||||||
graph=graph,
|
graph=graph,
|
||||||
variable_pool=variable_pool,
|
variable_pool=variable_pool,
|
||||||
callbacks=callbacks
|
max_execution_steps=current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS"),
|
||||||
|
max_execution_time=current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||||
)
|
)
|
||||||
|
|
||||||
# init workflow run
|
# init workflow run
|
||||||
|
@ -10,7 +10,8 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
|||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.workflow.entities.node_entities import NodeType
|
from core.workflow.entities.node_entities import NodeType
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager, node_classes
|
from core.workflow.nodes import node_classes
|
||||||
|
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||||
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
|
864
api/tests/unit_tests/core/workflow/graph_engine/test_graph.py
Normal file
864
api/tests/unit_tests/core/workflow/graph_engine/test_graph.py
Normal file
@ -0,0 +1,864 @@
|
|||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||||
|
from core.workflow.utils.condition.entities import Condition
|
||||||
|
|
||||||
|
|
||||||
|
def test_init():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "llm-source-answer-target",
|
||||||
|
"source": "llm",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-qc-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "qc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "qc-1-llm-target",
|
||||||
|
"source": "qc",
|
||||||
|
"sourceHandle": "1",
|
||||||
|
"target": "llm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "qc-2-http-target",
|
||||||
|
"source": "qc",
|
||||||
|
"sourceHandle": "2",
|
||||||
|
"target": "http",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "http-source-answer2-target",
|
||||||
|
"source": "http",
|
||||||
|
"target": "answer2",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "question-classifier"
|
||||||
|
},
|
||||||
|
"id": "qc",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "http-request",
|
||||||
|
},
|
||||||
|
"id": "http",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer2",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
start_node_id = "start"
|
||||||
|
|
||||||
|
assert graph.root_node_id == start_node_id
|
||||||
|
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
|
||||||
|
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
|
||||||
|
|
||||||
|
|
||||||
|
def test__init_iteration_graph():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "llm-answer",
|
||||||
|
"source": "llm",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "iteration-source-llm-target",
|
||||||
|
"source": "iteration",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
"target": "llm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
|
||||||
|
"source": "template-transform-in-iteration",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
"target": "llm-in-iteration",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm-in-iteration-source-answer-in-iteration-target",
|
||||||
|
"source": "llm-in-iteration",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
"target": "answer-in-iteration",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-code-target",
|
||||||
|
"source": "start",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
"target": "code",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code-source-iteration-target",
|
||||||
|
"source": "code",
|
||||||
|
"sourceHandle": "source",
|
||||||
|
"target": "iteration",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start",
|
||||||
|
},
|
||||||
|
"id": "start",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "iteration"
|
||||||
|
},
|
||||||
|
"id": "iteration",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "template-transform",
|
||||||
|
},
|
||||||
|
"id": "template-transform-in-iteration",
|
||||||
|
"parentId": "iteration",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm-in-iteration",
|
||||||
|
"parentId": "iteration",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer-in-iteration",
|
||||||
|
"parentId": "iteration",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config,
|
||||||
|
root_node_id="template-transform-in-iteration"
|
||||||
|
)
|
||||||
|
graph.add_extra_edge(
|
||||||
|
source_node_id="answer-in-iteration",
|
||||||
|
target_node_id="template-transform-in-iteration",
|
||||||
|
run_condition=RunCondition(
|
||||||
|
type="condition",
|
||||||
|
conditions=[
|
||||||
|
Condition(
|
||||||
|
variable_selector=["iteration", "index"],
|
||||||
|
comparison_operator="≤",
|
||||||
|
value="5"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# iteration:
|
||||||
|
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
|
||||||
|
|
||||||
|
assert graph.root_node_id == "template-transform-in-iteration"
|
||||||
|
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
|
||||||
|
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
|
||||||
|
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallels_graph():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-llm1-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm2-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm1-source-answer-target",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm2-source-answer-target",
|
||||||
|
"source": "llm2",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm3-source-answer-target",
|
||||||
|
"source": "llm3",
|
||||||
|
"target": "answer",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert graph.root_node_id == "start"
|
||||||
|
for i in range(3):
|
||||||
|
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
|
||||||
|
assert graph.edge_mapping.get(f"llm{i+1}") is not None
|
||||||
|
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
|
||||||
|
|
||||||
|
assert len(graph.parallel_mapping) == 1
|
||||||
|
assert len(graph.node_parallel_mapping) == 3
|
||||||
|
|
||||||
|
for node_id in ["llm1", "llm2", "llm3"]:
|
||||||
|
assert node_id in graph.node_parallel_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallels_graph2():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-llm1-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm2-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm1-source-answer-target",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm2-source-answer-target",
|
||||||
|
"source": "llm2",
|
||||||
|
"target": "answer",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert graph.root_node_id == "start"
|
||||||
|
for i in range(3):
|
||||||
|
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||||
|
|
||||||
|
if i < 2:
|
||||||
|
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||||
|
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
|
||||||
|
|
||||||
|
assert len(graph.parallel_mapping) == 1
|
||||||
|
assert len(graph.node_parallel_mapping) == 3
|
||||||
|
|
||||||
|
for node_id in ["llm1", "llm2", "llm3"]:
|
||||||
|
assert node_id in graph.node_parallel_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallels_graph3():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-llm1-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm2-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm3",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert graph.root_node_id == "start"
|
||||||
|
for i in range(3):
|
||||||
|
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||||
|
|
||||||
|
assert len(graph.parallel_mapping) == 1
|
||||||
|
assert len(graph.node_parallel_mapping) == 3
|
||||||
|
|
||||||
|
for node_id in ["llm1", "llm2", "llm3"]:
|
||||||
|
assert node_id in graph.node_parallel_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallels_graph4():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-llm1-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm2-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm1-source-answer-target",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "code1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm2-source-answer-target",
|
||||||
|
"source": "llm2",
|
||||||
|
"target": "code2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm3-source-code3-target",
|
||||||
|
"source": "llm3",
|
||||||
|
"target": "code3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code1-source-answer-target",
|
||||||
|
"source": "code1",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code2-source-answer-target",
|
||||||
|
"source": "code2",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code3-source-answer-target",
|
||||||
|
"source": "code3",
|
||||||
|
"target": "answer",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert graph.root_node_id == "start"
|
||||||
|
for i in range(3):
|
||||||
|
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||||
|
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||||
|
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
|
||||||
|
assert graph.edge_mapping.get(f"code{i + 1}") is not None
|
||||||
|
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
|
||||||
|
|
||||||
|
assert len(graph.parallel_mapping) == 1
|
||||||
|
assert len(graph.node_parallel_mapping) == 6
|
||||||
|
|
||||||
|
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||||
|
assert node_id in graph.node_parallel_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallels_graph5():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-llm1-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm2-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm5",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm1-source-code1-target",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "code1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm2-source-code1-target",
|
||||||
|
"source": "llm2",
|
||||||
|
"target": "code1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm3-source-code2-target",
|
||||||
|
"source": "llm3",
|
||||||
|
"target": "code2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm4-source-code2-target",
|
||||||
|
"source": "llm4",
|
||||||
|
"target": "code2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm5-source-code3-target",
|
||||||
|
"source": "llm5",
|
||||||
|
"target": "code3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code1-source-answer-target",
|
||||||
|
"source": "code1",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code2-source-answer-target",
|
||||||
|
"source": "code2",
|
||||||
|
"target": "answer",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm4"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm5"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert graph.root_node_id == "start"
|
||||||
|
for i in range(5):
|
||||||
|
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||||
|
|
||||||
|
assert graph.edge_mapping.get("llm1") is not None
|
||||||
|
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||||
|
assert graph.edge_mapping.get("llm2") is not None
|
||||||
|
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
|
||||||
|
assert graph.edge_mapping.get("llm3") is not None
|
||||||
|
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
|
||||||
|
assert graph.edge_mapping.get("llm4") is not None
|
||||||
|
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
|
||||||
|
assert graph.edge_mapping.get("llm5") is not None
|
||||||
|
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
|
||||||
|
assert graph.edge_mapping.get("code1") is not None
|
||||||
|
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||||
|
assert graph.edge_mapping.get("code2") is not None
|
||||||
|
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||||
|
|
||||||
|
assert len(graph.parallel_mapping) == 1
|
||||||
|
assert len(graph.node_parallel_mapping) == 8
|
||||||
|
|
||||||
|
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
|
||||||
|
assert node_id in graph.node_parallel_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_parallels_graph6():
|
||||||
|
graph_config = {
|
||||||
|
"edges": [
|
||||||
|
{
|
||||||
|
"id": "start-source-llm1-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm2-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "start-source-llm3-target",
|
||||||
|
"source": "start",
|
||||||
|
"target": "llm3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm1-source-code1-target",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "code1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm1-source-code2-target",
|
||||||
|
"source": "llm1",
|
||||||
|
"target": "code2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "llm2-source-code3-target",
|
||||||
|
"source": "llm2",
|
||||||
|
"target": "code3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code1-source-answer-target",
|
||||||
|
"source": "code1",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code2-source-answer-target",
|
||||||
|
"source": "code2",
|
||||||
|
"target": "answer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "code3-source-answer-target",
|
||||||
|
"source": "code3",
|
||||||
|
"target": "answer",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nodes": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "start"
|
||||||
|
},
|
||||||
|
"id": "start"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code2"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "llm",
|
||||||
|
},
|
||||||
|
"id": "llm3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "code",
|
||||||
|
},
|
||||||
|
"id": "code3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"type": "answer",
|
||||||
|
},
|
||||||
|
"id": "answer",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
graph = Graph.init(
|
||||||
|
graph_config=graph_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert graph.root_node_id == "start"
|
||||||
|
for i in range(3):
|
||||||
|
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||||
|
|
||||||
|
assert graph.edge_mapping.get("llm1") is not None
|
||||||
|
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||||
|
assert graph.edge_mapping.get("llm1") is not None
|
||||||
|
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
|
||||||
|
assert graph.edge_mapping.get("llm2") is not None
|
||||||
|
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
|
||||||
|
assert graph.edge_mapping.get("code1") is not None
|
||||||
|
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||||
|
assert graph.edge_mapping.get("code2") is not None
|
||||||
|
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||||
|
assert graph.edge_mapping.get("code3") is not None
|
||||||
|
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
|
||||||
|
|
||||||
|
assert len(graph.parallel_mapping) == 2
|
||||||
|
assert len(graph.node_parallel_mapping) == 6
|
||||||
|
|
||||||
|
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||||
|
assert node_id in graph.node_parallel_mapping
|
||||||
|
|
||||||
|
parent_parallel = None
|
||||||
|
child_parallel = None
|
||||||
|
for p_id, parallel in graph.parallel_mapping.items():
|
||||||
|
if parallel.parent_parallel_id is None:
|
||||||
|
parent_parallel = parallel
|
||||||
|
else:
|
||||||
|
child_parallel = parallel
|
||||||
|
|
||||||
|
for node_id in ["llm1", "llm2", "llm3", "code3"]:
|
||||||
|
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
|
||||||
|
|
||||||
|
for node_id in ["code1", "code2"]:
|
||||||
|
assert graph.node_parallel_mapping[node_id] == child_parallel.id
|
@ -1,9 +1,16 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.workflow.entities.node_entities import SystemVariable, UserFrom
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.graph import Graph
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||||
from core.workflow.utils.condition.entities import Condition
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
|
|
||||||
def test_init():
|
@patch('extensions.ext_database.db.session.remove')
|
||||||
|
@patch('extensions.ext_database.db.session.close')
|
||||||
|
def test_run(mock_close, mock_remove):
|
||||||
graph_config = {
|
graph_config = {
|
||||||
"edges": [
|
"edges": [
|
||||||
{
|
{
|
||||||
@ -37,37 +44,43 @@ def test_init():
|
|||||||
"nodes": [
|
"nodes": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "start"
|
"type": "start",
|
||||||
|
"title": "start"
|
||||||
},
|
},
|
||||||
"id": "start"
|
"id": "start"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "llm",
|
"type": "llm",
|
||||||
|
"title": "llm"
|
||||||
},
|
},
|
||||||
"id": "llm"
|
"id": "llm"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer"
|
||||||
},
|
},
|
||||||
"id": "answer",
|
"id": "answer",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "question-classifier"
|
"type": "question-classifier",
|
||||||
|
"title": "qc"
|
||||||
},
|
},
|
||||||
"id": "qc",
|
"id": "qc",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "http-request",
|
"type": "http-request",
|
||||||
|
"title": "http"
|
||||||
},
|
},
|
||||||
"id": "http",
|
"id": "http",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"type": "answer",
|
"type": "answer",
|
||||||
|
"title": "answer2"
|
||||||
},
|
},
|
||||||
"id": "answer2",
|
"id": "answer2",
|
||||||
}
|
}
|
||||||
@ -78,787 +91,30 @@ def test_init():
|
|||||||
graph_config=graph_config
|
graph_config=graph_config
|
||||||
)
|
)
|
||||||
|
|
||||||
start_node_id = "start"
|
variable_pool = VariablePool(system_variables={
|
||||||
|
SystemVariable.QUERY: 'what\'s the weather in SF',
|
||||||
|
SystemVariable.FILES: [],
|
||||||
|
SystemVariable.CONVERSATION_ID: 'abababa',
|
||||||
|
SystemVariable.USER_ID: 'aaa'
|
||||||
|
}, user_inputs={})
|
||||||
|
|
||||||
assert graph.root_node_id == start_node_id
|
graph_engine = GraphEngine(
|
||||||
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
|
tenant_id="111",
|
||||||
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
|
app_id="222",
|
||||||
|
workflow_type=WorkflowType.CHAT,
|
||||||
|
workflow_id="333",
|
||||||
def test__init_iteration_graph():
|
user_id="444",
|
||||||
graph_config = {
|
user_from=UserFrom.ACCOUNT,
|
||||||
"edges": [
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
{
|
call_depth=0,
|
||||||
"id": "llm-answer",
|
graph=graph,
|
||||||
"source": "llm",
|
variable_pool=variable_pool,
|
||||||
"sourceHandle": "source",
|
max_execution_steps=500,
|
||||||
"target": "answer",
|
max_execution_time=1200
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "iteration-source-llm-target",
|
|
||||||
"source": "iteration",
|
|
||||||
"sourceHandle": "source",
|
|
||||||
"target": "llm",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
|
|
||||||
"source": "template-transform-in-iteration",
|
|
||||||
"sourceHandle": "source",
|
|
||||||
"target": "llm-in-iteration",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm-in-iteration-source-answer-in-iteration-target",
|
|
||||||
"source": "llm-in-iteration",
|
|
||||||
"sourceHandle": "source",
|
|
||||||
"target": "answer-in-iteration",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-code-target",
|
|
||||||
"source": "start",
|
|
||||||
"sourceHandle": "source",
|
|
||||||
"target": "code",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code-source-iteration-target",
|
|
||||||
"source": "code",
|
|
||||||
"sourceHandle": "source",
|
|
||||||
"target": "iteration",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "start",
|
|
||||||
},
|
|
||||||
"id": "start",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "iteration"
|
|
||||||
},
|
|
||||||
"id": "iteration",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "template-transform",
|
|
||||||
},
|
|
||||||
"id": "template-transform-in-iteration",
|
|
||||||
"parentId": "iteration",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm-in-iteration",
|
|
||||||
"parentId": "iteration",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer-in-iteration",
|
|
||||||
"parentId": "iteration",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code",
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = Graph.init(
|
|
||||||
graph_config=graph_config,
|
|
||||||
root_node_id="template-transform-in-iteration"
|
|
||||||
)
|
|
||||||
graph.add_extra_edge(
|
|
||||||
source_node_id="answer-in-iteration",
|
|
||||||
target_node_id="template-transform-in-iteration",
|
|
||||||
run_condition=RunCondition(
|
|
||||||
type="condition",
|
|
||||||
conditions=[
|
|
||||||
Condition(
|
|
||||||
variable_selector=["iteration", "index"],
|
|
||||||
comparison_operator="≤",
|
|
||||||
value="5"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# iteration:
|
print("")
|
||||||
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
|
|
||||||
|
|
||||||
assert graph.root_node_id == "template-transform-in-iteration"
|
generator = graph_engine.run()
|
||||||
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
|
for item in generator:
|
||||||
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
|
print(type(item), item)
|
||||||
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
|
|
||||||
|
|
||||||
|
|
||||||
def test_parallels_graph():
|
|
||||||
graph_config = {
|
|
||||||
"edges": [
|
|
||||||
{
|
|
||||||
"id": "start-source-llm1-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm2-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm1-source-answer-target",
|
|
||||||
"source": "llm1",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm2-source-answer-target",
|
|
||||||
"source": "llm2",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm3-source-answer-target",
|
|
||||||
"source": "llm3",
|
|
||||||
"target": "answer",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "start"
|
|
||||||
},
|
|
||||||
"id": "start"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = Graph.init(
|
|
||||||
graph_config=graph_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph.root_node_id == "start"
|
|
||||||
for i in range(3):
|
|
||||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
|
|
||||||
assert graph.edge_mapping.get(f"llm{i+1}") is not None
|
|
||||||
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
|
|
||||||
|
|
||||||
assert len(graph.parallel_mapping) == 1
|
|
||||||
assert len(graph.node_parallel_mapping) == 3
|
|
||||||
|
|
||||||
for node_id in ["llm1", "llm2", "llm3"]:
|
|
||||||
assert node_id in graph.node_parallel_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_parallels_graph2():
|
|
||||||
graph_config = {
|
|
||||||
"edges": [
|
|
||||||
{
|
|
||||||
"id": "start-source-llm1-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm2-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm1-source-answer-target",
|
|
||||||
"source": "llm1",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm2-source-answer-target",
|
|
||||||
"source": "llm2",
|
|
||||||
"target": "answer",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "start"
|
|
||||||
},
|
|
||||||
"id": "start"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = Graph.init(
|
|
||||||
graph_config=graph_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph.root_node_id == "start"
|
|
||||||
for i in range(3):
|
|
||||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
|
||||||
|
|
||||||
if i < 2:
|
|
||||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
|
||||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
|
|
||||||
|
|
||||||
assert len(graph.parallel_mapping) == 1
|
|
||||||
assert len(graph.node_parallel_mapping) == 3
|
|
||||||
|
|
||||||
for node_id in ["llm1", "llm2", "llm3"]:
|
|
||||||
assert node_id in graph.node_parallel_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_parallels_graph3():
|
|
||||||
graph_config = {
|
|
||||||
"edges": [
|
|
||||||
{
|
|
||||||
"id": "start-source-llm1-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm2-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm3",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "start"
|
|
||||||
},
|
|
||||||
"id": "start"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = Graph.init(
|
|
||||||
graph_config=graph_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph.root_node_id == "start"
|
|
||||||
for i in range(3):
|
|
||||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
|
||||||
|
|
||||||
assert len(graph.parallel_mapping) == 1
|
|
||||||
assert len(graph.node_parallel_mapping) == 3
|
|
||||||
|
|
||||||
for node_id in ["llm1", "llm2", "llm3"]:
|
|
||||||
assert node_id in graph.node_parallel_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_parallels_graph4():
|
|
||||||
graph_config = {
|
|
||||||
"edges": [
|
|
||||||
{
|
|
||||||
"id": "start-source-llm1-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm2-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm1-source-answer-target",
|
|
||||||
"source": "llm1",
|
|
||||||
"target": "code1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm2-source-answer-target",
|
|
||||||
"source": "llm2",
|
|
||||||
"target": "code2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm3-source-code3-target",
|
|
||||||
"source": "llm3",
|
|
||||||
"target": "code3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code1-source-answer-target",
|
|
||||||
"source": "code1",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code2-source-answer-target",
|
|
||||||
"source": "code2",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code3-source-answer-target",
|
|
||||||
"source": "code3",
|
|
||||||
"target": "answer",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "start"
|
|
||||||
},
|
|
||||||
"id": "start"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = Graph.init(
|
|
||||||
graph_config=graph_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph.root_node_id == "start"
|
|
||||||
for i in range(3):
|
|
||||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
|
||||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
|
||||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
|
|
||||||
assert graph.edge_mapping.get(f"code{i + 1}") is not None
|
|
||||||
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
|
|
||||||
|
|
||||||
assert len(graph.parallel_mapping) == 1
|
|
||||||
assert len(graph.node_parallel_mapping) == 6
|
|
||||||
|
|
||||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
|
||||||
assert node_id in graph.node_parallel_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_parallels_graph5():
|
|
||||||
graph_config = {
|
|
||||||
"edges": [
|
|
||||||
{
|
|
||||||
"id": "start-source-llm1-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm2-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm4",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm5",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm1-source-code1-target",
|
|
||||||
"source": "llm1",
|
|
||||||
"target": "code1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm2-source-code1-target",
|
|
||||||
"source": "llm2",
|
|
||||||
"target": "code1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm3-source-code2-target",
|
|
||||||
"source": "llm3",
|
|
||||||
"target": "code2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm4-source-code2-target",
|
|
||||||
"source": "llm4",
|
|
||||||
"target": "code2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm5-source-code3-target",
|
|
||||||
"source": "llm5",
|
|
||||||
"target": "code3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code1-source-answer-target",
|
|
||||||
"source": "code1",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code2-source-answer-target",
|
|
||||||
"source": "code2",
|
|
||||||
"target": "answer",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "start"
|
|
||||||
},
|
|
||||||
"id": "start"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm4"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm5"
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = Graph.init(
|
|
||||||
graph_config=graph_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph.root_node_id == "start"
|
|
||||||
for i in range(5):
|
|
||||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
|
||||||
|
|
||||||
assert graph.edge_mapping.get("llm1") is not None
|
|
||||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
|
||||||
assert graph.edge_mapping.get("llm2") is not None
|
|
||||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
|
|
||||||
assert graph.edge_mapping.get("llm3") is not None
|
|
||||||
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
|
|
||||||
assert graph.edge_mapping.get("llm4") is not None
|
|
||||||
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
|
|
||||||
assert graph.edge_mapping.get("llm5") is not None
|
|
||||||
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
|
|
||||||
assert graph.edge_mapping.get("code1") is not None
|
|
||||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
|
||||||
assert graph.edge_mapping.get("code2") is not None
|
|
||||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
|
||||||
|
|
||||||
assert len(graph.parallel_mapping) == 1
|
|
||||||
assert len(graph.node_parallel_mapping) == 8
|
|
||||||
|
|
||||||
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
|
|
||||||
assert node_id in graph.node_parallel_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_parallels_graph6():
|
|
||||||
graph_config = {
|
|
||||||
"edges": [
|
|
||||||
{
|
|
||||||
"id": "start-source-llm1-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm2-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "start-source-llm3-target",
|
|
||||||
"source": "start",
|
|
||||||
"target": "llm3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm1-source-code1-target",
|
|
||||||
"source": "llm1",
|
|
||||||
"target": "code1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm1-source-code2-target",
|
|
||||||
"source": "llm1",
|
|
||||||
"target": "code2",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "llm2-source-code3-target",
|
|
||||||
"source": "llm2",
|
|
||||||
"target": "code3",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code1-source-answer-target",
|
|
||||||
"source": "code1",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code2-source-answer-target",
|
|
||||||
"source": "code2",
|
|
||||||
"target": "answer",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "code3-source-answer-target",
|
|
||||||
"source": "code3",
|
|
||||||
"target": "answer",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"nodes": [
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "start"
|
|
||||||
},
|
|
||||||
"id": "start"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code1"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code2"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "llm",
|
|
||||||
},
|
|
||||||
"id": "llm3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "code",
|
|
||||||
},
|
|
||||||
"id": "code3"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"type": "answer",
|
|
||||||
},
|
|
||||||
"id": "answer",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
graph = Graph.init(
|
|
||||||
graph_config=graph_config
|
|
||||||
)
|
|
||||||
|
|
||||||
assert graph.root_node_id == "start"
|
|
||||||
for i in range(3):
|
|
||||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
|
||||||
|
|
||||||
assert graph.edge_mapping.get("llm1") is not None
|
|
||||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
|
||||||
assert graph.edge_mapping.get("llm1") is not None
|
|
||||||
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
|
|
||||||
assert graph.edge_mapping.get("llm2") is not None
|
|
||||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
|
|
||||||
assert graph.edge_mapping.get("code1") is not None
|
|
||||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
|
||||||
assert graph.edge_mapping.get("code2") is not None
|
|
||||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
|
||||||
assert graph.edge_mapping.get("code3") is not None
|
|
||||||
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
|
|
||||||
|
|
||||||
assert len(graph.parallel_mapping) == 2
|
|
||||||
assert len(graph.node_parallel_mapping) == 6
|
|
||||||
|
|
||||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
|
||||||
assert node_id in graph.node_parallel_mapping
|
|
||||||
|
|
||||||
parent_parallel = None
|
|
||||||
child_parallel = None
|
|
||||||
for p_id, parallel in graph.parallel_mapping.items():
|
|
||||||
if parallel.parent_parallel_id is None:
|
|
||||||
parent_parallel = parallel
|
|
||||||
else:
|
|
||||||
child_parallel = parallel
|
|
||||||
|
|
||||||
for node_id in ["llm1", "llm2", "llm3", "code3"]:
|
|
||||||
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
|
|
||||||
|
|
||||||
for node_id in ["code1", "code2"]:
|
|
||||||
assert graph.node_parallel_mapping[node_id] == child_parallel.id
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user