refactor: Replaces direct DB session usage with context managers (#20569)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-06-03 16:16:06 +08:00 committed by GitHub
parent db83bfc53a
commit 72fdafc180
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 50 deletions

View File

@ -53,7 +53,6 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -607,8 +606,6 @@ class GraphEngine:
error=str(e), error=str(e),
) )
) )
finally:
db.session.remove()
def _run_node( def _run_node(
self, self,
@ -646,7 +643,6 @@ class GraphEngine:
agent_strategy=agent_strategy, agent_strategy=agent_strategy,
) )
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0 retries = 0
@ -863,8 +859,6 @@ class GraphEngine:
except Exception as e: except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed") logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
""" """

View File

@ -2,6 +2,9 @@ import json
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@ -320,15 +323,12 @@ class AgentNode(ToolNode):
return None return None
conversation_id = conversation_id_variable.value conversation_id = conversation_id_variable.value
# get conversation with Session(db.engine, expire_on_commit=False) as session:
conversation = ( stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
db.session.query(Conversation) conversation = session.scalar(stmt)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation: if not conversation:
return None return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)

View File

@ -8,6 +8,7 @@ from typing import Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session
from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode):
redis_client.zremrangebyscore(key, 0, current_time - 60000) redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key) request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit: if request_count > knowledge_rate_limit.limit:
# add ratelimit record with Session(db.engine) as session:
rate_limit_log = RateLimitLog( # add ratelimit record
tenant_id=self.tenant_id, rate_limit_log = RateLimitLog(
subscription_plan=knowledge_rate_limit.subscription_plan, tenant_id=self.tenant_id,
operation="knowledge", subscription_plan=knowledge_rate_limit.subscription_plan,
) operation="knowledge",
db.session.add(rate_limit_log) )
db.session.commit() session.add(rate_limit_log)
session.commit()
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables, inputs=variables,

View File

@ -7,6 +7,8 @@ from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -303,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages: Sequence[PromptMessage], prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]: ) -> Generator[NodeEvent, None, None]:
db.session.close()
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,
@ -603,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]):
return None return None
conversation_id = conversation_id_variable.value conversation_id = conversation_id_variable.value
# get conversation with Session(db.engine, expire_on_commit=False) as session:
conversation = ( stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
db.session.query(Conversation) conversation = session.scalar(stmt)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) if not conversation:
.first() return None
)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
@ -847,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]):
used_quota = 1 used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None: if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter( with Session(db.engine) as session:
Provider.tenant_id == tenant_id, stmt = (
# TODO: Use provider name with prefix after the data migration. update(Provider)
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, .where(
Provider.provider_type == ProviderType.SYSTEM.value, Provider.tenant_id == tenant_id,
Provider.quota_type == system_configuration.current_quota_type.value, # TODO: Use provider name with prefix after the data migration.
Provider.quota_limit > Provider.quota_used, Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
).update( Provider.provider_type == ProviderType.SYSTEM.value,
{ Provider.quota_type == system_configuration.current_quota_type.value,
"quota_used": Provider.quota_used + used_quota, Provider.quota_limit > Provider.quota_used,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None), )
} .values(
) quota_used=Provider.quota_used + used_quota,
db.session.commit() last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(

View File

@ -31,7 +31,6 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.utils import variable_template_parser from core.workflow.utils import variable_template_parser
from extensions.ext_database import db
from .entities import ParameterExtractorNodeData from .entities import ParameterExtractorNodeData
from .exc import ( from .exc import (
@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
tools: list[PromptMessageTool], tools: list[PromptMessageTool],
stop: list[str], stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
db.session.close()
invoke_result = model_instance.invoke_llm( invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages, prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params, model_parameters=node_data_model.completion_params,