add graph engine test

This commit is contained in:
takatost 2024-07-16 16:37:37 +08:00
parent 00fb23d0c9
commit 00ec36d47c
17 changed files with 1122 additions and 904 deletions

View File

@ -3,20 +3,6 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -55,25 +41,6 @@ class NodeType(Enum):
raise ValueError(f'invalid node type value {value}') raise ValueError(f'invalid node type value {value}')
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
class SystemVariable(Enum): class SystemVariable(Enum):
""" """
System Variables. System Variables.

View File

@ -1,23 +1,31 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.graph_engine.entities.run_condition import RunCondition
class RunConditionHandler(ABC): class RunConditionHandler(ABC):
def __init__(self, condition: RunCondition): def __init__(self,
init_params: GraphInitParams,
graph: Graph,
condition: RunCondition):
self.init_params = init_params
self.graph = graph
self.condition = condition self.condition = condition
@abstractmethod @abstractmethod
def check(self, def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str, source_node_id: str,
target_node_id: str, target_node_id: str) -> bool:
graph: "Graph") -> bool:
""" """
Check if the condition can be executed Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id :param source_node_id: source node id
:param target_node_id: target node id :param target_node_id: target node id
:param graph: graph
:return: bool :return: bool
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -1,29 +1,33 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
class BranchIdentifyRunConditionHandler(RunConditionHandler): class BranchIdentifyRunConditionHandler(RunConditionHandler):
def check(self, def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str, source_node_id: str,
target_node_id: str, target_node_id: str) -> bool:
graph: "Graph") -> bool:
""" """
Check if the condition can be executed Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id :param source_node_id: source node id
:param target_node_id: target node id :param target_node_id: target node id
:param graph: graph
:return: bool :return: bool
""" """
if not self.condition.branch_identify: if not self.condition.branch_identify:
raise Exception("Branch identify is required") raise Exception("Branch identify is required")
run_state = graph.run_state node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id)
node_route_result = run_state.node_route_results.get(source_node_id) if not node_route_state:
if not node_route_result:
return False return False
if not node_route_result.edge_source_handle: run_result = node_route_state.node_run_result
if not run_result:
return False return False
return self.condition.branch_identify == node_route_result.edge_source_handle if not run_result.edge_source_handle:
return False
return self.condition.branch_identify == run_result.edge_source_handle

View File

@ -1,18 +1,19 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.utils.condition.processor import ConditionProcessor from core.workflow.utils.condition.processor import ConditionProcessor
class ConditionRunConditionHandlerHandler(RunConditionHandler): class ConditionRunConditionHandlerHandler(RunConditionHandler):
def check(self, def check(self,
graph_runtime_state: GraphRuntimeState,
source_node_id: str, source_node_id: str,
target_node_id: str, target_node_id: str) -> bool:
graph: "Graph") -> bool:
""" """
Check if the condition can be executed Check if the condition can be executed
:param graph_runtime_state: graph runtime state
:param source_node_id: source node id :param source_node_id: source node id
:param target_node_id: target node id :param target_node_id: target node id
:param graph: graph
:return: bool :return: bool
""" """
if not self.condition.conditions: if not self.condition.conditions:
@ -21,10 +22,9 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler):
# process condition # process condition
condition_processor = ConditionProcessor() condition_processor = ConditionProcessor()
compare_result, _ = condition_processor.process( compare_result, _ = condition_processor.process(
variable_pool=graph.run_state.variable_pool, variable_pool=graph_runtime_state.variable_pool,
logical_operator="and", logical_operator="and",
conditions=self.condition.conditions conditions=self.condition.conditions
) )
return compare_result return compare_result

View File

@ -1,19 +1,35 @@
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.graph_engine.entities.run_condition import RunCondition
class ConditionManager: class ConditionManager:
@staticmethod @staticmethod
def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler: def get_condition_handler(
init_params: GraphInitParams,
graph: Graph,
run_condition: RunCondition
) -> RunConditionHandler:
""" """
Get condition handler Get condition handler
:param init_params: init params
:param graph: graph
:param run_condition: run condition :param run_condition: run condition
:return: condition handler :return: condition handler
""" """
if run_condition.type == "branch_identify": if run_condition.type == "branch_identify":
return BranchIdentifyRunConditionHandler(run_condition) return BranchIdentifyRunConditionHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)
else: else:
return ConditionRunConditionHandlerHandler(run_condition) return ConditionRunConditionHandlerHandler(
init_params=init_params,
graph=graph,
condition=run_condition
)

