From d5c708c62b50c93494efd5eab7a7e9b74bdf75f0 Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 19 Dec 2024 00:50:46 +0800 Subject: [PATCH] feat: add plugin_model_providers context --- api/contexts/__init__.py | 5 ++++ .../model_providers/model_provider_factory.py | 23 +++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index b8354fa012..25419af278 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -4,11 +4,16 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from core.tools.plugin_tool.provider import PluginToolProviderController + from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.workflow.entities.variable_pool import VariablePool + tenant_id: ContextVar[str] = ContextVar("tenant_id") workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") plugin_tool_providers: ContextVar[dict[str, "PluginToolProviderController"]] = ContextVar("plugin_tool_providers") plugin_tool_providers_lock: ContextVar[Lock] = ContextVar("plugin_tool_providers_lock") + +plugin_model_providers: ContextVar[list["PluginModelProviderEntity"]] = ContextVar("plugin_model_providers") +plugin_model_providers_lock: ContextVar[Lock] = ContextVar("plugin_model_providers_lock") diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index c79c3a2b62..8eb9b35d41 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,10 +1,12 @@ import logging import os from collections.abc import Sequence +from threading import Lock from typing import Optional from pydantic import BaseModel +import contexts from core.entities import DEFAULT_PLUGIN_ID from core.helper.position_helper import get_provider_position_map, sort_to_dict_by_position_map from core.model_runtime.entities.model_entities import AIModelEntity, ModelType @@ -71,13 +73,24 @@ class ModelProviderFactory: Get all plugin model providers :return: list of plugin model providers """ - # Fetch plugin model providers - plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) + # check if context is set + try: + contexts.plugin_model_providers.get() + except LookupError: + contexts.plugin_model_providers.set([]) + contexts.plugin_model_providers_lock.set(Lock()) - for provider in plugin_providers: - provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider + with contexts.plugin_model_providers_lock.get(): + plugin_model_providers = contexts.plugin_model_providers.get() - return plugin_providers + # Fetch plugin model providers + plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) + + for provider in plugin_providers: + provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider + plugin_model_providers.append(provider) + + return plugin_model_providers def get_provider_schema(self, provider: str) -> ProviderEntity: """