mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 08:36:01 +08:00
refactor: Replaces direct DB session usage with context managers (#20569)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
db83bfc53a
commit
72fdafc180
@ -53,7 +53,6 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
|||||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@ -607,8 +606,6 @@ class GraphEngine:
|
|||||||
error=str(e),
|
error=str(e),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
finally:
|
|
||||||
db.session.remove()
|
|
||||||
|
|
||||||
def _run_node(
|
def _run_node(
|
||||||
self,
|
self,
|
||||||
@ -646,7 +643,6 @@ class GraphEngine:
|
|||||||
agent_strategy=agent_strategy,
|
agent_strategy=agent_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.close()
|
|
||||||
max_retries = node_instance.node_data.retry_config.max_retries
|
max_retries = node_instance.node_data.retry_config.max_retries
|
||||||
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
|
||||||
retries = 0
|
retries = 0
|
||||||
@ -863,8 +859,6 @@ class GraphEngine:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
logger.exception(f"Node {node_instance.node_data.title} run failed")
|
||||||
raise e
|
raise e
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
|
||||||
"""
|
"""
|
||||||
|
@ -2,6 +2,9 @@ import json
|
|||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.agent.entities import AgentToolEntity
|
from core.agent.entities import AgentToolEntity
|
||||||
from core.agent.plugin_entities import AgentStrategyParameter
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
@ -320,12 +323,9 @@ class AgentNode(ToolNode):
|
|||||||
return None
|
return None
|
||||||
conversation_id = conversation_id_variable.value
|
conversation_id = conversation_id_variable.value
|
||||||
|
|
||||||
# get conversation
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
conversation = (
|
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||||
db.session.query(Conversation)
|
conversation = session.scalar(stmt)
|
||||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return None
|
return None
|
||||||
|
@ -8,6 +8,7 @@ from typing import Any, Optional, cast
|
|||||||
|
|
||||||
from sqlalchemy import Float, and_, func, or_, text
|
from sqlalchemy import Float, and_, func, or_, text
|
||||||
from sqlalchemy import cast as sqlalchemy_cast
|
from sqlalchemy import cast as sqlalchemy_cast
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
@ -95,14 +96,15 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||||
request_count = redis_client.zcard(key)
|
request_count = redis_client.zcard(key)
|
||||||
if request_count > knowledge_rate_limit.limit:
|
if request_count > knowledge_rate_limit.limit:
|
||||||
|
with Session(db.engine) as session:
|
||||||
# add ratelimit record
|
# add ratelimit record
|
||||||
rate_limit_log = RateLimitLog(
|
rate_limit_log = RateLimitLog(
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||||
operation="knowledge",
|
operation="knowledge",
|
||||||
)
|
)
|
||||||
db.session.add(rate_limit_log)
|
session.add(rate_limit_log)
|
||||||
db.session.commit()
|
session.commit()
|
||||||
return NodeRunResult(
|
return NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
inputs=variables,
|
inputs=variables,
|
||||||
|
@ -7,6 +7,8 @@ from datetime import UTC, datetime
|
|||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
|
from sqlalchemy import select, update
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
@ -303,8 +305,6 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
prompt_messages: Sequence[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
stop: Optional[Sequence[str]] = None,
|
stop: Optional[Sequence[str]] = None,
|
||||||
) -> Generator[NodeEvent, None, None]:
|
) -> Generator[NodeEvent, None, None]:
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
invoke_result = model_instance.invoke_llm(
|
invoke_result = model_instance.invoke_llm(
|
||||||
prompt_messages=list(prompt_messages),
|
prompt_messages=list(prompt_messages),
|
||||||
model_parameters=node_data_model.completion_params,
|
model_parameters=node_data_model.completion_params,
|
||||||
@ -603,13 +603,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
return None
|
return None
|
||||||
conversation_id = conversation_id_variable.value
|
conversation_id = conversation_id_variable.value
|
||||||
|
|
||||||
# get conversation
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
conversation = (
|
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||||
db.session.query(Conversation)
|
conversation = session.scalar(stmt)
|
||||||
.filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not conversation:
|
if not conversation:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -847,20 +843,24 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
used_quota = 1
|
used_quota = 1
|
||||||
|
|
||||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||||
db.session.query(Provider).filter(
|
with Session(db.engine) as session:
|
||||||
|
stmt = (
|
||||||
|
update(Provider)
|
||||||
|
.where(
|
||||||
Provider.tenant_id == tenant_id,
|
Provider.tenant_id == tenant_id,
|
||||||
# TODO: Use provider name with prefix after the data migration.
|
# TODO: Use provider name with prefix after the data migration.
|
||||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||||
Provider.quota_limit > Provider.quota_used,
|
Provider.quota_limit > Provider.quota_used,
|
||||||
).update(
|
|
||||||
{
|
|
||||||
"quota_used": Provider.quota_used + used_quota,
|
|
||||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
db.session.commit()
|
.values(
|
||||||
|
quota_used=Provider.quota_used + used_quota,
|
||||||
|
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
@ -31,7 +31,6 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
|
|||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.llm import LLMNode, ModelConfig
|
from core.workflow.nodes.llm import LLMNode, ModelConfig
|
||||||
from core.workflow.utils import variable_template_parser
|
from core.workflow.utils import variable_template_parser
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
from .entities import ParameterExtractorNodeData
|
from .entities import ParameterExtractorNodeData
|
||||||
from .exc import (
|
from .exc import (
|
||||||
@ -259,8 +258,6 @@ class ParameterExtractorNode(LLMNode):
|
|||||||
tools: list[PromptMessageTool],
|
tools: list[PromptMessageTool],
|
||||||
stop: list[str],
|
stop: list[str],
|
||||||
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||||
db.session.close()
|
|
||||||
|
|
||||||
invoke_result = model_instance.invoke_llm(
|
invoke_result = model_instance.invoke_llm(
|
||||||
prompt_messages=prompt_messages,
|
prompt_messages=prompt_messages,
|
||||||
model_parameters=node_data_model.completion_params,
|
model_parameters=node_data_model.completion_params,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user