View File

@ -3,6 +3,7 @@ from typing import Optional
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from models.workflow import WorkflowNodeExecutionStatus
class GraphEngineEvent(BaseModel): class GraphEngineEvent(BaseModel):
@ -50,10 +51,12 @@ class NodeRunStartedEvent(BaseNodeEvent):
class NodeRunStreamChunkEvent(BaseNodeEvent): class NodeRunStreamChunkEvent(BaseNodeEvent):
chunk_content: str = Field(..., description="chunk content") chunk_content: str = Field(..., description="chunk content")
from_variable_selector: list[str] = Field(..., description="from variable selector")
class NodeRunRetrieverResourceEvent(BaseNodeEvent): class NodeRunRetrieverResourceEvent(BaseNodeEvent):
retriever_resources: list[dict] = Field(..., description="retriever resources") retriever_resources: list[dict] = Field(..., description="retriever resources")
context: str = Field(..., description="context")
class NodeRunSucceededEvent(BaseNodeEvent): class NodeRunSucceededEvent(BaseNodeEvent):
@ -61,7 +64,10 @@ class NodeRunSucceededEvent(BaseNodeEvent):
class NodeRunFailedEvent(BaseNodeEvent): class NodeRunFailedEvent(BaseNodeEvent):
run_result: NodeRunResult = Field(..., description="run result") run_result: NodeRunResult = Field(
default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED),
description="run result"
)
reason: str = Field("", description="failed reason") reason: str = Field("", description="failed reason")
@model_validator(mode='before') @model_validator(mode='before')

View File

