refactor: Replaces direct DB session usage with context managers (#20569)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2025-06-03 16:16:06 +08:00 committed by GitHub
parent db83bfc53a
commit 72fdafc180
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 43 additions and 50 deletions

View File

@ -53,7 +53,6 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -607,8 +606,6 @@ class GraphEngine:
error=str(e),
)
)
finally:
db.session.remove()
def _run_node(
self,
@ -646,7 +643,6 @@ class GraphEngine:
agent_strategy=agent_strategy,
)
db.session.close()
max_retries = node_instance.node_data.retry_config.max_retries
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
retries = 0
@ -863,8 +859,6 @@ class GraphEngine:
except Exception as e:
logger.exception(f"Node {node_instance.node_data.title} run failed")
raise e
finally:
db.session.close()
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
"""

View File

@ -2,6 +2,9 @@ import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
@ -320,15 +323,12 @@ class AgentNode(ToolNode):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)

View File

@ -8,6 +8,7 @@ from typing import Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode):
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
db.session.add(rate_limit_log)
db.session.commit()
with Session(db.engine) as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
session.commit()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,

View File

@ -7,6 +7,8 @@ from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
@ -303,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]):
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
) -> Generator[NodeEvent, None, None]:
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
@ -603,15 +603,11 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (
db.session.query(Conversation)
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
.first()
)
if not conversation:
return None
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
@ -847,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]):
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
db.session.query(Provider).filter(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
).update(
{
"quota_used": Provider.quota_used + used_quota,
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
}
)
db.session.commit()
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()
@classmethod
def _extract_variable_selector_to_variable_mapping(

View File

@ -31,7 +31,6 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig
from core.workflow.utils import variable_template_parser
from extensions.ext_database import db
from .entities import ParameterExtractorNodeData
from .exc import (
@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
tools: list[PromptMessageTool],
stop: list[str],
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,