From 4668c4996a72052de000d6e2ef8cba0bea3129d9 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 4 Mar 2025 18:02:06 +0800 Subject: [PATCH] feat: Add caching mechanism for plugin model schemas (#14898) --- api/contexts/__init__.py | 9 ++++ .../model_providers/__base/ai_model.py | 41 +++++++++++++++---- .../model_providers/model_provider_factory.py | 39 +++++++++++++----- 3 files changed, 70 insertions(+), 19 deletions(-) diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 91438d086a..127b8fe76d 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: + from core.model_runtime.entities.model_entities import AIModelEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.workflow.entities.variable_pool import VariablePool @@ -20,11 +21,19 @@ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableCo plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar( ContextVar("plugin_tool_providers") ) + plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( ContextVar("plugin_model_providers") ) + plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( ContextVar("plugin_model_providers_lock") ) + +plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock")) + +plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( + ContextVar("plugin_model_schemas") +) diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index e79a3c0157..cdd1bba6be 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -1,8 +1,11 @@ import decimal +import hashlib +from threading import Lock from typing import Optional from pydantic import BaseModel, ConfigDict, Field +import contexts from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.model_entities import ( @@ -139,15 +142,35 @@ class AIModel(BaseModel): :return: model schema """ plugin_model_manager = PluginModelManager() - return plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials or {}, - ) + cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" + # sort credentials + sorted_credentials = sorted(credentials.items()) if credentials else [] + cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) + + try: + contexts.plugin_model_schemas.get() + except LookupError: + contexts.plugin_model_schemas.set({}) + contexts.plugin_model_schema_lock.set(Lock()) + + with contexts.plugin_model_schema_lock.get(): + if cache_key in contexts.plugin_model_schemas.get(): + return contexts.plugin_model_schemas.get()[cache_key] + + schema = plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=self.plugin_id, + provider=self.provider_name, + model_type=self.model_type.value, + model=model, + credentials=credentials or {}, + ) + + if schema: + contexts.plugin_model_schemas.get()[cache_key] = schema + + return schema def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 4e9b20b033..d2fd4916a4 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,3 +1,4 @@ +import hashlib import logging import os from collections.abc import Sequence @@ -206,17 +207,35 @@ class ModelProviderFactory: Get model schema """ plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - model_schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials, - ) + cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" + # sort credentials + sorted_credentials = sorted(credentials.items()) if credentials else [] + cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - return model_schema + try: + contexts.plugin_model_schemas.get() + except LookupError: + contexts.plugin_model_schemas.set({}) + contexts.plugin_model_schema_lock.set(Lock()) + + with contexts.plugin_model_schema_lock.get(): + if cache_key in contexts.plugin_model_schemas.get(): + return contexts.plugin_model_schemas.get()[cache_key] + + schema = self.plugin_model_manager.get_model_schema( + tenant_id=self.tenant_id, + user_id="unknown", + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials or {}, + ) + + if schema: + contexts.plugin_model_schemas.get()[cache_key] = schema + + return schema def get_models( self,