@ -3,13 +3,13 @@ import queue
import time import time
from collections.abc import Generator from collections.abc import Generator
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional, cast from datetime import datetime, timezone
from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
@ -19,14 +19,21 @@ from core.workflow.graph_engine.entities.event import (
GraphRunStartedEvent, GraphRunStartedEvent,
GraphRunSucceededEvent, GraphRunSucceededEvent,
NodeRunFailedEvent, NodeRunFailedEvent,
NodeRunRetrieverResourceEvent,
NodeRunStartedEvent, NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.base_node import UserFrom, node_classes from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
from core.workflow.nodes import node_classes
from core.workflow.nodes.base_node import UserFrom
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.test.test_node import TestNode
from extensions.ext_database import db from extensions.ext_database import db
from models.workflow import WorkflowType from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun") thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -43,7 +50,8 @@ class GraphEngine:
call_depth: int, call_depth: int,
graph: Graph, graph: Graph,
variable_pool: VariablePool, variable_pool: VariablePool,
callbacks: list[BaseWorkflowCallback]) -> None: max_execution_steps: int,
max_execution_time: int) -> None:
self.graph = graph self.graph = graph
self.init_params = GraphInitParams( self.init_params = GraphInitParams(
tenant_id=tenant_id, tenant_id=tenant_id,
@ -61,12 +69,8 @@ class GraphEngine:
start_at=time.perf_counter() start_at=time.perf_counter()
) )
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") self.max_execution_steps = max_execution_steps
self.max_execution_steps = cast(int, max_execution_steps) self.max_execution_time = max_execution_time
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
self.max_execution_time = cast(int, max_execution_time)
self.callbacks = callbacks
def run_in_block_mode(self): def run_in_block_mode(self):
# TODO convert generator to result # TODO convert generator to result
@ -92,7 +96,7 @@ class GraphEngine:
return return
except Exception as e: except Exception as e:
yield GraphRunFailedEvent(reason=str(e)) yield GraphRunFailedEvent(reason=str(e))
return raise e
def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]:
next_node_id = start_node_id next_node_id = start_node_id
@ -118,7 +122,7 @@ class GraphEngine:
) )
except Exception as e: except Exception as e:
yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e)) yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e))
return raise e
previous_node_id = next_node_id previous_node_id = next_node_id
@ -141,11 +145,13 @@ class GraphEngine:
for edge in edge_mappings: for edge in edge_mappings:
if edge.run_condition: if edge.run_condition:
result = ConditionManager.get_condition_handler( result = ConditionManager.get_condition_handler(
run_condition=edge.run_condition init_params=self.init_params,
graph=self.graph,
run_condition=edge.run_condition,
).check( ).check(
graph_runtime_state=self.graph_runtime_state,
source_node_id=edge.source_node_id, source_node_id=edge.source_node_id,
target_node_id=edge.target_node_id, target_node_id=edge.target_node_id,
graph=self.graph
) )
if result: if result:
@ -250,7 +256,16 @@ class GraphEngine:
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
# init workflow run state # init workflow run state
node_instance = node_cls( # type: ignore # node_instance = node_cls( # type: ignore
# config=node_config,
# graph_init_params=self.init_params,
# graph=self.graph,
# graph_runtime_state=self.graph_runtime_state,
# previous_node_id=previous_node_id
# )
# init workflow run state
node_instance = TestNode(
config=node_config, config=node_config,
graph_init_params=self.init_params, graph_init_params=self.init_params,
graph=self.graph, graph=self.graph,
@ -268,24 +283,64 @@ class GraphEngine:
self.graph_runtime_state.node_run_steps += 1 self.graph_runtime_state.node_run_steps += 1
try: try:
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
# run node # run node
generator = node_instance.run() generator = node_instance.run()
run_result = None
for item in generator:
if isinstance(item, RunCompletedEvent):
run_result = item.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
yield NodeRunFailedEvent(
node_id=node_id,
parallel_id=parallel_id,
run_result=run_result,
reason=run_result.error
)
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
yield NodeRunSucceededEvent(
node_id=node_id,
parallel_id=parallel_id,
run_result=run_result
)
yield from generator self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState(
node_id=node_id,
start_at=start_at,
status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
else RouteNodeState.Status.FAILED,
finished_at=datetime.now(timezone.utc).replace(tzinfo=None),
node_run_result=run_result,
failed_reason=run_result.error
if run_result.status == WorkflowNodeExecutionStatus.FAILED else None
)
# todo append self.graph_runtime_state.node_run_state.routes
break
elif isinstance(item, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
node_id=node_id,
parallel_id=parallel_id,
chunk_content=item.chunk_content,
from_variable_selector=item.from_variable_selector,
)
elif isinstance(item, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
node_id=node_id,
parallel_id=parallel_id,
retriever_resources=item.retriever_resources,
context=item.context
)
# todo record state # todo record state
# trigger node run success event
yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id)
except GenerateTaskStoppedException as e: except GenerateTaskStoppedException as e:
# trigger node run failed event # trigger node run failed event
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e)) yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return return
except Exception as e: except Exception as e:
# todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
# trigger node run failed event raise e
yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e))
return
finally: finally:
db.session.close() db.session.close()

View File

@ -0,0 +1,33 @@
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
node_classes = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.ANSWER: AnswerNode,
NodeType.LLM: LLMNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.CODE: CodeNode,
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}

View File

@ -2,37 +2,20 @@ from abc import ABC, abstractmethod
from collections.abc import Generator from collections.abc import Generator
from typing import Optional from typing import Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.iterable_node import IterableNodeMixin from core.workflow.nodes.iterable_node import IterableNodeMixin
from models.workflow import WorkflowType
class BaseNode(ABC): class BaseNode(ABC):
_node_data_cls: type[BaseNodeData] _node_data_cls: type[BaseNodeData]
_node_type: NodeType _node_type: NodeType
tenant_id: str
app_id: str
workflow_type: WorkflowType
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
graph: Graph
graph_runtime_state: GraphRuntimeState
previous_node_id: Optional[str] = None
node_id: str
node_data: BaseNodeData
def __init__(self, def __init__(self,
config: dict, config: dict,
graph_init_params: GraphInitParams, graph_init_params: GraphInitParams,
@ -81,24 +64,6 @@ class BaseNode(ABC):
else: else:
yield from result yield from result
def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None:
"""
Publish text chunk
:param text: chunk text
:param value_selector: value selector
:return:
"""
# TODO remove callbacks
if self.callbacks:
for callback in self.callbacks:
callback.on_node_text_chunk(
node_id=self.node_id,
text=text,
metadata={
"value_selector": value_selector
}
)
@classmethod @classmethod
def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]: def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]:
""" """

View File

