From 490b6d092e5efe9795292d6de00d0bc4a84e9568 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 25 Feb 2025 12:20:47 +0800 Subject: [PATCH] Fix/plugin race condition (#14253) --- api/app_factory.py | 7 ++ api/contexts/__init__.py | 19 ++++-- api/contexts/wrapper.py | 65 +++++++++++++++++++ api/core/agent/entities.py | 4 +- .../easy_ui_based_app/model_config/manager.py | 6 +- api/core/app/app_config/entities.py | 8 +-- api/core/app/entities/app_invoke_entities.py | 8 +-- api/core/entities/provider_configuration.py | 9 ++- api/core/hosting_configuration.py | 6 +- .../model_providers/model_provider_factory.py | 15 ++--- api/services/dataset_service.py | 10 +-- 11 files changed, 116 insertions(+), 41 deletions(-) create mode 100644 api/contexts/wrapper.py diff --git a/api/app_factory.py b/api/app_factory.py index c0714116a3..52ae05583a 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -2,6 +2,7 @@ import logging import time from configs import dify_config +from contexts.wrapper import RecyclableContextVar from dify_app import DifyApp @@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp: dify_app = DifyApp(__name__) dify_app.config.from_mapping(dify_config.model_dump()) + # add before request hook + @dify_app.before_request + def before_request(): + # add an unique identifier to each request + RecyclableContextVar.increment_thread_recycles() + return dify_app diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 64dae3a2d2..91438d086a 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -2,6 +2,8 @@ from contextvars import ContextVar from threading import Lock from typing import TYPE_CHECKING +from contexts.wrapper import RecyclableContextVar + if TYPE_CHECKING: from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController @@ -12,8 +14,17 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") -plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") -plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") +""" +To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with +""" +plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( + ContextVar("plugin_tool_providers") +) +plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) -plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers") -plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock") +plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( + ContextVar("plugin_model_providers") +) +plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( + ContextVar("plugin_model_providers_lock") +) diff --git a/api/contexts/wrapper.py b/api/contexts/wrapper.py new file mode 100644 index 0000000000..8cd53487ef --- /dev/null +++ b/api/contexts/wrapper.py @@ -0,0 +1,65 @@ +from contextvars import ContextVar +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class HiddenValue: + pass + + +_default = HiddenValue() + + +class RecyclableContextVar(Generic[T]): + """ + RecyclableContextVar is a wrapper around ContextVar + It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now + + NOTE: you need to call `increment_thread_recycles` before requests + """ + + _thread_recycles: ContextVar[int] = ContextVar("thread_recycles") + + @classmethod + def increment_thread_recycles(cls): + try: + recycles = cls._thread_recycles.get() + cls._thread_recycles.set(recycles + 1) + except LookupError: + cls._thread_recycles.set(0) + + def __init__(self, context_var: ContextVar[T]): + self._context_var = context_var + self._updates = ContextVar[int](context_var.name + "_updates", default=0) + + def get(self, default: T | HiddenValue = _default) -> T: + thread_recycles = self._thread_recycles.get(0) + self_updates = self._updates.get() + if thread_recycles > self_updates: + self._updates.set(thread_recycles) + + # check if thread is recycled and should be updated + if thread_recycles < self_updates: + return self._context_var.get() + else: + # thread_recycles >= self_updates, means current context is invalid + if isinstance(default, HiddenValue) or default is _default: + raise LookupError + else: + return default + + def set(self, value: T): + # it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before + # increase it manually + thread_recycles = self._thread_recycles.get(0) + self_updates = self._updates.get() + if thread_recycles > self_updates: + self._updates.set(thread_recycles) + + if self._updates.get() == self._thread_recycles.get(0): + # after increment, + self._updates.set(self._updates.get() + 1) + + # set the context + self._context_var.set(value) diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 5c7cee3a78..e68b4f2356 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -1,7 +1,7 @@ from enum import StrEnum from typing import Any, Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType @@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel): provider_type: ToolProviderType provider_id: str tool_name: str - tool_parameters: dict[str, Any] = {} + tool_parameters: dict[str, Any] = Field(default_factory=dict) plugin_unique_identifier: str | None = None diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index b19865ff4c..54bca10fc3 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,9 +2,9 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.entities import DEFAULT_PLUGIN_ID from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager @@ -61,9 +61,7 @@ class ModelConfigManager: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") if "/" not in config["model"]["provider"]: - config["model"]["provider"] = ( - f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}" - ) + config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"])) if config["model"]["provider"] not in model_provider_names: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 15bd353484..16b69a4468 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel): provider: str model: str mode: Optional[str] = None - parameters: dict[str, Any] = {} - stop: list[str] = [] + parameters: dict[str, Any] = Field(default_factory=dict) + stop: list[str] = Field(default_factory=list) class AdvancedChatMessageEntity(BaseModel): @@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel): variable: str type: str - config: dict[str, Any] = {} + config: dict[str, Any] = Field(default_factory=dict) class DatasetRetrieveConfigEntity(BaseModel): @@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel): """ type: str - config: dict[str, Any] = {} + config: dict[str, Any] = Field(default_factory=dict) class TextToSpeechEntity(BaseModel): diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index df5da61927..57beeaacc0 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel): model_schema: AIModelEntity mode: str provider_model_bundle: ProviderModelBundle - credentials: dict[str, Any] = {} - parameters: dict[str, Any] = {} - stop: list[str] = [] + credentials: dict[str, Any] = Field(default_factory=dict) + parameters: dict[str, Any] = Field(default_factory=dict) + stop: list[str] = Field(default_factory=list) # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel): call_depth: int = 0 # extra parameters, like: auto_generate_conversation_name - extras: dict[str, Any] = {} + extras: dict[str, Any] = Field(default_factory=dict) # tracing instance trace_manager: Optional[TraceQueueManager] = None diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index e032b0fa4a..b1a155fea8 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -6,11 +6,10 @@ from collections.abc import Iterator, Sequence from json import JSONDecodeError from typing import Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import or_ from constants import HIDDEN_VALUE -from core.entities import DEFAULT_PLUGIN_ID from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import ( CustomConfiguration, @@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel): """ tenant_id: str - configurations: dict[str, ProviderConfiguration] = {} + configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict) def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) @@ -1060,7 +1059,7 @@ class ProviderConfigurations(BaseModel): def __getitem__(self, key): if "/" not in key: - key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" + key = str(ModelProviderID(key)) return self.configurations[key] @@ -1075,7 +1074,7 @@ class ProviderConfigurations(BaseModel): def get(self, key, default=None) -> ProviderConfiguration | None: if "/" not in key: - key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" + key = str(ModelProviderID(key)) return self.configurations.get(key, default) # type: ignore diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 67fd355ee9..20d98562de 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel): class HostingConfiguration: - provider_map: dict[str, HostingProvider] = {} + provider_map: dict[str, HostingProvider] moderation_config: Optional[HostedModerationConfig] = None + def __init__(self) -> None: + self.provider_map = {} + self.moderation_config = None + def init_app(self, app: Flask) -> None: if dify_config.EDITION != "CLOUD": return diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index b311f069a8..4e9b20b033 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -7,7 +7,6 @@ from typing import Optional from pydantic import BaseModel import contexts -from core.entities import DEFAULT_PLUGIN_ID from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity @@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel): class ModelProviderFactory: - provider_position_map: dict[str, int] = {} + provider_position_map: dict[str, int] def __init__(self, tenant_id: str) -> None: + self.provider_position_map = {} + self.tenant_id = tenant_id self.plugin_model_manager = PluginModelManager() @@ -360,11 +361,5 @@ class ModelProviderFactory: :param provider: provider name :return: plugin id and provider name """ - plugin_id = DEFAULT_PLUGIN_ID - provider_name = provider - if "/" in provider: - # get the plugin_id before provider - plugin_id = "/".join(provider.split("/")[:-1]) - provider_name = provider.split("/")[-1] - - return str(plugin_id), provider_name + provider_id = ModelProviderID(provider) + return provider_id.plugin_id, provider_id.provider_name diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 79961533b0..df38dc3c16 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -13,10 +13,10 @@ from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config -from core.entities import DEFAULT_PLUGIN_ID from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from core.plugin.entities.plugin import ModelProviderID from core.rag.index_processor.constant.index_type import IndexType from core.rag.retrieval.retrieval_methods import RetrievalMethod from events.dataset_event import dataset_was_deleted @@ -328,14 +328,10 @@ class DatasetService: else: # add default plugin id to both setting sets, to make sure the plugin model provider is consistent plugin_model_provider = dataset.embedding_model_provider - if "/" not in plugin_model_provider: - plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}" + plugin_model_provider = str(ModelProviderID(plugin_model_provider)) new_plugin_model_provider = data["embedding_model_provider"] - if "/" not in new_plugin_model_provider: - new_plugin_model_provider = ( - f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}" - ) + new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider)) if ( new_plugin_model_provider != plugin_model_provider