diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 7b479ef536..5a664beca8 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -15,7 +15,7 @@ from core.app.entities.app_invoke_entities import ( from core.app.entities.queue_entities import QueueAnnotationReplyEvent, QueueStopEvent, QueueTextChunkEvent from core.moderation.base import ModerationException from core.workflow.entities.node_entities import SystemVariable -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App, Conversation, EndUser, Message diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 36e2deb42d..27a874ce34 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import ( WorkflowAppGenerateEntity, ) from core.workflow.entities.node_entities import SystemVariable -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.workflow_engine_manager import WorkflowEngineManager from extensions.ext_database import db from models.model import App, EndUser diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index e30a905cbc..b9fcecc05e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -32,7 +32,6 @@ from core.tools.utils.configuration import ( ToolParameterConfigurationManager, ) from core.tools.utils.tool_parameter_converter import ToolParameterConverter -from core.workflow.nodes.tool.entities import ToolEntity from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -255,7 +254,7 @@ class ToolManager: return tool_entity @classmethod - def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: + def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: "ToolEntity", invoke_from: InvokeFrom = InvokeFrom.DEBUGGER) -> Tool: """ get the workflow tool runtime """ diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index ef9e5b67ae..d5cf10da0c 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -7,6 +7,7 @@ from core.tools.tool_file_manager import ToolFileManager logger = logging.getLogger(__name__) + class ToolFileMessageTransformer: @classmethod def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage], diff --git a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py index 3c1e8634c7..15946bdd84 100644 --- a/api/core/workflow/graph_engine/condition_handlers/condition_handler.py +++ b/api/core/workflow/graph_engine/condition_handlers/condition_handler.py @@ -21,10 +21,12 @@ class ConditionRunConditionHandlerHandler(RunConditionHandler): # process condition condition_processor = ConditionProcessor() - compare_result, _ = condition_processor.process( + input_conditions, group_result = condition_processor.process_conditions( variable_pool=graph_runtime_state.variable_pool, - logical_operator="and", conditions=self.condition.conditions ) + # Apply the logical operator for the current case + compare_result = all(group_result) + return compare_result diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 9fff28a82f..46c7b6fba9 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -3,6 +3,7 @@ from typing import Optional from pydantic import BaseModel, Field, model_validator from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from models.workflow import WorkflowNodeExecutionStatus @@ -40,7 +41,7 @@ class GraphRunFailedEvent(BaseGraphEvent): class BaseNodeEvent(GraphEngineEvent): - node_id: str = Field(..., description="node id") + route_node_state: RouteNodeState = Field(..., description="route node state") parallel_id: Optional[str] = Field(None, description="parallel id if node is in parallel") # iteration_id: Optional[str] = Field(None, description="iteration id if node is in iteration") @@ -60,21 +61,11 @@ class NodeRunRetrieverResourceEvent(BaseNodeEvent): class NodeRunSucceededEvent(BaseNodeEvent): - run_result: NodeRunResult = Field(..., description="run result") + pass class NodeRunFailedEvent(BaseNodeEvent): - run_result: NodeRunResult = Field( - default=NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED), - description="run result" - ) - reason: str = Field("", description="failed reason") - - @model_validator(mode='before') - def init_reason(cls, values: dict) -> dict: - if not values.get("reason"): - values["reason"] = values.get("run_result").error or "Unknown error" - return values + pass ########################################### diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index e6851ac223..a300d8b2f4 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -16,7 +16,7 @@ class RouteNodeState(BaseModel): FAILED = "failed" PAUSED = "paused" - state_id: str = Field(default_factory=lambda: str(uuid.uuid4())) + id: str = Field(default_factory=lambda: str(uuid.uuid4())) """node state id""" node_id: str @@ -45,11 +45,15 @@ class RouteNodeState(BaseModel): class RuntimeRouteState(BaseModel): - routes: dict[str, list[str]] = Field(default_factory=dict) - """graph state routes (source_node_state_id: target_node_state_id)""" + routes: dict[str, list[str]] = Field( + default_factory=dict, + description="graph state routes (source_node_state_id: target_node_state_id)" + ) - node_state_mapping: dict[str, RouteNodeState] = Field(default_factory=dict) - """node state mapping (route_node_state_id: route_node_state)""" + node_state_mapping: dict[str, RouteNodeState] = Field( + default_factory=dict, + description="node state mapping (route_node_state_id: route_node_state)" + ) def create_node_state(self, node_id: str) -> RouteNodeState: """ @@ -58,7 +62,7 @@ class RuntimeRouteState(BaseModel): :param node_id: node id """ state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) - self.node_state_mapping[state.state_id] = state + self.node_state_mapping[state.id] = state return state def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None: diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2d05dabfd7..211d0df5fb 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -10,7 +10,7 @@ from flask import Flask, current_app from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import NodeType +from core.workflow.entities.node_entities import NodeType, UserFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( @@ -29,9 +29,7 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import node_classes -from core.workflow.nodes.base_node import UserFrom from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.test.test_node import TestNode from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -86,7 +84,7 @@ class GraphEngine: for item in generator: yield item if isinstance(item, NodeRunFailedEvent): - yield GraphRunFailedEvent(reason=item.reason) + yield GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.') return # trigger graph run success event @@ -100,7 +98,7 @@ class GraphEngine: def _run(self, start_node_id: str, in_parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: next_node_id = start_node_id - previous_node_id = None + previous_route_node_state: Optional[RouteNodeState] = None while True: # max steps reached if self.graph_runtime_state.node_run_steps > self.max_execution_steps: @@ -108,23 +106,42 @@ class GraphEngine: # or max execution time reached if self._is_timed_out( - start_at=self.graph_runtime_state.start_at, - max_execution_time=self.max_execution_time + start_at=self.graph_runtime_state.start_at, + max_execution_time=self.max_execution_time ): raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time)) + # init route node state + route_node_state = self.graph_runtime_state.create_node_state( + node_id=next_node_id + ) + try: # run node yield from self._run_node( - node_id=next_node_id, - previous_node_id=previous_node_id, + route_node_state=route_node_state, + previous_node_id=previous_route_node_state.node_id if previous_route_node_state else None, parallel_id=in_parallel_id ) + + self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state + + # append route + if previous_route_node_state: + if previous_route_node_state.id not in self.graph_runtime_state.node_run_state.routes: + self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id] = [] + + self.graph_runtime_state.node_run_state.routes[previous_route_node_state.id].append( + route_node_state.id + ) except Exception as e: - yield NodeRunFailedEvent(node_id=next_node_id, reason=str(e)) + yield NodeRunFailedEvent( + route_node_state=route_node_state, + parallel_id=in_parallel_id + ) raise e - previous_node_id = next_node_id + previous_route_node_state = route_node_state # get next node ids edge_mappings = self.graph.edge_mapping.get(next_node_id) @@ -227,24 +244,32 @@ class GraphEngine: in_parallel_id=parallel_id ) - if generator: - for item in generator: - q.put(item) - except Exception: + for item in generator: + q.put(item) + if isinstance(item, NodeRunFailedEvent): + q.put(GraphRunFailedEvent(reason=item.route_node_state.failed_reason or 'Unknown error.')) + return + + # trigger graph run success event + q.put(GraphRunSucceededEvent()) + except (GraphRunFailedError, NodeRunFailedError) as e: + q.put(GraphRunFailedEvent(reason=e.error)) + except Exception as e: logger.exception("Unknown Error when generating in parallel") + q.put(GraphRunFailedEvent(reason=str(e))) finally: q.put(None) db.session.remove() def _run_node(self, - node_id: str, + route_node_state: RouteNodeState, previous_node_id: Optional[str] = None, - parallel_id: Optional[str] = None - ) -> Generator[GraphEngineEvent, None, None]: + parallel_id: Optional[str] = None) -> Generator[GraphEngineEvent, None, None]: """ Run node """ # get node config + node_id = route_node_state.node_id node_config = self.graph.node_id_config_mapping.get(node_id) if not node_config: raise GraphRunFailedError(f'Node {node_id} config not found.') @@ -256,16 +281,7 @@ class GraphEngine: raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.') # init workflow run state - # node_instance = node_cls( # type: ignore - # config=node_config, - # graph_init_params=self.init_params, - # graph=self.graph, - # graph_runtime_state=self.graph_runtime_state, - # previous_node_id=previous_node_id - # ) - - # init workflow run state - node_instance = TestNode( + node_instance = node_cls( # type: ignore config=node_config, graph_init_params=self.init_params, graph=self.graph, @@ -274,7 +290,10 @@ class GraphEngine: ) # trigger node run start event - yield NodeRunStartedEvent(node_id=node_id, parallel_id=parallel_id) + yield NodeRunStartedEvent( + route_node_state=route_node_state, + parallel_id=parallel_id + ) db.session.close() @@ -283,60 +302,50 @@ class GraphEngine: self.graph_runtime_state.node_run_steps += 1 try: - start_at = datetime.now(timezone.utc).replace(tzinfo=None) - # run node generator = node_instance.run() - run_result = None for item in generator: if isinstance(item, RunCompletedEvent): run_result = item.run_result + route_node_state.status = RouteNodeState.Status.SUCCESS \ + if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED \ + else RouteNodeState.Status.FAILED + route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + route_node_state.node_run_result = run_result + route_node_state.failed_reason = run_result.error \ + if run_result.status == WorkflowNodeExecutionStatus.FAILED else None + if run_result.status == WorkflowNodeExecutionStatus.FAILED: yield NodeRunFailedEvent( - node_id=node_id, parallel_id=parallel_id, - run_result=run_result, - reason=run_result.error + route_node_state=route_node_state ) elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: yield NodeRunSucceededEvent( - node_id=node_id, parallel_id=parallel_id, - run_result=run_result + route_node_state=route_node_state ) - - self.graph_runtime_state.node_run_state.node_state_mapping[node_id] = RouteNodeState( - node_id=node_id, - start_at=start_at, - status=RouteNodeState.Status.SUCCESS if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - else RouteNodeState.Status.FAILED, - finished_at=datetime.now(timezone.utc).replace(tzinfo=None), - node_run_result=run_result, - failed_reason=run_result.error - if run_result.status == WorkflowNodeExecutionStatus.FAILED else None - ) - - # todo append self.graph_runtime_state.node_run_state.routes break elif isinstance(item, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( - node_id=node_id, + route_node_state=route_node_state, parallel_id=parallel_id, chunk_content=item.chunk_content, from_variable_selector=item.from_variable_selector, ) elif isinstance(item, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( - node_id=node_id, + route_node_state=route_node_state, parallel_id=parallel_id, retriever_resources=item.retriever_resources, context=item.context ) - - # todo record state except GenerateTaskStoppedException as e: # trigger node run failed event - yield NodeRunFailedEvent(node_id=node_id, parallel_id=parallel_id, reason=str(e)) + yield NodeRunFailedEvent( + route_node_state=route_node_state, + parallel_id=parallel_id + ) return except Exception as e: logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}") diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index e8f1678ecb..31956f6757 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -21,10 +21,9 @@ class AnswerNode(BaseNode): _node_data_cls = AnswerNodeData node_type = NodeType.ANSWER - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data @@ -38,7 +37,7 @@ class AnswerNode(BaseNode): if part.type == "var": part = cast(VarGenerateRouteChunk, part) value_selector = part.value_selector - value = variable_pool.get_variable_value( + value = self.graph_runtime_state.variable_pool.get_variable_value( variable_selector=value_selector ) diff --git a/api/core/workflow/nodes/base_node.py b/api/core/workflow/nodes/base_node.py index 7215cda4b9..4a02c96c10 100644 --- a/api/core/workflow/nodes/base_node.py +++ b/api/core/workflow/nodes/base_node.py @@ -9,7 +9,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent, RunEvent -from core.workflow.nodes.iterable_node import IterableNodeMixin +from core.workflow.nodes.iterable_node_mixin import IterableNodeMixin class BaseNode(ABC): @@ -104,21 +104,19 @@ class BaseNode(ABC): class BaseIterationNode(BaseNode, IterableNodeMixin): @abstractmethod - def _run(self, variable_pool: VariablePool) -> BaseIterationState: + def _run(self) -> BaseIterationState: """ Run node - :param variable_pool: variable pool :return: """ raise NotImplementedError - def run(self, variable_pool: VariablePool) -> BaseIterationState: + def run(self) -> BaseIterationState: """ Run node entry - :param variable_pool: variable pool :return: """ - return self._run(variable_pool=variable_pool) + return self._run(variable_pool=self.graph_runtime_state.variable_pool) def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str: """ diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index e15c1c6f87..dc2478f3a9 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -42,14 +42,13 @@ class CodeNode(BaseNode): return code_provider.get_default_config() - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run code - :param variable_pool: variable pool :return: """ node_data = self.node_data - node_data: CodeNodeData = cast(self._node_data_cls, node_data) + node_data = cast(CodeNodeData, node_data) # Get code language code_language = node_data.code_language @@ -59,7 +58,7 @@ class CodeNode(BaseNode): variables = {} for variable_selector in node_data.variables: variable = variable_selector.variable - value = variable_pool.get_variable_value( + value = self.graph_runtime_state.variable_pool.get_variable_value( variable_selector=variable_selector.value_selector ) diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 08d55d5576..936597c481 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -2,7 +2,6 @@ from typing import cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.end.entities import EndNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,19 +11,18 @@ class EndNode(BaseNode): _node_data_cls = EndNodeData node_type = NodeType.END - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ node_data = self.node_data - node_data = cast(self._node_data_cls, node_data) + node_data = cast(EndNodeData, node_data) output_variables = node_data.outputs outputs = {} for variable_selector in output_variables: - value = variable_pool.get_variable_value( + value = self.graph_runtime_state.variable_pool.get_variable_value( variable_selector=variable_selector.value_selector ) diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 24acf984f2..5e855ddc7e 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -49,14 +49,16 @@ class HttpRequestNode(BaseNode): }, } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data: HttpRequestNodeData = cast(HttpRequestNodeData, self.node_data) # init http executor http_executor = None try: http_executor = HttpExecutor( - node_data=node_data, timeout=self._get_request_timeout(node_data), variable_pool=variable_pool + node_data=node_data, + timeout=self._get_request_timeout(node_data), + variable_pool=self.graph_runtime_state.variable_pool ) # invoke http executor diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index c29588f3b4..da58cd0c1b 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -4,6 +4,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.if_else.entities import IfElseNodeData +from core.workflow.utils.condition.processor import ConditionProcessor from models.workflow import WorkflowNodeExecutionStatus @@ -30,11 +31,16 @@ class IfElseNode(BaseNode): input_conditions = [] final_result = False selected_case_id = None + condition_processor = ConditionProcessor() try: # Check if the new cases structure is used if node_data.cases: for case in node_data.cases: - input_conditions, group_result = self.process_conditions(self.graph_runtime_state.variable_pool, case.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=case.conditions + ) + # Apply the logical operator for the current case final_result = all(group_result) if case.logical_operator == "and" else any(group_result) @@ -53,7 +59,10 @@ class IfElseNode(BaseNode): else: # Fallback to old structure if cases are not defined - input_conditions, group_result = self.process_conditions(variable_pool, node_data.conditions) + input_conditions, group_result = condition_processor.process_conditions( + variable_pool=self.graph_runtime_state.variable_pool, + conditions=node_data.conditions + ) final_result = all(group_result) if node_data.logical_operator == "and" else any(group_result) diff --git a/api/core/workflow/nodes/iterable_node.py b/api/core/workflow/nodes/iterable_node_mixin.py similarity index 100% rename from api/core/workflow/nodes/iterable_node.py rename to api/core/workflow/nodes/iterable_node_mixin.py diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index fb12e07f85..1ce723133b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -17,7 +17,7 @@ class IterationNode(BaseIterationNode): _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION - def _run(self, variable_pool: VariablePool) -> BaseIterationState: + def _run(self) -> BaseIterationState: """ Run the node. """ @@ -32,7 +32,7 @@ class IterationNode(BaseIterationNode): iterator_length=len(iterator) if iterator is not None else 0 )) - self._set_current_iteration_variable(variable_pool, state) + self._set_current_iteration_variable(self.graph_runtime_state.variable_pool, state) return state def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str: diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 9e29bd9ea1..fdba392fe1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -14,7 +14,6 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrival_methods import RetrievalMethod from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db @@ -37,11 +36,11 @@ class KnowledgeRetrievalNode(BaseNode): _node_data_cls = KnowledgeRetrievalNodeData node_type = NodeType.KNOWLEDGE_RETRIEVAL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: - node_data: KnowledgeRetrievalNodeData = cast(self._node_data_cls, self.node_data) + def _run(self) -> NodeRunResult: + node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - query = variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) + query = self.graph_runtime_state.variable_pool.get_variable_value(variable_selector=node_data.query_variable_selector) variables = { 'query': query } diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 89a602248b..42bc62a8ee 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -27,7 +27,6 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.graph_engine.entities.event import NodeRunRetrieverResourceEvent from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.event import RunCompletedEvent, RunEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.llm.entities import ( @@ -85,9 +84,7 @@ class LLMNode(BaseNode): for event in generator: if isinstance(event, RunRetrieverResourceEvent): context = event.context - yield NodeRunRetrieverResourceEvent( - retriever_resources=event.retriever_resources - ) + yield event if context: node_inputs['#context#'] = context # type: ignore @@ -170,7 +167,7 @@ class LLMNode(BaseNode): model_instance: ModelInstance, prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ - -> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]: + -> Generator[RunEvent, None, None]: """ Invoke large language model :param node_data_model: node data model @@ -204,7 +201,7 @@ class LLMNode(BaseNode): self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) def _handle_invoke_result(self, invoke_result: LLMResult | Generator) \ - -> Generator["RunStreamChunkEvent | ModelInvokeCompleted", None, None]: + -> Generator[RunEvent, None, None]: """ Handle invoke result :param invoke_result: invoke result diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 4b1421cae9..0c20312d84 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -14,8 +14,8 @@ class LoopNode(BaseIterationNode): _node_data_cls = LoopNodeData _node_type = NodeType.LOOP - def _run(self, variable_pool: VariablePool) -> LoopState: - return super()._run(variable_pool) + def _run(self) -> LoopState: + return super()._run() def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str: """ diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index d219156026..eebc36cc55 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -66,12 +66,12 @@ class ParameterExtractorNode(LLMNode): } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the node. """ node_data = cast(ParameterExtractorNodeData, self.node_data) - query = variable_pool.get_variable_value(node_data.query) + query = self.graph_runtime_state.variable_pool.get_variable_value(node_data.query) if not query: raise ValueError("Input variable content not found or is empty") @@ -91,17 +91,20 @@ class ParameterExtractorNode(LLMNode): raise ValueError("Model schema not found") # fetch memory - memory = self._fetch_memory(node_data.memory, variable_pool, model_instance) + memory = self._fetch_memory(node_data.memory, self.graph_runtime_state.variable_pool, model_instance) if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \ and node_data.reasoning_mode == 'function_call': # use function call prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data, query, variable_pool, model_config, memory + node_data, query, self.graph_runtime_state.variable_pool, model_config, memory ) else: # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, + prompt_messages = self._generate_prompt_engineering_prompt(node_data, + query, + self.graph_runtime_state.variable_pool, + model_config, memory) prompt_message_tools = [] diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index fd51a6c476..a534b5b97f 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,7 +1,6 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.start.entities import StartNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -11,17 +10,16 @@ class StartNode(BaseNode): _node_data_cls = StartNodeData node_type = NodeType.START - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node - :param variable_pool: variable pool :return: """ # Get cleaned inputs - cleaned_inputs = variable_pool.user_inputs + cleaned_inputs = self.graph_runtime_state.variable_pool.user_inputs - for var in variable_pool.system_variables: - cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var] + for var in self.graph_runtime_state.variable_pool.system_variables: + cleaned_inputs['sys.' + var.value] = self.graph_runtime_state.variable_pool.system_variables[var] return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 2c4a2257f5..923c2ae1ae 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -3,13 +3,13 @@ from typing import Optional, cast from core.helper.code_executor.code_executor import CodeExecutionException, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData from models.workflow import WorkflowNodeExecutionStatus MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get('TEMPLATE_TRANSFORM_MAX_LENGTH', '80000')) + class TemplateTransformNode(BaseNode): _node_data_cls = TemplateTransformNodeData _node_type = NodeType.TEMPLATE_TRANSFORM @@ -34,7 +34,7 @@ class TemplateTransformNode(BaseNode): } } - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run node """ @@ -45,7 +45,7 @@ class TemplateTransformNode(BaseNode): variables = {} for variable_selector in node_data.variables: variable = variable_selector.variable - value = variable_pool.get_variable_value( + value = self.graph_runtime_state.variable_pool.get_variable_value( variable_selector=variable_selector.value_selector ) @@ -63,7 +63,7 @@ class TemplateTransformNode(BaseNode): status=WorkflowNodeExecutionStatus.FAILED, error=str(e) ) - + if len(result['result']) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH: return NodeRunResult( inputs=variables, @@ -78,9 +78,10 @@ class TemplateTransformNode(BaseNode): 'output': result['result'] } ) - + @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[str, list[str]]: + def _extract_variable_selector_to_variable_mapping(cls, node_data: TemplateTransformNodeData) -> dict[ + str, list[str]]: """ Extract variable selector to variable mapping :param node_data: node data @@ -88,4 +89,4 @@ class TemplateTransformNode(BaseNode): """ return { variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } \ No newline at end of file + } diff --git a/api/core/workflow/nodes/test/__init__.py b/api/core/workflow/nodes/test/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/core/workflow/nodes/test/entities.py b/api/core/workflow/nodes/test/entities.py deleted file mode 100644 index 8d9610737c..0000000000 --- a/api/core/workflow/nodes/test/entities.py +++ /dev/null @@ -1,8 +0,0 @@ -from core.workflow.entities.base_node_data_entities import BaseNodeData - - -class TestNodeData(BaseNodeData): - """ - Test Node Data. - """ - pass diff --git a/api/core/workflow/nodes/test/test_node.py b/api/core/workflow/nodes/test/test_node.py deleted file mode 100644 index 91e6fd5bc0..0000000000 --- a/api/core/workflow/nodes/test/test_node.py +++ /dev/null @@ -1,33 +0,0 @@ - -from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.nodes.base_node import BaseNode -from core.workflow.nodes.test.entities import TestNodeData -from models.workflow import WorkflowNodeExecutionStatus - - -class TestNode(BaseNode): - _node_data_cls = TestNodeData - node_type = NodeType.ANSWER - - def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "content": "abc" - }, - edge_source_handle="1" - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: - """ - Extract variable selector to variable mapping - :param node_data: node data - :return: - """ - return {} diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index cddea03bf8..449e838617 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -23,7 +23,7 @@ class ToolNode(BaseNode): _node_data_cls = ToolNodeData _node_type = NodeType.TOOL - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: """ Run the tool node """ @@ -52,7 +52,7 @@ class ToolNode(BaseNode): ) # get parameters - parameters = self._generate_parameters(variable_pool, node_data, tool_runtime) + parameters = self._generate_parameters(self.graph_runtime_state.variable_pool, node_data, tool_runtime) try: messages = ToolEngine.workflow_invoke( @@ -136,7 +136,8 @@ class ToolNode(BaseNode): return files - def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]: + def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) \ + -> tuple[str, list[FileVar], list[dict]]: """ Convert ToolInvokeMessages into tuple[plain_text, files] """ diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 63ce790625..0576e07824 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -2,7 +2,6 @@ from typing import cast from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType -from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData from models.workflow import WorkflowNodeExecutionStatus @@ -12,7 +11,7 @@ class VariableAggregatorNode(BaseNode): _node_data_cls = VariableAssignerNodeData _node_type = NodeType.VARIABLE_AGGREGATOR - def _run(self, variable_pool: VariablePool) -> NodeRunResult: + def _run(self) -> NodeRunResult: node_data = cast(VariableAssignerNodeData, self.node_data) # Get variables outputs = {} @@ -20,7 +19,7 @@ class VariableAggregatorNode(BaseNode): if not node_data.advanced_settings or not node_data.advanced_settings.group_enabled: for variable in node_data.variables: - value = variable_pool.get_variable_value(variable) + value = self.graph_runtime_state.variable_pool.get_variable_value(variable) if value is not None: outputs = { @@ -34,7 +33,7 @@ class VariableAggregatorNode(BaseNode): else: for group in node_data.advanced_settings.groups: for variable in group.variables: - value = variable_pool.get_variable_value(variable) + value = self.graph_runtime_state.variable_pool.get_variable_value(variable) if value is not None: outputs[group.group_name] = { diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index 524cce1a43..e195730a31 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -14,6 +14,4 @@ class Condition(BaseModel): # for number "=", "≠", ">", "<", "≥", "≤", "null", "not null" ] - value_type: Literal["string", "value_selector"] = "string" value: Optional[str] = None - value_selector: Optional[list[str]] = None diff --git a/api/core/workflow/utils/condition/processor.py b/api/core/workflow/utils/condition/processor.py index cebd234570..d617d3abdc 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/core/workflow/utils/condition/processor.py @@ -1,95 +1,101 @@ -from typing import Literal, Optional +from typing import Any, Optional +from core.file.file_obj import FileVar from core.workflow.entities.variable_pool import VariablePool from core.workflow.utils.condition.entities import Condition +from core.workflow.utils.variable_template_parser import VariableTemplateParser class ConditionProcessor: - def process(self, variable_pool: VariablePool, - logical_operator: Literal["and", "or"], - conditions: list[Condition]) -> tuple[bool, list[dict]]: - """ - Process conditions - - :param variable_pool: variable pool - :param logical_operator: logical operator - :param conditions: conditions - """ + def process_conditions(self, variable_pool: VariablePool, conditions: list[Condition]): input_conditions = [] - sub_condition_compare_results = [] + group_result = [] - try: - for condition in conditions: - actual_value = variable_pool.get_variable_value( - variable_selector=condition.variable_selector - ) + index = 0 + for condition in conditions: + index += 1 + actual_value = variable_pool.get_variable_value( + variable_selector=condition.variable_selector + ) - if condition.value_type == "value_selector": - expected_value = variable_pool.get_variable_value( - variable_selector=condition.value_selector - ) + expected_value = None + if condition.value is not None: + variable_template_parser = VariableTemplateParser(template=condition.value) + variable_selectors = variable_template_parser.extract_variable_selectors() + if variable_selectors: + for variable_selector in variable_selectors: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector + ) + expected_value = variable_template_parser.format({variable_selector.variable: value}) + + if expected_value is None: + expected_value = condition.value else: expected_value = condition.value - input_conditions.append({ + comparison_operator = condition.comparison_operator + input_conditions.append( + { "actual_value": actual_value, "expected_value": expected_value, - "comparison_operator": condition.comparison_operator - }) + "comparison_operator": comparison_operator + } + ) - for input_condition in input_conditions: - actual_value = input_condition["actual_value"] - expected_value = input_condition["expected_value"] - comparison_operator = input_condition["comparison_operator"] + result = self.evaluate_condition(actual_value, comparison_operator, expected_value) + group_result.append(result) - if comparison_operator == "contains": - compare_result = self._assert_contains(actual_value, expected_value) - elif comparison_operator == "not contains": - compare_result = self._assert_not_contains(actual_value, expected_value) - elif comparison_operator == "start with": - compare_result = self._assert_start_with(actual_value, expected_value) - elif comparison_operator == "end with": - compare_result = self._assert_end_with(actual_value, expected_value) - elif comparison_operator == "is": - compare_result = self._assert_is(actual_value, expected_value) - elif comparison_operator == "is not": - compare_result = self._assert_is_not(actual_value, expected_value) - elif comparison_operator == "empty": - compare_result = self._assert_empty(actual_value) - elif comparison_operator == "not empty": - compare_result = self._assert_not_empty(actual_value) - elif comparison_operator == "=": - compare_result = self._assert_equal(actual_value, expected_value) - elif comparison_operator == "≠": - compare_result = self._assert_not_equal(actual_value, expected_value) - elif comparison_operator == ">": - compare_result = self._assert_greater_than(actual_value, expected_value) - elif comparison_operator == "<": - compare_result = self._assert_less_than(actual_value, expected_value) - elif comparison_operator == "≥": - compare_result = self._assert_greater_than_or_equal(actual_value, expected_value) - elif comparison_operator == "≤": - compare_result = self._assert_less_than_or_equal(actual_value, expected_value) - elif comparison_operator == "null": - compare_result = self._assert_null(actual_value) - elif comparison_operator == "not null": - compare_result = self._assert_not_null(actual_value) - else: - continue + return input_conditions, group_result - sub_condition_compare_results.append({ - **input_condition, - "result": compare_result - }) - except Exception as e: - raise ConditionAssertionError(str(e), input_conditions, sub_condition_compare_results) + def evaluate_condition( + self, + actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None], + comparison_operator: str, + expected_value: Optional[str] = None + ) -> bool: + """ + Evaluate condition + :param actual_value: actual value + :param expected_value: expected value + :param comparison_operator: comparison operator - if logical_operator == "and": - compare_result = False not in [condition["result"] for condition in sub_condition_compare_results] + :return: bool + """ + if comparison_operator == "contains": + return self._assert_contains(actual_value, expected_value) # type: ignore + elif comparison_operator == "not contains": + return self._assert_not_contains(actual_value, expected_value) # type: ignore + elif comparison_operator == "start with": + return self._assert_start_with(actual_value, expected_value) # type: ignore + elif comparison_operator == "end with": + return self._assert_end_with(actual_value, expected_value) # type: ignore + elif comparison_operator == "is": + return self._assert_is(actual_value, expected_value) # type: ignore + elif comparison_operator == "is not": + return self._assert_is_not(actual_value, expected_value) # type: ignore + elif comparison_operator == "empty": + return self._assert_empty(actual_value) # type: ignore + elif comparison_operator == "not empty": + return self._assert_not_empty(actual_value) # type: ignore + elif comparison_operator == "=": + return self._assert_equal(actual_value, expected_value) # type: ignore + elif comparison_operator == "≠": + return self._assert_not_equal(actual_value, expected_value) # type: ignore + elif comparison_operator == ">": + return self._assert_greater_than(actual_value, expected_value) # type: ignore + elif comparison_operator == "<": + return self._assert_less_than(actual_value, expected_value) # type: ignore + elif comparison_operator == "≥": + return self._assert_greater_than_or_equal(actual_value, expected_value) # type: ignore + elif comparison_operator == "≤": + return self._assert_less_than_or_equal(actual_value, expected_value) # type: ignore + elif comparison_operator == "null": + return self._assert_null(actual_value) # type: ignore + elif comparison_operator == "not null": + return self._assert_not_null(actual_value) # type: ignore else: - compare_result = True in [condition["result"] for condition in sub_condition_compare_results] - - return compare_result, sub_condition_compare_results + raise ValueError(f"Invalid comparison operator: {comparison_operator}") def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool: """ @@ -301,7 +307,8 @@ class ConditionProcessor: return False return True - def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + def _assert_greater_than_or_equal(self, actual_value: Optional[int | float], + expected_value: str | int | float) -> bool: """ Assert greater than or equal :param actual_value: actual value @@ -323,7 +330,8 @@ class ConditionProcessor: return False return True - def _assert_less_than_or_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool: + def _assert_less_than_or_equal(self, actual_value: Optional[int | float], + expected_value: str | int | float) -> bool: """ Assert less than or equal :param actual_value: actual value diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 15cf5367d3..e6baa36b16 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -4,7 +4,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.code.code_node import CodeNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index eaed24e56c..b729d247e6 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -4,7 +4,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.http_request.http_request_node import HttpRequestNode from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d7a6c1224f..043fd05a0b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -12,7 +12,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import ModelProviderFactory from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.llm.llm_node import LLMNode from extensions.ext_database import db from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index e5fd2bc1fd..daa6767232 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from extensions.ext_database import db from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 02999bf0a2..72856b2d06 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -2,7 +2,7 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.workflow import WorkflowNodeExecutionStatus from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index fffd074457..be8694968a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,6 +1,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.tool.tool_node import ToolNode from models.workflow import WorkflowNodeExecutionStatus diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index 102711b4b6..d1da4ffe8f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -4,7 +4,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index 6860b2fd97..6b08dc7f33 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import SystemVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.nodes.base_node import UserFrom +from core.workflow.entities.node_entities import UserFrom from core.workflow.nodes.if_else.if_else_node import IfElseNode from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus