diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 4e91508786..d11352f066 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -3,20 +3,6 @@ from typing import Any, Optional from pydantic import BaseModel -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.http_request.http_request_node import HttpRequestNode -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.iteration.iteration_node import IterationNode -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.llm_node import LLMNode -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode from models.workflow import WorkflowNodeExecutionStatus @@ -55,25 +41,6 @@ class NodeType(Enum): raise ValueError(f'invalid node type value {value}') -node_classes = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR - NodeType.ITERATION: IterationNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode -} - - class SystemVariable(Enum): """ System Variables. diff --git a/api/core/workflow/graph_engine/condition_handlers/base_handler.py b/api/core/workflow/graph_engine/condition_handlers/base_handler.py index 05d8f1765f..bdaa3d1529 100644 --- a/api/core/workflow/graph_engine/condition_handlers/base_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/base_handler.py @@ -1,23 +1,31 @@ from abc import ABC, abstractmethod +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.run_condition import RunCondition class RunConditionHandler(ABC): - def __init__(self, condition: RunCondition): + def __init__(self, + init_params: GraphInitParams, + graph: Graph, + condition: RunCondition): + self.init_params = init_params + self.graph = graph self.condition = condition @abstractmethod def check(self, + graph_runtime_state: GraphRuntimeState, source_node_id: str, - target_node_id: str, - graph: "Graph") -> bool: + target_node_id: str) -> bool: """ Check if the condition can be executed + :param graph_runtime_state: graph runtime state :param source_node_id: source node id :param target_node_id: target node id - :param graph: graph :return: bool """ raise NotImplementedError diff --git a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py index 000547b5cc..f2a52c6dab 100644 --- a/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/branch_identify_handler.py @@ -1,29 +1,33 @@ from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState class BranchIdentifyRunConditionHandler(RunConditionHandler): def check(self, + graph_runtime_state: GraphRuntimeState, source_node_id: str, - target_node_id: str, - graph: "Graph") -> bool: + target_node_id: str) -> bool: """ Check if the condition can be executed + :param graph_runtime_state: graph runtime state :param source_node_id: source node id :param target_node_id: target node id - :param graph: graph :return: bool """ if not self.condition.branch_identify: raise Exception("Branch identify is required") - run_state = graph.run_state - node_route_result = run_state.node_route_results.get(source_node_id) - if not node_route_result: + node_route_state = graph_runtime_state.node_run_state.node_state_mapping.get(source_node_id) + if not node_route_state: return False - if not node_route_result.edge_source_handle: + run_result = node_route_state.node_run_result + if not run_result: return False - return self.condition.branch_identify == node_route_result.edge_source_handle + if not run_result.edge_source_handle: + return False + + return self.condition.branch_identify == run_result.edge_source_handle diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index c71438cf89..3c1e8634c7 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -1,18 +1,19 @@ from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.utils.condition.processor import ConditionProcessor class ConditionRunConditionHandlerHandler(RunConditionHandler): def check(self, + graph_runtime_state: GraphRuntimeState, source_node_id: str, - target_node_id: str, - graph: "Graph") -> bool: + target_node_id: str) -> bool: """ Check if the condition can be executed + :param graph_runtime_state: graph runtime state :param source_node_id: source node id :param target_node_id: target node id - :param graph: graph :return: bool """ if not self.condition.conditions: @@ -21,10 +22,9 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): # process condition condition_processor = ConditionProcessor() compare_result, _ = condition_processor.process( - variable_pool=graph.run_state.variable_pool, + variable_pool=graph_runtime_state.variable_pool, logical_operator="and", conditions=self.condition.conditions ) return compare_result - diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py index 5b3c430418..2eb2e58bfc 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_manager.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_manager.py @@ -1,19 +1,35 @@ from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.run_condition import RunCondition class ConditionManager: @staticmethod - def get_condition_handler(run_condition: RunCondition) -> RunConditionHandler: + def get_condition_handler( + init_params: GraphInitParams, + graph: Graph, + run_condition: RunCondition + ) -> RunConditionHandler: """ Get condition handler + :param init_params: init params + :param graph: graph :param run_condition: run condition :return: condition handler """ if run_condition.type == "branch_identify": - return BranchIdentifyRunConditionHandler(run_condition) + return BranchIdentifyRunConditionHandler( + init_params=init_params, + graph=graph, + condition=run_condition + ) else: - return ConditionRunConditionHandlerHandler(run_condition) + return ConditionRunConditionHandlerHandler( + init_params=init_params, + graph=graph, + condition=run_condition + ) diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index a9a984f83d..9fff28a82f 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import BaseModel, Field, model_validator from core.workflow.entities.node_entities import NodeRunResult +from models.workflow import WorkflowNodeExecutionStatus class GraphEngineEvent(BaseModel): @@ -50,10 +51,12 @@ class NodeRunStartedEvent(BaseNodeEvent): class NodeRunStreamChunkEvent(BaseNodeEvent): chunk_content: str = Field(..., description="chunk content") + from_variable_selector: list[str] = Field(..., description="from variable selector") class NodeRunRetrieverResourceEvent(BaseNodeEvent): retriever_resources: list[dict] = Field(..., description="retriever resources") + context: str = Field(..., description="context") class NodeRunSucceededEvent(BaseNodeEvent): @@ -61,7 +64,10 @@ class NodeRunSucceededEvent(BaseNodeEvent): class NodeRunFailedEvent(BaseNodeEvent): - run_result: NodeRunResult = Field(..., description="run result") + run_result: NodeRunResult = Field( + default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED), + description="run result" + ) reason: str = Field("", description="failed reason") @model_validator(mode='before') diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index abda39b6e3..2d05dabfd7 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -3,13 +3,13 @@ import queue import time from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor -from typing import Optional, cast +from datetime import datetime, timezone +from typing import Optional from flask import Flask, current_app from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback from core.workflow.entities.node_entities import NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager @@ -19,14 +19,21 @@ from core.workflow.graph_engine.entities.event import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunFailedEvent, + NodeRunRetrieverResourceEvent, NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.nodes.base_node import UserFrom, node_classes +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState +from core.workflow.nodes import node_classes +from core.workflow.nodes.base_node import UserFrom +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent +from core.workflow.nodes.test.test_node import TestNode from extensions.ext_database import db -from models.workflow import WorkflowType +from models.workflow import WorkflowNodeExecutionStatus, WorkflowType thread_pool = ThreadPoolExecutor(max_workers=500, thread_name_prefix="ThreadGraphParallelRun") logger = logging.getLogger(__name__) @@ -43,7 +50,8 @@ class GraphEngine: call_depth: int, graph: Graph, variable_pool: VariablePool, - callbacks: list[BaseWorkflowCallback]) -> None: + max_execution_steps: int, + max_execution_time: int) -> None: self.graph = graph self.init_params = GraphInitParams( tenant_id=tenant_id, @@ -61,12 +69,8 @@ class GraphEngine: start_at=time.perf_counter() ) - max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS") - self.max_execution_steps = cast(int, max_execution_steps) - max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") - self.max_execution_time = cast(int, max_execution_time) - - self.callbacks = callbacks + self.max_execution_steps = max_execution_steps + self.max_execution_time = max_execution_time def run_in_block_mode(self): # TODO convert generator to result @@ -92,7 +96,7 @@ class GraphEngine: return except Exception as e: yield GraphRunFailedEvent(reason=str(e)) - return + raise e def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: next_node_id = start_node_id @@ -118,7 +122,7 @@ class GraphEngine: ) except Exception as e: yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e)) - return + raise e previous_node_id = next_node_id @@ -141,11 +145,13 @@ class GraphEngine: for edge in edge_mappings: if edge.run_condition: result = ConditionManager.get_condition_handler( - run_condition=edge.run_condition + init_params=self.init_params, + graph=self.graph, + run_condition=edge.run_condition, ).check( + graph_runtime_state=self.graph_runtime_state, source_node_id=edge.source_node_id, target_node_id=edge.target_node_id, - graph=self.graph ) if result: @@ -250,7 +256,16 @@ class GraphEngine: raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') # init workflow run state - node_instance = node_cls( # type: ignore + # node_instance = node_cls( # type: ignore + # config=node_config, + # graph_init_params=self.init_params, + # graph=self.graph, + # graph_runtime_state=self.graph_runtime_state, + # previous_node_id=previous_node_id + # ) + + # init workflow run state + node_instance = TestNode( config=node_config, graph_init_params=self.init_params, graph=self.graph, @@ -268,24 +283,64 @@ class GraphEngine: self.graph_runtime_state.node_run_steps += 1 try: + start_at = datetime.now(timezone.utc).replace(tzinfo=None) + # run node generator = node_instance.run() + run_result = None + for item in generator: + if isinstance(item, RunCompletedEvent): + run_result = item.run_result + if run_result.status == WorkflowNodeExecutionStatus.FAILED: + yield NodeRunFailedEvent( + node_id=node_id, + parallel_id=parallel_id, + run_result=run_result, + reason=run_result.error + ) + elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: + yield NodeRunSucceededEvent( + node_id=node_id, + parallel_id=parallel_id, + run_result=run_result + ) - yield from generator + self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState( + node_id=node_id, + start_at=start_at, + status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + else RouteNodeState.Status.FAILED, + finished_at=datetime.now(timezone.utc).replace(tzinfo=None), + node_run_result=run_result, + failed_reason=run_result.error + if run_result.status == WorkflowNodeExecutionStatus.FAILED else None + ) + + # todo append self.graph_runtime_state.node_run_state.routes + break + elif isinstance(item, RunStreamChunkEvent): + yield NodeRunStreamChunkEvent( + node_id=node_id, + parallel_id=parallel_id, + chunk_content=item.chunk_content, + from_variable_selector=item.from_variable_selector, + ) + elif isinstance(item, RunRetrieverResourceEvent): + yield NodeRunRetrieverResourceEvent( + node_id=node_id, + parallel_id=parallel_id, + retriever_resources=item.retriever_resources, + context=item.context + ) # todo record state - - # trigger node run success event - yield NodeRunSucceededEvent(node_id=node_id, parallel_id=parallel_id) except GenerateTaskStoppedException as e: # trigger node run failed event yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e)) return except Exception as e: - # todo logger.exception(f"Node {node.node_data.title} run failed: {str(e)}") - # trigger node run failed event - yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e)) - return + logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}") + raise e finally: db.session.close() diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py index e69de29bb2..df1eb98989 100644 --- a/api/core/workflow/nodes/__init__.py +++ b/api/core/workflow/nodes/__init__.py @@ -0,0 +1,33 @@ +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.nodes.code.code_node import CodeNode +from core.workflow.nodes.end.end_node import EndNode +from core.workflow.nodes.http_request.http_request_node import HttpRequestNode +from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.nodes.iteration.iteration_node import IterationNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.llm_node import LLMNode +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.start.start_node import StartNode +from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode + +node_classes = { + NodeType.START: StartNode, + NodeType.END: EndNode, + NodeType.ANSWER: AnswerNode, + NodeType.LLM: LLMNode, + NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, + NodeType.IF_ELSE: IfElseNode, + NodeType.CODE: CodeNode, + NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, + NodeType.HTTP_REQUEST: HttpRequestNode, + NodeType.TOOL: ToolNode, + NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, + NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: IterationNode, + NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode +} diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index d4ab6a5f5f..7215cda4b9 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -2,37 +2,20 @@ from abc import ABC, abstractmethod from collections.abc import Generator from typing import Optional -from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType, UserFrom +from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.iterable_node import IterableNodeMixin -from models.workflow import WorkflowType class BaseNode(ABC): _node_data_cls: type[BaseNodeData] _node_type: NodeType - tenant_id: str - app_id: str - workflow_type: WorkflowType - workflow_id: str - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - workflow_call_depth: int - graph: Graph - graph_runtime_state: GraphRuntimeState - previous_node_id: Optional[str] = None - - node_id: str - node_data: BaseNodeData - def __init__(self, config: dict, graph_init_params: GraphInitParams, @@ -81,24 +64,6 @@ class BaseNode(ABC): else: yield from result - def publish_text_chunk(self, text: str, value_selector: list[str] = None) -> None: - """ - Publish text chunk - :param text: chunk text - :param value_selector: value selector - :return: - """ - # TODO remove callbacks - if self.callbacks: - for callback in self.callbacks: - callback.on_node_text_chunk( - node_id=self.node_id, - text=text, - metadata={ - "value_selector": value_selector - } - ) - @classmethod def extract_variable_selector_to_variable_mapping(cls, config: dict) -> dict[str, list[str]]: """ diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 9b498062cc..cb0f3dc4d7 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -170,7 +170,7 @@ class LLMNode(BaseNode): model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ - -> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]: + -> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]: """ Invoke large language model :param node_data_model: node data model @@ -204,7 +204,7 @@ class LLMNode(BaseNode): self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ - -> Generator[RunStreamChunkEvent | "ModelInvokeCompleted", None, None]: + -> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]: """ Handle invoke result :param invoke_result: invoke result diff --git a/api/core/workflow/nodes/test/__init__.py b/api/core/workflow/nodes/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/test/entities.py b/api/core/workflow/nodes/test/entities.py new file mode 100644 index 0000000000..8d9610737c --- /dev/null +++ b/api/core/workflow/nodes/test/entities.py @@ -0,0 +1,8 @@ +from core.workflow.entities.base_node_data_entities import BaseNodeData + + +class TestNodeData(BaseNodeData): + """ + Test Node Data. + """ + pass diff --git a/api/core/workflow/nodes/test/test_node.py b/api/core/workflow/nodes/test/test_node.py new file mode 100644 index 0000000000..91e6fd5bc0 --- /dev/null +++ b/api/core/workflow/nodes/test/test_node.py @@ -0,0 +1,33 @@ + +from core.workflow.entities.base_node_data_entities import BaseNodeData +from core.workflow.entities.node_entities import NodeRunResult, NodeType +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.test.entities import TestNodeData +from models.workflow import WorkflowNodeExecutionStatus + + +class TestNode(BaseNode): + _node_data_cls = TestNodeData + node_type = NodeType.ANSWER + + def _run(self) -> NodeRunResult: + """ + Run node + :return: + """ + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "content": "abc" + }, + edge_source_handle="1" + ) + + @classmethod + def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: + """ + Extract variable selector to variable mapping + :param node_data: node data + :return: + """ + return {} diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 5a54016aab..7ea9a76d5b 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -17,7 +17,8 @@ from core.workflow.entities.workflow_runtime_state import WorkflowRuntimeState from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom, node_classes +from core.workflow.nodes import node_classes +from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom from core.workflow.nodes.iteration.entities import IterationState from core.workflow.nodes.llm.entities import LLMNodeData from core.workflow.nodes.start.start_node import StartNode @@ -93,7 +94,8 @@ class WorkflowEntry: call_depth=call_depth, graph=graph, variable_pool=variable_pool, - callbacks=callbacks + max_execution_steps=current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS"), + max_execution_time=current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME") ) # init workflow run diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 17e92b876e..eb8c54dcbe 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -10,7 +10,8 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.workflow_engine_manager import WorkflowEngineManager, node_classes +from core.workflow.nodes import node_classes +from core.workflow.workflow_engine_manager import WorkflowEngineManager from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py new file mode 100644 index 0000000000..2987343693 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph.py @@ -0,0 +1,864 @@ +from core.workflow.graph_engine.entities.graph import Graph +from core.workflow.graph_engine.entities.run_condition import RunCondition +from core.workflow.utils.condition.entities import Condition + + +def test_init(): + graph_config = { + "edges": [ + { + "id": "llm-source-answer-target", + "source": "llm", + "target": "answer", + }, + { + "id": "start-source-qc-target", + "source": "start", + "target": "qc", + }, + { + "id": "qc-1-llm-target", + "source": "qc", + "sourceHandle": "1", + "target": "llm", + }, + { + "id": "qc-2-http-target", + "source": "qc", + "sourceHandle": "2", + "target": "http", + }, + { + "id": "http-source-answer2-target", + "source": "http", + "target": "answer2", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + { + "data": { + "type": "question-classifier" + }, + "id": "qc", + }, + { + "data": { + "type": "http-request", + }, + "id": "http", + }, + { + "data": { + "type": "answer", + }, + "id": "answer2", + } + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + start_node_id = "start" + + assert graph.root_node_id == start_node_id + assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" + assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} + + +def test__init_iteration_graph(): + graph_config = { + "edges": [ + { + "id": "llm-answer", + "source": "llm", + "sourceHandle": "source", + "target": "answer", + }, + { + "id": "iteration-source-llm-target", + "source": "iteration", + "sourceHandle": "source", + "target": "llm", + }, + { + "id": "template-transform-in-iteration-source-llm-in-iteration-target", + "source": "template-transform-in-iteration", + "sourceHandle": "source", + "target": "llm-in-iteration", + }, + { + "id": "llm-in-iteration-source-answer-in-iteration-target", + "source": "llm-in-iteration", + "sourceHandle": "source", + "target": "answer-in-iteration", + }, + { + "id": "start-source-code-target", + "source": "start", + "sourceHandle": "source", + "target": "code", + }, + { + "id": "code-source-iteration-target", + "source": "code", + "sourceHandle": "source", + "target": "iteration", + } + ], + "nodes": [ + { + "data": { + "type": "start", + }, + "id": "start", + }, + { + "data": { + "type": "llm", + }, + "id": "llm", + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + { + "data": { + "type": "iteration" + }, + "id": "iteration", + }, + { + "data": { + "type": "template-transform", + }, + "id": "template-transform-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "llm", + }, + "id": "llm-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "answer", + }, + "id": "answer-in-iteration", + "parentId": "iteration", + }, + { + "data": { + "type": "code", + }, + "id": "code", + } + ] + } + + graph = Graph.init( + graph_config=graph_config, + root_node_id="template-transform-in-iteration" + ) + graph.add_extra_edge( + source_node_id="answer-in-iteration", + target_node_id="template-transform-in-iteration", + run_condition=RunCondition( + type="condition", + conditions=[ + Condition( + variable_selector=["iteration", "index"], + comparison_operator="≤", + value="5" + ) + ] + ) + ) + + # iteration: + # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] + + assert graph.root_node_id == "template-transform-in-iteration" + assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" + assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" + assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" + + +def test_parallels_graph(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + }, + { + "id": "llm3-source-answer-target", + "source": "llm3", + "target": "answer", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}" + assert graph.edge_mapping.get(f"llm{i+1}") is not None + assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph2(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "answer", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "answer", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + if i < 2: + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph3(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 3 + + for node_id in ["llm1", "llm2", "llm3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph4(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-answer-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-answer-target", + "source": "llm2", + "target": "code2", + }, + { + "id": "llm3-source-code3-target", + "source": "llm3", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "code", + }, + "id": "code1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "code", + }, + "id": "code2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "code", + }, + "id": "code3" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + assert graph.edge_mapping.get(f"llm{i + 1}") is not None + assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}" + assert graph.edge_mapping.get(f"code{i + 1}") is not None + assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph5(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm4", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm5", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm2-source-code1-target", + "source": "llm2", + "target": "code1", + }, + { + "id": "llm3-source-code2-target", + "source": "llm3", + "target": "code2", + }, + { + "id": "llm4-source-code2-target", + "source": "llm4", + "target": "code2", + }, + { + "id": "llm5-source-code3-target", + "source": "llm5", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "code", + }, + "id": "code1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "code", + }, + "id": "code2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "code", + }, + "id": "code3" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + { + "data": { + "type": "llm", + }, + "id": "llm4" + }, + { + "data": { + "type": "llm", + }, + "id": "llm5" + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + assert graph.root_node_id == "start" + for i in range(5): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm3") is not None + assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm4") is not None + assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2" + assert graph.edge_mapping.get("llm5") is not None + assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 1 + assert len(graph.node_parallel_mapping) == 8 + + for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + +def test_parallels_graph6(): + graph_config = { + "edges": [ + { + "id": "start-source-llm1-target", + "source": "start", + "target": "llm1", + }, + { + "id": "start-source-llm2-target", + "source": "start", + "target": "llm2", + }, + { + "id": "start-source-llm3-target", + "source": "start", + "target": "llm3", + }, + { + "id": "llm1-source-code1-target", + "source": "llm1", + "target": "code1", + }, + { + "id": "llm1-source-code2-target", + "source": "llm1", + "target": "code2", + }, + { + "id": "llm2-source-code3-target", + "source": "llm2", + "target": "code3", + }, + { + "id": "code1-source-answer-target", + "source": "code1", + "target": "answer", + }, + { + "id": "code2-source-answer-target", + "source": "code2", + "target": "answer", + }, + { + "id": "code3-source-answer-target", + "source": "code3", + "target": "answer", + } + ], + "nodes": [ + { + "data": { + "type": "start" + }, + "id": "start" + }, + { + "data": { + "type": "llm", + }, + "id": "llm1" + }, + { + "data": { + "type": "code", + }, + "id": "code1" + }, + { + "data": { + "type": "llm", + }, + "id": "llm2" + }, + { + "data": { + "type": "code", + }, + "id": "code2" + }, + { + "data": { + "type": "llm", + }, + "id": "llm3" + }, + { + "data": { + "type": "code", + }, + "id": "code3" + }, + { + "data": { + "type": "answer", + }, + "id": "answer", + }, + ], + } + + graph = Graph.init( + graph_config=graph_config + ) + + assert graph.root_node_id == "start" + for i in range(3): + assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" + + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" + assert graph.edge_mapping.get("llm1") is not None + assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2" + assert graph.edge_mapping.get("llm2") is not None + assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3" + assert graph.edge_mapping.get("code1") is not None + assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code2") is not None + assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" + assert graph.edge_mapping.get("code3") is not None + assert graph.edge_mapping.get("code3")[0].target_node_id == "answer" + + assert len(graph.parallel_mapping) == 2 + assert len(graph.node_parallel_mapping) == 6 + + for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: + assert node_id in graph.node_parallel_mapping + + parent_parallel = None + child_parallel = None + for p_id, parallel in graph.parallel_mapping.items(): + if parallel.parent_parallel_id is None: + parent_parallel = parallel + else: + child_parallel = parallel + + for node_id in ["llm1", "llm2", "llm3", "code3"]: + assert graph.node_parallel_mapping[node_id] == parent_parallel.id + + for node_id in ["code1", "code2"]: + assert graph.node_parallel_mapping[node_id] == child_parallel.id diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 2987343693..5032bac330 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -1,9 +1,16 @@ +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.entities.node_entities import SystemVariable, UserFrom +from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.run_condition import RunCondition -from core.workflow.utils.condition.entities import Condition +from core.workflow.graph_engine.graph_engine import GraphEngine +from models.workflow import WorkflowType -def test_init(): +@patch('extensions.ext_database.db.session.remove') +@patch('extensions.ext_database.db.session.close') +def test_run(mock_close, mock_remove): graph_config = { "edges": [ { @@ -37,37 +44,43 @@ def test_init(): "nodes": [ { "data": { - "type": "start" + "type": "start", + "title": "start" }, "id": "start" }, { "data": { "type": "llm", + "title": "llm" }, "id": "llm" }, { "data": { "type": "answer", + "title": "answer" }, "id": "answer", }, { "data": { - "type": "question-classifier" + "type": "question-classifier", + "title": "qc" }, "id": "qc", }, { "data": { "type": "http-request", + "title": "http" }, "id": "http", }, { "data": { "type": "answer", + "title": "answer2" }, "id": "answer2", } @@ -78,787 +91,30 @@ def test_init(): graph_config=graph_config ) - start_node_id = "start" + variable_pool = VariablePool(system_variables={ + SystemVariable.QUERY: 'what\'s the weather in SF', + SystemVariable.FILES: [], + SystemVariable.CONVERSATION_ID: 'abababa', + SystemVariable.USER_ID: 'aaa' + }, user_inputs={}) - assert graph.root_node_id == start_node_id - assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc" - assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")} - - -def test__init_iteration_graph(): - graph_config = { - "edges": [ - { - "id": "llm-answer", - "source": "llm", - "sourceHandle": "source", - "target": "answer", - }, - { - "id": "iteration-source-llm-target", - "source": "iteration", - "sourceHandle": "source", - "target": "llm", - }, - { - "id": "template-transform-in-iteration-source-llm-in-iteration-target", - "source": "template-transform-in-iteration", - "sourceHandle": "source", - "target": "llm-in-iteration", - }, - { - "id": "llm-in-iteration-source-answer-in-iteration-target", - "source": "llm-in-iteration", - "sourceHandle": "source", - "target": "answer-in-iteration", - }, - { - "id": "start-source-code-target", - "source": "start", - "sourceHandle": "source", - "target": "code", - }, - { - "id": "code-source-iteration-target", - "source": "code", - "sourceHandle": "source", - "target": "iteration", - } - ], - "nodes": [ - { - "data": { - "type": "start", - }, - "id": "start", - }, - { - "data": { - "type": "llm", - }, - "id": "llm", - }, - { - "data": { - "type": "answer", - }, - "id": "answer", - }, - { - "data": { - "type": "iteration" - }, - "id": "iteration", - }, - { - "data": { - "type": "template-transform", - }, - "id": "template-transform-in-iteration", - "parentId": "iteration", - }, - { - "data": { - "type": "llm", - }, - "id": "llm-in-iteration", - "parentId": "iteration", - }, - { - "data": { - "type": "answer", - }, - "id": "answer-in-iteration", - "parentId": "iteration", - }, - { - "data": { - "type": "code", - }, - "id": "code", - } - ] - } - - graph = Graph.init( - graph_config=graph_config, - root_node_id="template-transform-in-iteration" - ) - graph.add_extra_edge( - source_node_id="answer-in-iteration", - target_node_id="template-transform-in-iteration", - run_condition=RunCondition( - type="condition", - conditions=[ - Condition( - variable_selector=["iteration", "index"], - comparison_operator="≤", - value="5" - ) - ] - ) + graph_engine = GraphEngine( + tenant_id="111", + app_id="222", + workflow_type=WorkflowType.CHAT, + workflow_id="333", + user_id="444", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.WEB_APP, + call_depth=0, + graph=graph, + variable_pool=variable_pool, + max_execution_steps=500, + max_execution_time=1200 ) - # iteration: - # [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration] + print("") - assert graph.root_node_id == "template-transform-in-iteration" - assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration" - assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration" - assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration" - - -def test_parallels_graph(): - graph_config = { - "edges": [ - { - "id": "start-source-llm1-target", - "source": "start", - "target": "llm1", - }, - { - "id": "start-source-llm2-target", - "source": "start", - "target": "llm2", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm3", - }, - { - "id": "llm1-source-answer-target", - "source": "llm1", - "target": "answer", - }, - { - "id": "llm2-source-answer-target", - "source": "llm2", - "target": "answer", - }, - { - "id": "llm3-source-answer-target", - "source": "llm3", - "target": "answer", - } - ], - "nodes": [ - { - "data": { - "type": "start" - }, - "id": "start" - }, - { - "data": { - "type": "llm", - }, - "id": "llm1" - }, - { - "data": { - "type": "llm", - }, - "id": "llm2" - }, - { - "data": { - "type": "llm", - }, - "id": "llm3" - }, - { - "data": { - "type": "answer", - }, - "id": "answer", - }, - ], - } - - graph = Graph.init( - graph_config=graph_config - ) - - assert graph.root_node_id == "start" - for i in range(3): - assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i+1}" - assert graph.edge_mapping.get(f"llm{i+1}") is not None - assert graph.edge_mapping.get(f"llm{i+1}")[0].target_node_id == "answer" - - assert len(graph.parallel_mapping) == 1 - assert len(graph.node_parallel_mapping) == 3 - - for node_id in ["llm1", "llm2", "llm3"]: - assert node_id in graph.node_parallel_mapping - - -def test_parallels_graph2(): - graph_config = { - "edges": [ - { - "id": "start-source-llm1-target", - "source": "start", - "target": "llm1", - }, - { - "id": "start-source-llm2-target", - "source": "start", - "target": "llm2", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm3", - }, - { - "id": "llm1-source-answer-target", - "source": "llm1", - "target": "answer", - }, - { - "id": "llm2-source-answer-target", - "source": "llm2", - "target": "answer", - } - ], - "nodes": [ - { - "data": { - "type": "start" - }, - "id": "start" - }, - { - "data": { - "type": "llm", - }, - "id": "llm1" - }, - { - "data": { - "type": "llm", - }, - "id": "llm2" - }, - { - "data": { - "type": "llm", - }, - "id": "llm3" - }, - { - "data": { - "type": "answer", - }, - "id": "answer", - }, - ], - } - - graph = Graph.init( - graph_config=graph_config - ) - - assert graph.root_node_id == "start" - for i in range(3): - assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" - - if i < 2: - assert graph.edge_mapping.get(f"llm{i + 1}") is not None - assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer" - - assert len(graph.parallel_mapping) == 1 - assert len(graph.node_parallel_mapping) == 3 - - for node_id in ["llm1", "llm2", "llm3"]: - assert node_id in graph.node_parallel_mapping - - -def test_parallels_graph3(): - graph_config = { - "edges": [ - { - "id": "start-source-llm1-target", - "source": "start", - "target": "llm1", - }, - { - "id": "start-source-llm2-target", - "source": "start", - "target": "llm2", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm3", - }, - ], - "nodes": [ - { - "data": { - "type": "start" - }, - "id": "start" - }, - { - "data": { - "type": "llm", - }, - "id": "llm1" - }, - { - "data": { - "type": "llm", - }, - "id": "llm2" - }, - { - "data": { - "type": "llm", - }, - "id": "llm3" - }, - { - "data": { - "type": "answer", - }, - "id": "answer", - }, - ], - } - - graph = Graph.init( - graph_config=graph_config - ) - - assert graph.root_node_id == "start" - for i in range(3): - assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" - - assert len(graph.parallel_mapping) == 1 - assert len(graph.node_parallel_mapping) == 3 - - for node_id in ["llm1", "llm2", "llm3"]: - assert node_id in graph.node_parallel_mapping - - -def test_parallels_graph4(): - graph_config = { - "edges": [ - { - "id": "start-source-llm1-target", - "source": "start", - "target": "llm1", - }, - { - "id": "start-source-llm2-target", - "source": "start", - "target": "llm2", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm3", - }, - { - "id": "llm1-source-answer-target", - "source": "llm1", - "target": "code1", - }, - { - "id": "llm2-source-answer-target", - "source": "llm2", - "target": "code2", - }, - { - "id": "llm3-source-code3-target", - "source": "llm3", - "target": "code3", - }, - { - "id": "code1-source-answer-target", - "source": "code1", - "target": "answer", - }, - { - "id": "code2-source-answer-target", - "source": "code2", - "target": "answer", - }, - { - "id": "code3-source-answer-target", - "source": "code3", - "target": "answer", - } - ], - "nodes": [ - { - "data": { - "type": "start" - }, - "id": "start" - }, - { - "data": { - "type": "llm", - }, - "id": "llm1" - }, - { - "data": { - "type": "code", - }, - "id": "code1" - }, - { - "data": { - "type": "llm", - }, - "id": "llm2" - }, - { - "data": { - "type": "code", - }, - "id": "code2" - }, - { - "data": { - "type": "llm", - }, - "id": "llm3" - }, - { - "data": { - "type": "code", - }, - "id": "code3" - }, - { - "data": { - "type": "answer", - }, - "id": "answer", - }, - ], - } - - graph = Graph.init( - graph_config=graph_config - ) - - assert graph.root_node_id == "start" - for i in range(3): - assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" - assert graph.edge_mapping.get(f"llm{i + 1}") is not None - assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}" - assert graph.edge_mapping.get(f"code{i + 1}") is not None - assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer" - - assert len(graph.parallel_mapping) == 1 - assert len(graph.node_parallel_mapping) == 6 - - for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: - assert node_id in graph.node_parallel_mapping - - -def test_parallels_graph5(): - graph_config = { - "edges": [ - { - "id": "start-source-llm1-target", - "source": "start", - "target": "llm1", - }, - { - "id": "start-source-llm2-target", - "source": "start", - "target": "llm2", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm3", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm4", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm5", - }, - { - "id": "llm1-source-code1-target", - "source": "llm1", - "target": "code1", - }, - { - "id": "llm2-source-code1-target", - "source": "llm2", - "target": "code1", - }, - { - "id": "llm3-source-code2-target", - "source": "llm3", - "target": "code2", - }, - { - "id": "llm4-source-code2-target", - "source": "llm4", - "target": "code2", - }, - { - "id": "llm5-source-code3-target", - "source": "llm5", - "target": "code3", - }, - { - "id": "code1-source-answer-target", - "source": "code1", - "target": "answer", - }, - { - "id": "code2-source-answer-target", - "source": "code2", - "target": "answer", - } - ], - "nodes": [ - { - "data": { - "type": "start" - }, - "id": "start" - }, - { - "data": { - "type": "llm", - }, - "id": "llm1" - }, - { - "data": { - "type": "code", - }, - "id": "code1" - }, - { - "data": { - "type": "llm", - }, - "id": "llm2" - }, - { - "data": { - "type": "code", - }, - "id": "code2" - }, - { - "data": { - "type": "llm", - }, - "id": "llm3" - }, - { - "data": { - "type": "code", - }, - "id": "code3" - }, - { - "data": { - "type": "answer", - }, - "id": "answer", - }, - { - "data": { - "type": "llm", - }, - "id": "llm4" - }, - { - "data": { - "type": "llm", - }, - "id": "llm5" - }, - ], - } - - graph = Graph.init( - graph_config=graph_config - ) - - assert graph.root_node_id == "start" - for i in range(5): - assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" - - assert graph.edge_mapping.get("llm1") is not None - assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" - assert graph.edge_mapping.get("llm2") is not None - assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1" - assert graph.edge_mapping.get("llm3") is not None - assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2" - assert graph.edge_mapping.get("llm4") is not None - assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2" - assert graph.edge_mapping.get("llm5") is not None - assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3" - assert graph.edge_mapping.get("code1") is not None - assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" - assert graph.edge_mapping.get("code2") is not None - assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" - - assert len(graph.parallel_mapping) == 1 - assert len(graph.node_parallel_mapping) == 8 - - for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]: - assert node_id in graph.node_parallel_mapping - - -def test_parallels_graph6(): - graph_config = { - "edges": [ - { - "id": "start-source-llm1-target", - "source": "start", - "target": "llm1", - }, - { - "id": "start-source-llm2-target", - "source": "start", - "target": "llm2", - }, - { - "id": "start-source-llm3-target", - "source": "start", - "target": "llm3", - }, - { - "id": "llm1-source-code1-target", - "source": "llm1", - "target": "code1", - }, - { - "id": "llm1-source-code2-target", - "source": "llm1", - "target": "code2", - }, - { - "id": "llm2-source-code3-target", - "source": "llm2", - "target": "code3", - }, - { - "id": "code1-source-answer-target", - "source": "code1", - "target": "answer", - }, - { - "id": "code2-source-answer-target", - "source": "code2", - "target": "answer", - }, - { - "id": "code3-source-answer-target", - "source": "code3", - "target": "answer", - } - ], - "nodes": [ - { - "data": { - "type": "start" - }, - "id": "start" - }, - { - "data": { - "type": "llm", - }, - "id": "llm1" - }, - { - "data": { - "type": "code", - }, - "id": "code1" - }, - { - "data": { - "type": "llm", - }, - "id": "llm2" - }, - { - "data": { - "type": "code", - }, - "id": "code2" - }, - { - "data": { - "type": "llm", - }, - "id": "llm3" - }, - { - "data": { - "type": "code", - }, - "id": "code3" - }, - { - "data": { - "type": "answer", - }, - "id": "answer", - }, - ], - } - - graph = Graph.init( - graph_config=graph_config - ) - - assert graph.root_node_id == "start" - for i in range(3): - assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}" - - assert graph.edge_mapping.get("llm1") is not None - assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1" - assert graph.edge_mapping.get("llm1") is not None - assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2" - assert graph.edge_mapping.get("llm2") is not None - assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3" - assert graph.edge_mapping.get("code1") is not None - assert graph.edge_mapping.get("code1")[0].target_node_id == "answer" - assert graph.edge_mapping.get("code2") is not None - assert graph.edge_mapping.get("code2")[0].target_node_id == "answer" - assert graph.edge_mapping.get("code3") is not None - assert graph.edge_mapping.get("code3")[0].target_node_id == "answer" - - assert len(graph.parallel_mapping) == 2 - assert len(graph.node_parallel_mapping) == 6 - - for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]: - assert node_id in graph.node_parallel_mapping - - parent_parallel = None - child_parallel = None - for p_id, parallel in graph.parallel_mapping.items(): - if parallel.parent_parallel_id is None: - parent_parallel = parallel - else: - child_parallel = parallel - - for node_id in ["llm1", "llm2", "llm3", "code3"]: - assert graph.node_parallel_mapping[node_id] == parent_parallel.id - - for node_id in ["code1", "code2"]: - assert graph.node_parallel_mapping[node_id] == child_parallel.id + generator = graph_engine.run() + for item in generator: + print(type(item), item)