feat: add plugin_model_providers context

This commit is contained in:
takatost 2024-12-19 00:50:46 +08:00
parent 342d4060ff
commit d5c708c62b
2 changed files with 23 additions and 5 deletions

View File

@ -4,11 +4,16 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.provider import PluginToolProviderController
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
tenant_id: ContextVar[str] = ContextVar("tenant_id") tenant_id: ContextVar[str] = ContextVar("tenant_id")
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers")
plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock")
plugin_model_providers: ContextVar[list["PluginModelProviderEntity"]] = ContextVar("plugin_model_providers")
plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock")

View File

@ -1,10 +1,12 @@
import logging import logging
import os import os
from collections.abc import Sequence from collections.abc import Sequence
from threading import Lock
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
import contexts
from core.entities import DEFAULT_PLUGIN_ID from core.entities import DEFAULT_PLUGIN_ID
from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@ -71,13 +73,24 @@ class ModelProviderFactory:
Get all plugin model providers Get all plugin model providers
:return: list of plugin model providers :return: list of plugin model providers
""" """
# Fetch plugin model providers # check if context is set
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) try:
contexts.plugin_model_providers.get()
except LookupError:
contexts.plugin_model_providers.set([])
contexts.plugin_model_providers_lock.set(Lock())
for provider in plugin_providers: with contexts.plugin_model_providers_lock.get():
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider plugin_model_providers = contexts.plugin_model_providers.get()
return plugin_providers # Fetch plugin model providers
plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id)
for provider in plugin_providers:
provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider
plugin_model_providers.append(provider)
return plugin_model_providers
def get_provider_schema(self, provider: str) -> ProviderEntity: def get_provider_schema(self, provider: str) -> ProviderEntity:
""" """