mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-18 05:35:57 +08:00
feat: add plugin_model_providers context
This commit is contained in:
parent
342d4060ff
commit
d5c708c62b
@ -4,11 +4,16 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
|
||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||
|
||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||
|
||||
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
|
||||
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
|
||||
|
||||
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"]] = ContextVar("plugin_model_providers")
|
||||
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")
|
||||
|
@ -1,10 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import contexts
|
||||
from core.entities import DEFAULT_PLUGIN_ID
|
||||
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
@ -71,13 +73,24 @@ class ModelProviderFactory:
|
||||
Get all plugin model providers
|
||||
:return: list of plugin model providers
|
||||
"""
|
||||
# Fetch plugin model providers
|
||||
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
||||
# check if context is set
|
||||
try:
|
||||
contexts.plugin_model_providers.get()
|
||||
except LookupError:
|
||||
contexts.plugin_model_providers.set([])
|
||||
contexts.plugin_model_providers_lock.set(Lock())
|
||||
|
||||
for provider in plugin_providers:
|
||||
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
||||
with contexts.plugin_model_providers_lock.get():
|
||||
plugin_model_providers = contexts.plugin_model_providers.get()
|
||||
|
||||
return plugin_providers
|
||||
# Fetch plugin model providers
|
||||
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
|
||||
|
||||
for provider in plugin_providers:
|
||||
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
|
||||
plugin_model_providers.append(provider)
|
||||
|
||||
return plugin_model_providers
|
||||
|
||||
def get_provider_schema(self, provider: str) -> ProviderEntity:
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user