mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-06-04 11:14:10 +08:00
Fix/plugin race condition (#14253)
This commit is contained in:
parent
42b13bd312
commit
490b6d092e
@ -2,6 +2,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from contexts.wrapper import RecyclableContextVar
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
|
|
||||||
|
|
||||||
@ -16,6 +17,12 @@ def create_flask_app_with_configs() -> DifyApp:
|
|||||||
dify_app = DifyApp(__name__)
|
dify_app = DifyApp(__name__)
|
||||||
dify_app.config.from_mapping(dify_config.model_dump())
|
dify_app.config.from_mapping(dify_config.model_dump())
|
||||||
|
|
||||||
|
# add before request hook
|
||||||
|
@dify_app.before_request
|
||||||
|
def before_request():
|
||||||
|
# add an unique identifier to each request
|
||||||
|
RecyclableContextVar.increment_thread_recycles()
|
||||||
|
|
||||||
return dify_app
|
return dify_app
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,6 +2,8 @@ from contextvars import ContextVar
|
|||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from contexts.wrapper import RecyclableContextVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
@ -12,8 +14,17 @@ tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
|||||||
|
|
||||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||||
|
|
||||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
"""
|
||||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
||||||
|
"""
|
||||||
|
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_tool_providers")
|
||||||
|
)
|
||||||
|
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||||
|
|
||||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"] | None] = ContextVar("plugin_model_providers")
|
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
ContextVar("plugin_model_providers")
|
||||||
|
)
|
||||||
|
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_model_providers_lock")
|
||||||
|
)
|
||||||
|
65
api/contexts/wrapper.py
Normal file
65
api/contexts/wrapper.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class HiddenValue:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_default = HiddenValue()
|
||||||
|
|
||||||
|
|
||||||
|
class RecyclableContextVar(Generic[T]):
|
||||||
|
"""
|
||||||
|
RecyclableContextVar is a wrapper around ContextVar
|
||||||
|
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||||
|
|
||||||
|
NOTE: you need to call `increment_thread_recycles` before requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def increment_thread_recycles(cls):
|
||||||
|
try:
|
||||||
|
recycles = cls._thread_recycles.get()
|
||||||
|
cls._thread_recycles.set(recycles + 1)
|
||||||
|
except LookupError:
|
||||||
|
cls._thread_recycles.set(0)
|
||||||
|
|
||||||
|
def __init__(self, context_var: ContextVar[T]):
|
||||||
|
self._context_var = context_var
|
||||||
|
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
|
||||||
|
|
||||||
|
def get(self, default: T | HiddenValue = _default) -> T:
|
||||||
|
thread_recycles = self._thread_recycles.get(0)
|
||||||
|
self_updates = self._updates.get()
|
||||||
|
if thread_recycles > self_updates:
|
||||||
|
self._updates.set(thread_recycles)
|
||||||
|
|
||||||
|
# check if thread is recycled and should be updated
|
||||||
|
if thread_recycles < self_updates:
|
||||||
|
return self._context_var.get()
|
||||||
|
else:
|
||||||
|
# thread_recycles >= self_updates, means current context is invalid
|
||||||
|
if isinstance(default, HiddenValue) or default is _default:
|
||||||
|
raise LookupError
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def set(self, value: T):
|
||||||
|
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
|
||||||
|
# increase it manually
|
||||||
|
thread_recycles = self._thread_recycles.get(0)
|
||||||
|
self_updates = self._updates.get()
|
||||||
|
if thread_recycles > self_updates:
|
||||||
|
self._updates.set(thread_recycles)
|
||||||
|
|
||||||
|
if self._updates.get() == self._thread_recycles.get(0):
|
||||||
|
# after increment,
|
||||||
|
self._updates.set(self._updates.get() + 1)
|
||||||
|
|
||||||
|
# set the context
|
||||||
|
self._context_var.set(value)
|
@ -1,7 +1,7 @@
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class AgentToolEntity(BaseModel):
|
|||||||
provider_type: ToolProviderType
|
provider_type: ToolProviderType
|
||||||
provider_id: str
|
provider_id: str
|
||||||
tool_name: str
|
tool_name: str
|
||||||
tool_parameters: dict[str, Any] = {}
|
tool_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
plugin_unique_identifier: str | None = None
|
plugin_unique_identifier: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,9 +2,9 @@ from collections.abc import Mapping
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.app.app_config.entities import ModelConfigEntity
|
from core.app.app_config.entities import ModelConfigEntity
|
||||||
from core.entities import DEFAULT_PLUGIN_ID
|
|
||||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
from core.provider_manager import ProviderManager
|
from core.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
|
||||||
@ -61,9 +61,7 @@ class ModelConfigManager:
|
|||||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||||
|
|
||||||
if "/" not in config["model"]["provider"]:
|
if "/" not in config["model"]["provider"]:
|
||||||
config["model"]["provider"] = (
|
config["model"]["provider"] = str(ModelProviderID(config["model"]["provider"]))
|
||||||
f"{DEFAULT_PLUGIN_ID}/{config['model']['provider']}/{config['model']['provider']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if config["model"]["provider"] not in model_provider_names:
|
if config["model"]["provider"] not in model_provider_names:
|
||||||
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}")
|
||||||
|
@ -17,8 +17,8 @@ class ModelConfigEntity(BaseModel):
|
|||||||
provider: str
|
provider: str
|
||||||
model: str
|
model: str
|
||||||
mode: Optional[str] = None
|
mode: Optional[str] = None
|
||||||
parameters: dict[str, Any] = {}
|
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
stop: list[str] = []
|
stop: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatMessageEntity(BaseModel):
|
class AdvancedChatMessageEntity(BaseModel):
|
||||||
@ -132,7 +132,7 @@ class ExternalDataVariableEntity(BaseModel):
|
|||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
type: str
|
type: str
|
||||||
config: dict[str, Any] = {}
|
config: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class DatasetRetrieveConfigEntity(BaseModel):
|
class DatasetRetrieveConfigEntity(BaseModel):
|
||||||
@ -188,7 +188,7 @@ class SensitiveWordAvoidanceEntity(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
type: str
|
type: str
|
||||||
config: dict[str, Any] = {}
|
config: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeechEntity(BaseModel):
|
class TextToSpeechEntity(BaseModel):
|
||||||
|
@ -63,9 +63,9 @@ class ModelConfigWithCredentialsEntity(BaseModel):
|
|||||||
model_schema: AIModelEntity
|
model_schema: AIModelEntity
|
||||||
mode: str
|
mode: str
|
||||||
provider_model_bundle: ProviderModelBundle
|
provider_model_bundle: ProviderModelBundle
|
||||||
credentials: dict[str, Any] = {}
|
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||||
parameters: dict[str, Any] = {}
|
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
stop: list[str] = []
|
stop: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
# pydantic configs
|
# pydantic configs
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
@ -94,7 +94,7 @@ class AppGenerateEntity(BaseModel):
|
|||||||
call_depth: int = 0
|
call_depth: int = 0
|
||||||
|
|
||||||
# extra parameters, like: auto_generate_conversation_name
|
# extra parameters, like: auto_generate_conversation_name
|
||||||
extras: dict[str, Any] = {}
|
extras: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
# tracing instance
|
# tracing instance
|
||||||
trace_manager: Optional[TraceQueueManager] = None
|
trace_manager: Optional[TraceQueueManager] = None
|
||||||
|
@ -6,11 +6,10 @@ from collections.abc import Iterator, Sequence
|
|||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy import or_
|
from sqlalchemy import or_
|
||||||
|
|
||||||
from constants import HIDDEN_VALUE
|
from constants import HIDDEN_VALUE
|
||||||
from core.entities import DEFAULT_PLUGIN_ID
|
|
||||||
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity
|
||||||
from core.entities.provider_entities import (
|
from core.entities.provider_entities import (
|
||||||
CustomConfiguration,
|
CustomConfiguration,
|
||||||
@ -1004,7 +1003,7 @@ class ProviderConfigurations(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
configurations: dict[str, ProviderConfiguration] = {}
|
configurations: dict[str, ProviderConfiguration] = Field(default_factory=dict)
|
||||||
|
|
||||||
def __init__(self, tenant_id: str):
|
def __init__(self, tenant_id: str):
|
||||||
super().__init__(tenant_id=tenant_id)
|
super().__init__(tenant_id=tenant_id)
|
||||||
@ -1060,7 +1059,7 @@ class ProviderConfigurations(BaseModel):
|
|||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
if "/" not in key:
|
if "/" not in key:
|
||||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
key = str(ModelProviderID(key))
|
||||||
|
|
||||||
return self.configurations[key]
|
return self.configurations[key]
|
||||||
|
|
||||||
@ -1075,7 +1074,7 @@ class ProviderConfigurations(BaseModel):
|
|||||||
|
|
||||||
def get(self, key, default=None) -> ProviderConfiguration | None:
|
def get(self, key, default=None) -> ProviderConfiguration | None:
|
||||||
if "/" not in key:
|
if "/" not in key:
|
||||||
key = f"{DEFAULT_PLUGIN_ID}/{key}/{key}"
|
key = str(ModelProviderID(key))
|
||||||
|
|
||||||
return self.configurations.get(key, default) # type: ignore
|
return self.configurations.get(key, default) # type: ignore
|
||||||
|
|
||||||
|
@ -41,9 +41,13 @@ class HostedModerationConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class HostingConfiguration:
|
class HostingConfiguration:
|
||||||
provider_map: dict[str, HostingProvider] = {}
|
provider_map: dict[str, HostingProvider]
|
||||||
moderation_config: Optional[HostedModerationConfig] = None
|
moderation_config: Optional[HostedModerationConfig] = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.provider_map = {}
|
||||||
|
self.moderation_config = None
|
||||||
|
|
||||||
def init_app(self, app: Flask) -> None:
|
def init_app(self, app: Flask) -> None:
|
||||||
if dify_config.EDITION != "CLOUD":
|
if dify_config.EDITION != "CLOUD":
|
||||||
return
|
return
|
||||||
|
@ -7,7 +7,6 @@ from typing import Optional
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
from core.entities import DEFAULT_PLUGIN_ID
|
|
||||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||||
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
|
||||||
@ -34,9 +33,11 @@ class ModelProviderExtension(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ModelProviderFactory:
|
class ModelProviderFactory:
|
||||||
provider_position_map: dict[str, int] = {}
|
provider_position_map: dict[str, int]
|
||||||
|
|
||||||
def __init__(self, tenant_id: str) -> None:
|
def __init__(self, tenant_id: str) -> None:
|
||||||
|
self.provider_position_map = {}
|
||||||
|
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.plugin_model_manager = PluginModelManager()
|
self.plugin_model_manager = PluginModelManager()
|
||||||
|
|
||||||
@ -360,11 +361,5 @@ class ModelProviderFactory:
|
|||||||
:param provider: provider name
|
:param provider: provider name
|
||||||
:return: plugin id and provider name
|
:return: plugin id and provider name
|
||||||
"""
|
"""
|
||||||
plugin_id = DEFAULT_PLUGIN_ID
|
provider_id = ModelProviderID(provider)
|
||||||
provider_name = provider
|
return provider_id.plugin_id, provider_id.provider_name
|
||||||
if "/" in provider:
|
|
||||||
# get the plugin_id before provider
|
|
||||||
plugin_id = "/".join(provider.split("/")[:-1])
|
|
||||||
provider_name = provider.split("/")[-1]
|
|
||||||
|
|
||||||
return str(plugin_id), provider_name
|
|
||||||
|
@ -13,10 +13,10 @@ from sqlalchemy.orm import Session
|
|||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.entities import DEFAULT_PLUGIN_ID
|
|
||||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||||
from core.model_manager import ModelManager
|
from core.model_manager import ModelManager
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
from core.rag.index_processor.constant.index_type import IndexType
|
from core.rag.index_processor.constant.index_type import IndexType
|
||||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
from events.dataset_event import dataset_was_deleted
|
from events.dataset_event import dataset_was_deleted
|
||||||
@ -328,14 +328,10 @@ class DatasetService:
|
|||||||
else:
|
else:
|
||||||
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
# add default plugin id to both setting sets, to make sure the plugin model provider is consistent
|
||||||
plugin_model_provider = dataset.embedding_model_provider
|
plugin_model_provider = dataset.embedding_model_provider
|
||||||
if "/" not in plugin_model_provider:
|
plugin_model_provider = str(ModelProviderID(plugin_model_provider))
|
||||||
plugin_model_provider = f"{DEFAULT_PLUGIN_ID}/{plugin_model_provider}/{plugin_model_provider}"
|
|
||||||
|
|
||||||
new_plugin_model_provider = data["embedding_model_provider"]
|
new_plugin_model_provider = data["embedding_model_provider"]
|
||||||
if "/" not in new_plugin_model_provider:
|
new_plugin_model_provider = str(ModelProviderID(new_plugin_model_provider))
|
||||||
new_plugin_model_provider = (
|
|
||||||
f"{DEFAULT_PLUGIN_ID}/{new_plugin_model_provider}/{new_plugin_model_provider}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
new_plugin_model_provider != plugin_model_provider
|
new_plugin_model_provider != plugin_model_provider
|
||||||
|
Loading…
x
Reference in New Issue
Block a user