mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-04-18 11:49:41 +08:00
Fix/plugin race condition (#14253)
This commit is contained in:
parent
42b13bd312
commit
490b6d092e
@ -2,6 +2,7 @@ import logging
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
dify_app = DifyApp(__name__)
|
||||
dify_app.config.from_mapping(dify_config.model_dump())
|
||||
|
||||
# add before request hook
|
||||
@dify_app.before_request
|
||||
def before_request():
|
||||
# add an unique identifier to each request
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
return dify_app
|
||||
|
||||
|
||||
|
@ -2,6 +2,8 @@ from contextvars import ContextVar
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
@ -12,8 +14,17 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
||||
"""
|
||||
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
||||
"""
|
||||
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
|
||||
ContextVar("plugin_tool_providers")
|
||||
)
|
||||
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||
|
||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
|
||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
||||
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers")
|
||||
)
|
||||
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||
ContextVar("plugin_model_providers_lock")
|
||||
)
|
||||
|
65
api/contexts/wrapper.py
Normal file
65
api/contexts/wrapper.py
Normal file
@ -0,0 +1,65 @@
|
||||
from contextvars import ContextVar
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class HiddenValue:
|
||||
pass
|
||||
|
||||
|
||||
_default = HiddenValue()
|
||||
|
||||
|
||||
class RecyclableContextVar(Generic[T]):
|
||||
"""
|
||||
RecyclableContextVar is a wrapper around ContextVar
|
||||
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||
|
||||
NOTE: you need to call `increment_thread_recycles` before requests
|
||||
"""
|
||||
|
||||
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
|
||||
|
||||
@classmethod
|
||||
def increment_thread_recycles(cls):
|
||||
try:
|
||||
recycles = cls._thread_recycles.get()
|
||||
cls._thread_recycles.set(recycles + 1)
|
||||
except LookupError:
|
||||
cls._thread_recycles.set(0)
|
||||
|
||||
def __init__(self, context_var: ContextVar[T]):
|
||||
self._context_var = context_var
|
||||
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
|
||||
|
||||
def get(self, default: T | HiddenValue = _default) -> T:
|
||||
thread_recycles = self._thread_recycles.get(0)
|
||||
self_updates = self._updates.get()
|
||||
if thread_recycles > self_updates:
|
||||
self._updates.set(thread_recycles)
|
||||
|
||||
# check if thread is recycled and should be updated
|
||||
if thread_recycles < self_updates:
|
||||
return self._context_var.get()
|
||||
else:
|
||||
# thread_recycles >= self_updates, means current context is invalid
|
||||
if isinstance(default, HiddenValue) or default is _default:
|
||||
raise LookupError
|
||||
else:
|
||||
return default
|
||||
|
||||
def set(self, value: T):
|
||||
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
|
||||
# increase it manually
|
||||
thread_recycles = self._thread_recycles.get(0)
|
||||
self_updates = self._updates.get()
|
||||
if thread_recycles > self_updates:
|
||||
self._updates.set(thread_recycles)
|
||||
|
||||
if self._updates.get() == self._thread_recycles.get(0):
|
||||
# after increment,
|
||||
self._updates.set(self._updates.get() + 1)
|
||||
|
||||
# set the context
|
||||
self._context_var.set(value)
|
@ -1,7 +1,7 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
|
||||
@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
|
||||
provider_type: ToolProviderType
|
||||
provider_id: str
|
||||
tool_name: str
|
||||
tool_parameters: dict[str, Any] = {}
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
plugin_unique_identifier: str | None = None
|
||||
|
||||
|
||||
|
@ -2,9 +2,9 @@ from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@ -61,9 +61,7 @@ class ModelConfigManager:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
||||
if "/" not in config["model"]["provider"]:
|
||||
config["model"]["provider"] = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
||||
)
|
||||
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
|
||||
|
||||
if config["model"]["provider"] not in model_provider_names:
|
||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||
|
@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
mode: Optional[str] = None
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AdvancedChatMessageEntity(BaseModel):
|
||||
@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):
|
||||
|
||||
variable: str
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DatasetRetrieveConfigEntity(BaseModel):
|
||||
@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
||||
"""
|
||||
|
||||
type: str
|
||||
config: dict[str, Any] = {}
|
||||
config: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TextToSpeechEntity(BaseModel):
|
||||
|
@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
|
||||
model_schema: AIModelEntity
|
||||
mode: str
|
||||
provider_model_bundle: ProviderModelBundle
|
||||
credentials: dict[str, Any] = {}
|
||||
parameters: dict[str, Any] = {}
|
||||
stop: list[str] = []
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
|
||||
# pydantic configs
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
|
||||
call_depth: int = 0
|
||||
|
||||
# extra parameters, like: auto_generate_conversation_name
|
||||
extras: dict[str, Any] = {}
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional[TraceQueueManager] = None
|
||||
|
@ -6,11 +6,10 @@ from collections.abc import Iterator, Sequence
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import or_
|
||||
|
||||
from constants import HIDDEN_VALUE
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||
from core.entities.provider_entities import (
|
||||
CustomConfiguration,
|
||||
@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel):
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
configurations: dict[str, ProviderConfiguration] = {}
|
||||
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
|
||||
|
||||
def __init__(self, tenant_id: str):
|
||||
super().__init__(tenant_id=tenant_id)
|
||||
@ -1060,7 +1059,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
def __getitem__(self, key):
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
key = str(ModelProviderID(key))
|
||||
|
||||
return self.configurations[key]
|
||||
|
||||
@ -1075,7 +1074,7 @@ class ProviderConfigurations(BaseModel):
|
||||
|
||||
def get(self, key, default=None) -> ProviderConfiguration | None:
|
||||
if "/" not in key:
|
||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
||||
key = str(ModelProviderID(key))
|
||||
|
||||
return self.configurations.get(key, default) # type: ignore
|
||||
|
||||
|
@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel):
|
||||
|
||||
|
||||
class HostingConfiguration:
|
||||
provider_map: dict[str, HostingProvider] = {}
|
||||
provider_map: dict[str, HostingProvider]
|
||||
moderation_config: Optional[HostedModerationConfig] = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.provider_map = {}
|
||||
self.moderation_config = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
if dify_config.EDITION != "CLOUD":
|
||||
return
|
||||
|
@ -7,7 +7,6 @@ from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
import contexts
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||
@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel):
|
||||
|
||||
|
||||
class ModelProviderFactory:
|
||||
provider_position_map: dict[str, int] = {}
|
||||
provider_position_map: dict[str, int]
|
||||
|
||||
def __init__(self, tenant_id: str) -> None:
|
||||
self.provider_position_map = {}
|
||||
|
||||
self.tenant_id = tenant_id
|
||||
self.plugin_model_manager = PluginModelManager()
|
||||
|
||||
@ -360,11 +361,5 @@ class ModelProviderFactory:
|
||||
:param provider: provider name
|
||||
:return: plugin id and provider name
|
||||
"""
|
||||
plugin_id = DEFAULT_PLUGIN_ID
|
||||
provider_name = provider
|
||||
if "/" in provider:
|
||||
# get the plugin_id before provider
|
||||
plugin_id = "/".join(provider.split("/")[:-1])
|
||||
provider_name = provider.split("/")[-1]
|
||||
|
||||
return str(plugin_id), provider_name
|
||||
provider_id = ModelProviderID(provider)
|
||||
return provider_id.plugin_id, provider_id.provider_name
|
||||
|
@ -13,10 +13,10 @@ from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from events.dataset_event import dataset_was_deleted
|
||||
@ -328,14 +328,10 @@ class DatasetService:
|
||||
else:
|
||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||
plugin_model_provider = dataset.embedding_model_provider
|
||||
if "/" not in plugin_model_provider:
|
||||
plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}"
|
||||
plugin_model_provider = str(ModelProviderID(plugin_model_provider))
|
||||
|
||||
new_plugin_model_provider = data["embedding_model_provider"]
|
||||
if "/" not in new_plugin_model_provider:
|
||||
new_plugin_model_provider = (
|
||||
f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}"
|
||||
)
|
||||
new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider))
|
||||
|
||||
if (
|
||||
new_plugin_model_provider != plugin_model_provider
|
||||
|
Loading…
x
Reference in New Issue
Block a user