From 72fdafc180adafb81165fe6409772de27283c3c2 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 3 Jun 2025 16:16:06 +0800 Subject: [PATCH] refactor: Replaces direct DB session usage with context managers (#20569) Signed-off-by: -LAN- --- .../workflow/graph_engine/graph_engine.py | 6 --- api/core/workflow/nodes/agent/agent_node.py | 16 +++--- .../knowledge_retrieval_node.py | 18 ++++--- api/core/workflow/nodes/llm/node.py | 50 +++++++++---------- .../parameter_extractor_node.py | 3 -- 5 files changed, 43 insertions(+), 50 deletions(-) diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 3eb99fde81..363b2ee920 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -53,7 +53,6 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType @@ -607,8 +606,6 @@ class GraphEngine: error=str(e), ) ) - finally: - db.session.remove() def _run_node( self, @@ -646,7 +643,6 @@ class GraphEngine: agent_strategy=agent_strategy, ) - db.session.close() max_retries = node_instance.node_data.retry_config.max_retries retry_interval = node_instance.node_data.retry_config.retry_interval_seconds retries = 0 @@ -863,8 +859,6 @@ class GraphEngine: except Exception as e: logger.exception(f"Node {node_instance.node_data.title} run failed") raise e - finally: - db.session.close() def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 30b17cbd84..faa8f90bea 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,6 +2,9 @@ import json from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory @@ -320,15 +323,12 @@ class AgentNode(ToolNode): return None conversation_id = conversation_id_variable.value - # get conversation - conversation = ( - db.session.query(Conversation) - .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) - .first() - ) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) - if not conversation: - return None + if not conversation: + return None memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2ddb4f8a0b..53124f962a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,6 +8,7 @@ from typing import Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast +from sqlalchemy.orm import Session from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode): redis_client.zremrangebyscore(key, 0, current_time - 60000) request_count = redis_client.zcard(key) if request_count > knowledge_rate_limit.limit: - # add ratelimit record - rate_limit_log = RateLimitLog( - tenant_id=self.tenant_id, - subscription_plan=knowledge_rate_limit.subscription_plan, - operation="knowledge", - ) - db.session.add(rate_limit_log) - db.session.commit() + with Session(db.engine) as session: + # add ratelimit record + rate_limit_log = RateLimitLog( + tenant_id=self.tenant_id, + subscription_plan=knowledge_rate_limit.subscription_plan, + operation="knowledge", + ) + session.add(rate_limit_log) + session.commit() return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 0fd7c31ffb..df8f614db3 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -7,6 +7,8 @@ from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast import json_repair +from sqlalchemy import select, update +from sqlalchemy.orm import Session from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -303,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]): prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: - db.session.close() - invoke_result = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=node_data_model.completion_params, @@ -603,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]): return None conversation_id = conversation_id_variable.value - # get conversation - conversation = ( - db.session.query(Conversation) - .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) - .first() - ) - - if not conversation: - return None + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) @@ -847,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]): used_quota = 1 if used_quota is not None and system_configuration.current_quota_type is not None: - db.session.query(Provider).filter( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ).update( - { - "quota_used": Provider.quota_used + used_quota, - "last_used": datetime.now(tz=UTC).replace(tzinfo=None), - } - ) - db.session.commit() + with Session(db.engine) as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=datetime.now(tz=UTC).replace(tzinfo=None), + ) + ) + session.execute(stmt) + session.commit() @classmethod def _extract_variable_selector_to_variable_mapping( diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index ea4070e224..4f31258b15 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -31,7 +31,6 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.nodes.enums import NodeType from core.workflow.nodes.llm import LLMNode, ModelConfig from core.workflow.utils import variable_template_parser -from extensions.ext_database import db from .entities import ParameterExtractorNodeData from .exc import ( @@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode): tools: list[PromptMessageTool], stop: list[str], ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: - db.session.close() - invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=node_data_model.completion_params,