@ -170,7 +170,7 @@ class LLMNode(BaseNode):
model_instance: ModelInstance, model_instance: ModelInstance,
prompt_messages: list[PromptMessage], prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) \ stop: Optional[list[str]] = None) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]: -> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
""" """
Invoke large language model Invoke large language model
:param node_data_model: node data model :param node_data_model: node data model
@ -204,7 +204,7 @@ class LLMNode(BaseNode):
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \
-> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]: -> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]:
""" """
Handle invoke result Handle invoke result
:param invoke_result: invoke result :param invoke_result: invoke result

View File

View File

@ -0,0 +1,8 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class TestNodeData(BaseNodeData):
"""
Test Node Data.
"""
pass

View File

@ -0,0 +1,33 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.test.entities import TestNodeData
from models.workflow import WorkflowNodeExecutionStatus
class TestNode(BaseNode):
_node_data_cls = TestNodeData
node_type = NodeType.ANSWER
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"content": "abc"
},
edge_source_handle="1"
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {}

View File

@ -17,7 +17,8 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes from core.workflow.nodes import node_classes
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.start.start_node import StartNode from core.workflow.nodes.start.start_node import StartNode
@ -93,7 +94,8 @@ class WorkflowEntry:
call_depth=call_depth, call_depth=call_depth,
graph=graph, graph=graph,
variable_pool=variable_pool, variable_pool=variable_pool,
callbacks=callbacks max_execution_steps=current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS"),
max_execution_time=current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
) )
# init workflow run # init workflow run

View File

@ -10,7 +10,8 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType from core.workflow.entities.node_entities import NodeType
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.workflow_engine_manager import WorkflowEngineManager, node_classes from core.workflow.nodes import node_classes
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account

View File

