feat: add plugin_model_providers context

This commit is contained in:
takatost 2024-12-19 00:50:46 +08:00
parent 342d4060ff
commit d5c708c62b
2 changed files with 23 additions and 5 deletions

View File

@ -4,11 +4,16 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"]] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")

View File

@ -1,10 +1,12 @@
import logging
import os
from collections.abc import Sequence
from threading import Lock
from typing import Optional
from pydantic import BaseModel
import contexts
from core.entities import DEFAULT_PLUGIN_ID
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@ -71,13 +73,24 @@ class ModelProviderFactory:
Get all plugin model providers
:return: list of plugin model providers
"""
# Fetch plugin model providers
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
# check if context is set
try:
contexts.plugin_model_providers.get()
except LookupError:
contexts.plugin_model_providers.set([])
contexts.plugin_model_providers_lock.set(Lock())
for provider in plugin_providers:
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
with contexts.plugin_model_providers_lock.get():
plugin_model_providers = contexts.plugin_model_providers.get()
return plugin_providers
# Fetch plugin model providers
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
for provider in plugin_providers:
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
plugin_model_providers.append(provider)
return plugin_model_providers
def get_provider_schema(self, provider: str) -> ProviderEntity:
"""