Fix/plugin race condition (#14253)

This commit is contained in:
Yeuoly 2025-02-25 12:20:47 +08:00 committed by GitHub
parent 42b13bd312
commit 490b6d092e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 116 additions and 41 deletions

View File

@ -2,6 +2,7 @@ import logging
import time import time
from configs import dify_config from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from dify_app import DifyApp from dify_app import DifyApp
@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp:
dify_app = DifyApp(__name__) dify_app = DifyApp(__name__)
dify_app.config.from_mapping(dify_config.model_dump()) dify_app.config.from_mapping(dify_config.model_dump())
# add before request hook
@dify_app.before_request
def before_request():
# add an unique identifier to each request
RecyclableContextVar.increment_thread_recycles()
return dify_app return dify_app

View File

@ -2,6 +2,8 @@ from contextvars import ContextVar
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING: if TYPE_CHECKING:
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
@ -12,8 +14,17 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") """
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
"""
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
ContextVar("plugin_tool_providers")
)
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers") plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock") ContextVar("plugin_model_providers")
)
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)

65
api/contexts/wrapper.py Normal file
View File

@ -0,0 +1,65 @@
from contextvars import ContextVar
from typing import Generic, TypeVar
T = TypeVar("T")
class HiddenValue:
pass
_default = HiddenValue()
class RecyclableContextVar(Generic[T]):
"""
RecyclableContextVar is a wrapper around ContextVar
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
NOTE: you need to call `increment_thread_recycles` before requests
"""
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
@classmethod
def increment_thread_recycles(cls):
try:
recycles = cls._thread_recycles.get()
cls._thread_recycles.set(recycles + 1)
except LookupError:
cls._thread_recycles.set(0)
def __init__(self, context_var: ContextVar[T]):
self._context_var = context_var
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
def get(self, default: T | HiddenValue = _default) -> T:
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)
# check if thread is recycled and should be updated
if thread_recycles < self_updates:
return self._context_var.get()
else:
# thread_recycles >= self_updates, means current context is invalid
if isinstance(default, HiddenValue) or default is _default:
raise LookupError
else:
return default
def set(self, value: T):
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
# increase it manually
thread_recycles = self._thread_recycles.get(0)
self_updates = self._updates.get()
if thread_recycles > self_updates:
self._updates.set(thread_recycles)
if self._updates.get() == self._thread_recycles.get(0):
# after increment,
self._updates.set(self._updates.get() + 1)
# set the context
self._context_var.set(value)

View File

@ -1,7 +1,7 @@
from enum import StrEnum from enum import StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel, Field
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
provider_type: ToolProviderType provider_type: ToolProviderType
provider_id: str provider_id: str
tool_name: str tool_name: str
tool_parameters: dict[str, Any] = {} tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None plugin_unique_identifier: str | None = None

View File

@ -2,9 +2,9 @@ from collections.abc import Mapping
from typing import Any from typing import Any
from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.entities import DEFAULT_PLUGIN_ID
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
@ -61,9 +61,7 @@ class ModelConfigManager:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
if "/" not in config["model"]["provider"]: if "/" not in config["model"]["provider"]:
config["model"]["provider"] = ( config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
)
if config["model"]["provider"] not in model_provider_names: if config["model"]["provider"] not in model_provider_names:
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")

View File

@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
provider: str provider: str
model: str model: str
mode: Optional[str] = None mode: Optional[str] = None
parameters: dict[str, Any] = {} parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = [] stop: list[str] = Field(default_factory=list)
class AdvancedChatMessageEntity(BaseModel): class AdvancedChatMessageEntity(BaseModel):
@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):
variable: str variable: str
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = Field(default_factory=dict)
class DatasetRetrieveConfigEntity(BaseModel): class DatasetRetrieveConfigEntity(BaseModel):
@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
""" """
type: str type: str
config: dict[str, Any] = {} config: dict[str, Any] = Field(default_factory=dict)
class TextToSpeechEntity(BaseModel): class TextToSpeechEntity(BaseModel):

View File

@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
model_schema: AIModelEntity model_schema: AIModelEntity
mode: str mode: str
provider_model_bundle: ProviderModelBundle provider_model_bundle: ProviderModelBundle
credentials: dict[str, Any] = {} credentials: dict[str, Any] = Field(default_factory=dict)
parameters: dict[str, Any] = {} parameters: dict[str, Any] = Field(default_factory=dict)
stop: list[str] = [] stop: list[str] = Field(default_factory=list)
# pydantic configs # pydantic configs
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
call_depth: int = 0 call_depth: int = 0
# extra parameters, like: auto_generate_conversation_name # extra parameters, like: auto_generate_conversation_name
extras: dict[str, Any] = {} extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance # tracing instance
trace_manager: Optional[TraceQueueManager] = None trace_manager: Optional[TraceQueueManager] = None

View File

@ -6,11 +6,10 @@ from collections.abc import Iterator, Sequence
from json import JSONDecodeError from json import JSONDecodeError
from typing import Optional from typing import Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import or_ from sqlalchemy import or_
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.entities import DEFAULT_PLUGIN_ID
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
from core.entities.provider_entities import ( from core.entities.provider_entities import (
CustomConfiguration, CustomConfiguration,
@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel):
""" """
tenant_id: str tenant_id: str
configurations: dict[str, ProviderConfiguration] = {} configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
def __init__(self, tenant_id: str): def __init__(self, tenant_id: str):
super().__init__(tenant_id=tenant_id) super().__init__(tenant_id=tenant_id)
@ -1060,7 +1059,7 @@ class ProviderConfigurations(BaseModel):
def __getitem__(self, key): def __getitem__(self, key):
if "/" not in key: if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" key = str(ModelProviderID(key))
return self.configurations[key] return self.configurations[key]
@ -1075,7 +1074,7 @@ class ProviderConfigurations(BaseModel):
def get(self, key, default=None) -> ProviderConfiguration | None: def get(self, key, default=None) -> ProviderConfiguration | None:
if "/" not in key: if "/" not in key:
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}" key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore return self.configurations.get(key, default) # type: ignore

View File

@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel):
class HostingConfiguration: class HostingConfiguration:
provider_map: dict[str, HostingProvider] = {} provider_map: dict[str, HostingProvider]
moderation_config: Optional[HostedModerationConfig] = None moderation_config: Optional[HostedModerationConfig] = None
def __init__(self) -> None:
self.provider_map = {}
self.moderation_config = None
def init_app(self, app: Flask) -> None: def init_app(self, app: Flask) -> None:
if dify_config.EDITION != "CLOUD": if dify_config.EDITION != "CLOUD":
return return

View File

@ -7,7 +7,6 @@ from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
import contexts import contexts
from core.entities import DEFAULT_PLUGIN_ID
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel):
class ModelProviderFactory: class ModelProviderFactory:
provider_position_map: dict[str, int] = {} provider_position_map: dict[str, int]
def __init__(self, tenant_id: str) -> None: def __init__(self, tenant_id: str) -> None:
self.provider_position_map = {}
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.plugin_model_manager = PluginModelManager() self.plugin_model_manager = PluginModelManager()
@ -360,11 +361,5 @@ class ModelProviderFactory:
:param provider: provider name :param provider: provider name
:return: plugin id and provider name :return: plugin id and provider name
""" """
plugin_id = DEFAULT_PLUGIN_ID provider_id = ModelProviderID(provider)
provider_name = provider return provider_id.plugin_id, provider_id.provider_name
if "/" in provider:
# get the plugin_id before provider
plugin_id = "/".join(provider.split("/")[:-1])
provider_name = provider.split("/")[-1]
return str(plugin_id), provider_name

View File

@ -13,10 +13,10 @@ from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
from core.entities import DEFAULT_PLUGIN_ID
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from events.dataset_event import dataset_was_deleted from events.dataset_event import dataset_was_deleted
@ -328,14 +328,10 @@ class DatasetService:
else: else:
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent # add default plugin id to both setting sets, to make sure the plugin model provider is consistent
plugin_model_provider = dataset.embedding_model_provider plugin_model_provider = dataset.embedding_model_provider
if "/" not in plugin_model_provider: plugin_model_provider = str(ModelProviderID(plugin_model_provider))
plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}"
new_plugin_model_provider = data["embedding_model_provider"] new_plugin_model_provider = data["embedding_model_provider"]
if "/" not in new_plugin_model_provider: new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider))
new_plugin_model_provider = (
f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}"
)
if ( if (
new_plugin_model_provider != plugin_model_provider new_plugin_model_provider != plugin_model_provider