From cc14877ce72400bb677eeda3e81834e80effaafb Mon Sep 17 00:00:00 2001 From: luckylhb90 Date: Wed, 14 May 2025 11:10:05 +0300 Subject: [PATCH] Chore/reduce the invocation of the plugin interface (#19673) Co-authored-by: hobo.l Co-authored-by: crazywoola <427733928@qq.com> --- api/core/entities/provider_configuration.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 86887c9b4a..e1aec6eb7b 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -754,7 +754,7 @@ class ProviderConfiguration(BaseModel): :param only_active: return active model only :return: """ - provider_models = self.get_provider_models(model_type, only_active) + provider_models = self.get_provider_models(model_type, only_active, model) for provider_model in provider_models: if provider_model.model == model: @@ -763,12 +763,13 @@ class ProviderConfiguration(BaseModel): return None def get_provider_models( - self, model_type: Optional[ModelType] = None, only_active: bool = False + self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type :param only_active: only active models + :param model: model name :return: """ model_provider_factory = ModelProviderFactory(self.tenant_id) @@ -791,7 +792,10 @@ class ProviderConfiguration(BaseModel): ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map + model_types=model_types, + provider_schema=provider_schema, + model_setting_map=model_setting_map, + model=model, ) if only_active: @@ -944,6 +948,7 @@ class ProviderConfiguration(BaseModel): model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], + model: Optional[str] = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -996,7 +1001,8 @@ class ProviderConfiguration(BaseModel): for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue - + if model and model != model_configuration.model: + continue try: custom_model_schema = self.get_model_schema( model_type=model_configuration.model_type,