@ -0,0 +1,864 @@
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition
from core.workflow.utils.condition.entities import Condition
def test_init():
graph_config = {
"edges": [
{
"id": "llm-source-answer-target",
"source": "llm",
"target": "answer",
},
{
"id": "start-source-qc-target",
"source": "start",
"target": "qc",
},
{
"id": "qc-1-llm-target",
"source": "qc",
"sourceHandle": "1",
"target": "llm",
},
{
"id": "qc-2-http-target",
"source": "qc",
"sourceHandle": "2",
"target": "http",
},
{
"id": "http-source-answer2-target",
"source": "http",
"target": "answer2",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "question-classifier"
},
"id": "qc",
},
{
"data": {
"type": "http-request",
},
"id": "http",
},
{
"data": {
"type": "answer",
},
"id": "answer2",
}
],
}
graph = Graph.init(
graph_config=graph_config
)
start_node_id = "start"
assert graph.root_node_id == start_node_id
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
def test__init_iteration_graph():
graph_config = {
"edges": [
{
"id": "llm-answer",
"source": "llm",
"sourceHandle": "source",
"target": "answer",
},
{
"id": "iteration-source-llm-target",
"source": "iteration",
"sourceHandle": "source",
"target": "llm",
},
{
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
"source": "template-transform-in-iteration",
"sourceHandle": "source",
"target": "llm-in-iteration",
},
{
"id": "llm-in-iteration-source-answer-in-iteration-target",
"source": "llm-in-iteration",
"sourceHandle": "source",
"target": "answer-in-iteration",
},
{
"id": "start-source-code-target",
"source": "start",
"sourceHandle": "source",
"target": "code",
},
{
"id": "code-source-iteration-target",
"source": "code",
"sourceHandle": "source",
"target": "iteration",
}
],
"nodes": [
{
"data": {
"type": "start",
},
"id": "start",
},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "iteration"
},
"id": "iteration",
},
{
"data": {
"type": "template-transform",
},
"id": "template-transform-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "llm",
},
"id": "llm-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "answer",
},
"id": "answer-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "code",
},
"id": "code",
}
]
}
graph = Graph.init(
graph_config=graph_config,
root_node_id="template-transform-in-iteration"
)
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=["iteration", "index"],
comparison_operator="",
value="5"
)
]
)
)
# iteration:
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
assert graph.root_node_id == "template-transform-in-iteration"
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
def test_parallels_graph():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
assert graph.edge_mapping.get(f"llm{i+1}") is not None
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph2():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
if i < 2:
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph3():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph4():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "code2",
},
{
"id": "llm3-source-code3-target",
"source": "llm3",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
assert graph.edge_mapping.get(f"code{i + 1}") is not None
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph5():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm4",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm5",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-code1-target",
"source": "llm2",
"target": "code1",
},
{
"id": "llm3-source-code2-target",
"source": "llm3",
"target": "code2",
},
{
"id": "llm4-source-code2-target",
"source": "llm4",
"target": "code2",
},
{
"id": "llm5-source-code3-target",
"source": "llm5",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "llm",
},
"id": "llm4"
},
{
"data": {
"type": "llm",
},
"id": "llm5"
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(5):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm3") is not None
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm4") is not None
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm5") is not None
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 8
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph6():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm1-source-code2-target",
"source": "llm1",
"target": "code2",
},
{
"id": "llm2-source-code3-target",
"source": "llm2",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code3") is not None
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 2
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
parent_parallel = None
child_parallel = None
for p_id, parallel in graph.parallel_mapping.items():
if parallel.parent_parallel_id is None:
parent_parallel = parallel
else:
child_parallel = parallel
for node_id in ["llm1", "llm2", "llm3", "code3"]:
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
for node_id in ["code1", "code2"]:
assert graph.node_parallel_mapping[node_id] == child_parallel.id

View File

@ -1,9 +1,16 @@
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import SystemVariable, UserFrom
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.run_condition import RunCondition from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.utils.condition.entities import Condition from models.workflow import WorkflowType
def test_init(): @patch('extensions.ext_database.db.session.remove')
@patch('extensions.ext_database.db.session.close')
def test_run(mock_close, mock_remove):
graph_config = { graph_config = {
"edges": [ "edges": [
{ {
@ -37,37 +44,43 @@ def test_init():
"nodes": [ "nodes": [
{ {
"data": { "data": {
"type": "start" "type": "start",
"title": "start"
}, },
"id": "start" "id": "start"
}, },
{ {
"data": { "data": {
"type": "llm", "type": "llm",
"title": "llm"
}, },
"id": "llm" "id": "llm"
}, },
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer"
}, },
"id": "answer", "id": "answer",
}, },
{ {
"data": { "data": {
"type": "question-classifier" "type": "question-classifier",
"title": "qc"
}, },
"id": "qc", "id": "qc",
}, },
{ {
"data": { "data": {
"type": "http-request", "type": "http-request",
"title": "http"
}, },
"id": "http", "id": "http",
}, },
{ {
"data": { "data": {
"type": "answer", "type": "answer",
"title": "answer2"
}, },
"id": "answer2", "id": "answer2",
} }
@ -78,787 +91,30 @@ def test_init():
graph_config=graph_config graph_config=graph_config
) )
start_node_id = "start" variable_pool = VariablePool(system_variables={
SystemVariable.QUERY: 'what\'s the weather in SF',
SystemVariable.FILES: [],
SystemVariable.CONVERSATION_ID: 'abababa',
SystemVariable.USER_ID: 'aaa'
}, user_inputs={})
assert graph.root_node_id == start_node_id graph_engine = GraphEngine(
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" tenant_id="111",
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} app_id="222",
workflow_type=WorkflowType.CHAT,
workflow_id="333",
def test__init_iteration_graph(): user_id="444",
graph_config = { user_from=UserFrom.ACCOUNT,
"edges": [ invoke_from=InvokeFrom.WEB_APP,
{ call_depth=0,
"id": "llm-answer", graph=graph,
"source": "llm", variable_pool=variable_pool,
"sourceHandle": "source", max_execution_steps=500,
"target": "answer", max_execution_time=1200
},
{
"id": "iteration-source-llm-target",
"source": "iteration",
"sourceHandle": "source",
"target": "llm",
},
{
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
"source": "template-transform-in-iteration",
"sourceHandle": "source",
"target": "llm-in-iteration",
},
{
"id": "llm-in-iteration-source-answer-in-iteration-target",
"source": "llm-in-iteration",
"sourceHandle": "source",
"target": "answer-in-iteration",
},
{
"id": "start-source-code-target",
"source": "start",
"sourceHandle": "source",
"target": "code",
},
{
"id": "code-source-iteration-target",
"source": "code",
"sourceHandle": "source",
"target": "iteration",
}
],
"nodes": [
{
"data": {
"type": "start",
},
"id": "start",
},
{
"data": {
"type": "llm",
},
"id": "llm",
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "iteration"
},
"id": "iteration",
},
{
"data": {
"type": "template-transform",
},
"id": "template-transform-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "llm",
},
"id": "llm-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "answer",
},
"id": "answer-in-iteration",
"parentId": "iteration",
},
{
"data": {
"type": "code",
},
"id": "code",
}
]
}
graph = Graph.init(
graph_config=graph_config,
root_node_id="template-transform-in-iteration"
)
graph.add_extra_edge(
source_node_id="answer-in-iteration",
target_node_id="template-transform-in-iteration",
run_condition=RunCondition(
type="condition",
conditions=[
Condition(
variable_selector=["iteration", "index"],
comparison_operator="",
value="5"
)
]
)
) )
# iteration: print("")
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
assert graph.root_node_id == "template-transform-in-iteration" generator = graph_engine.run()
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" for item in generator:
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" print(type(item), item)
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
def test_parallels_graph():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
},
{
"id": "llm3-source-answer-target",
"source": "llm3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}"
assert graph.edge_mapping.get(f"llm{i+1}") is not None
assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph2():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "answer",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
if i < 2:
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph3():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 3
for node_id in ["llm1", "llm2", "llm3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph4():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-answer-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-answer-target",
"source": "llm2",
"target": "code2",
},
{
"id": "llm3-source-code3-target",
"source": "llm3",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
assert graph.edge_mapping.get(f"code{i + 1}") is not None
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph5():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm4",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm5",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm2-source-code1-target",
"source": "llm2",
"target": "code1",
},
{
"id": "llm3-source-code2-target",
"source": "llm3",
"target": "code2",
},
{
"id": "llm4-source-code2-target",
"source": "llm4",
"target": "code2",
},
{
"id": "llm5-source-code3-target",
"source": "llm5",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
{
"data": {
"type": "llm",
},
"id": "llm4"
},
{
"data": {
"type": "llm",
},
"id": "llm5"
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(5):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm3") is not None
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm4") is not None
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
assert graph.edge_mapping.get("llm5") is not None
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 1
assert len(graph.node_parallel_mapping) == 8
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
def test_parallels_graph6():
graph_config = {
"edges": [
{
"id": "start-source-llm1-target",
"source": "start",
"target": "llm1",
},
{
"id": "start-source-llm2-target",
"source": "start",
"target": "llm2",
},
{
"id": "start-source-llm3-target",
"source": "start",
"target": "llm3",
},
{
"id": "llm1-source-code1-target",
"source": "llm1",
"target": "code1",
},
{
"id": "llm1-source-code2-target",
"source": "llm1",
"target": "code2",
},
{
"id": "llm2-source-code3-target",
"source": "llm2",
"target": "code3",
},
{
"id": "code1-source-answer-target",
"source": "code1",
"target": "answer",
},
{
"id": "code2-source-answer-target",
"source": "code2",
"target": "answer",
},
{
"id": "code3-source-answer-target",
"source": "code3",
"target": "answer",
}
],
"nodes": [
{
"data": {
"type": "start"
},
"id": "start"
},
{
"data": {
"type": "llm",
},
"id": "llm1"
},
{
"data": {
"type": "code",
},
"id": "code1"
},
{
"data": {
"type": "llm",
},
"id": "llm2"
},
{
"data": {
"type": "code",
},
"id": "code2"
},
{
"data": {
"type": "llm",
},
"id": "llm3"
},
{
"data": {
"type": "code",
},
"id": "code3"
},
{
"data": {
"type": "answer",
},
"id": "answer",
},
],
}
graph = Graph.init(
graph_config=graph_config
)
assert graph.root_node_id == "start"
for i in range(3):
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
assert graph.edge_mapping.get("llm1") is not None
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
assert graph.edge_mapping.get("llm2") is not None
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
assert graph.edge_mapping.get("code1") is not None
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code2") is not None
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
assert graph.edge_mapping.get("code3") is not None
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
assert len(graph.parallel_mapping) == 2
assert len(graph.node_parallel_mapping) == 6
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
assert node_id in graph.node_parallel_mapping
parent_parallel = None
child_parallel = None
for p_id, parallel in graph.parallel_mapping.items():
if parallel.parent_parallel_id is None:
parent_parallel = parallel
else:
child_parallel = parallel
for node_id in ["llm1", "llm2", "llm3", "code3"]:
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
for node_id in ["code1", "code2"]:
assert graph.node_parallel_mapping[node_id] == child_parallel.id