From 3efb5fe7e26834ebf76910b8fa159b98af19ecfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Wed, 22 May 2024 11:18:03 +0800 Subject: [PATCH] Refactor part of the ProviderManager code to improve readability (#4524) --- api/core/provider_manager.py | 77 +++++++++++------------------------- 1 file changed, 22 insertions(+), 55 deletions(-) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 0db84d3b69..0281ddad0a 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -105,14 +105,8 @@ class ProviderManager: # Construct ProviderConfiguration objects for each provider for provider_entity in provider_entities: provider_name = provider_entity.provider - - provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider) - if not provider_records: - provider_records = [] - - provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider) - if not provider_model_records: - provider_model_records = [] + provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) + provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) # Convert to custom configuration custom_configuration = self._to_custom_configuration( @@ -134,38 +128,24 @@ class ProviderManager: if preferred_provider_type_record: preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) + elif custom_configuration.provider or custom_configuration.models: + preferred_provider_type = ProviderType.CUSTOM + elif system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM else: - if custom_configuration.provider or custom_configuration.models: - preferred_provider_type = ProviderType.CUSTOM - elif system_configuration.enabled: - preferred_provider_type = ProviderType.SYSTEM - else: - preferred_provider_type = ProviderType.CUSTOM + preferred_provider_type = ProviderType.CUSTOM using_provider_type = preferred_provider_type + has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations) + if preferred_provider_type == ProviderType.SYSTEM: - if not system_configuration.enabled: + if not system_configuration.enabled or not has_valid_quota: using_provider_type = ProviderType.CUSTOM - has_valid_quota = False - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.is_valid: - has_valid_quota = True - break - - if not has_valid_quota: - using_provider_type = ProviderType.CUSTOM else: if not custom_configuration.provider and not custom_configuration.models: - if system_configuration.enabled: - has_valid_quota = False - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.is_valid: - has_valid_quota = True - break - - if has_valid_quota: - using_provider_type = ProviderType.SYSTEM + if system_configuration.enabled and has_valid_quota: + using_provider_type = ProviderType.SYSTEM provider_configuration = ProviderConfiguration( tenant_id=tenant_id, @@ -233,30 +213,17 @@ class ProviderManager: ) if available_models: - found = False - for available_model in available_models: - if available_model.model == "gpt-4": - default_model = TenantDefaultModel( - tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), - provider_name=available_model.provider.provider, - model_name=available_model.model - ) - db.session.add(default_model) - db.session.commit() - found = True - break + available_model = next((model for model in available_models if model.model == "gpt-4"), + available_models[0]) - if not found: - available_model = available_models[0] - default_model = TenantDefaultModel( - tenant_id=tenant_id, - model_type=model_type.to_origin_model_type(), - provider_name=available_model.provider.provider, - model_name=available_model.model - ) - db.session.add(default_model) - db.session.commit() + default_model = TenantDefaultModel( + tenant_id=tenant_id, + model_type=model_type.to_origin_model_type(), + provider_name=available_model.provider.provider, + model_name=available_model.model + ) + db.session.add(default_model) + db.session.commit() if not default_model: return None