Refactor part of the ProviderManager code to improve readability (#4524)

This commit is contained in:
非法操作 2024-05-22 11:18:03 +08:00 committed by GitHub
parent ee53f98d8c
commit 3efb5fe7e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -105,14 +105,8 @@ class ProviderManager:
# Construct ProviderConfiguration objects for each provider # Construct ProviderConfiguration objects for each provider
for provider_entity in provider_entities: for provider_entity in provider_entities:
provider_name = provider_entity.provider provider_name = provider_entity.provider
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, [])
provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider) provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, [])
if not provider_records:
provider_records = []
provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider)
if not provider_model_records:
provider_model_records = []
# Convert to custom configuration # Convert to custom configuration
custom_configuration = self._to_custom_configuration( custom_configuration = self._to_custom_configuration(
@ -134,8 +128,7 @@ class ProviderManager:
if preferred_provider_type_record: if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
else: elif custom_configuration.provider or custom_configuration.models:
if custom_configuration.provider or custom_configuration.models:
preferred_provider_type = ProviderType.CUSTOM preferred_provider_type = ProviderType.CUSTOM
elif system_configuration.enabled: elif system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM preferred_provider_type = ProviderType.SYSTEM
@ -143,28 +136,15 @@ class ProviderManager:
preferred_provider_type = ProviderType.CUSTOM preferred_provider_type = ProviderType.CUSTOM
using_provider_type = preferred_provider_type using_provider_type = preferred_provider_type
has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations)
if preferred_provider_type == ProviderType.SYSTEM: if preferred_provider_type == ProviderType.SYSTEM:
if not system_configuration.enabled: if not system_configuration.enabled or not has_valid_quota:
using_provider_type = ProviderType.CUSTOM using_provider_type = ProviderType.CUSTOM
has_valid_quota = False
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.is_valid:
has_valid_quota = True
break
if not has_valid_quota:
using_provider_type = ProviderType.CUSTOM
else: else:
if not custom_configuration.provider and not custom_configuration.models: if not custom_configuration.provider and not custom_configuration.models:
if system_configuration.enabled: if system_configuration.enabled and has_valid_quota:
has_valid_quota = False
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.is_valid:
has_valid_quota = True
break
if has_valid_quota:
using_provider_type = ProviderType.SYSTEM using_provider_type = ProviderType.SYSTEM
provider_configuration = ProviderConfiguration( provider_configuration = ProviderConfiguration(
@ -233,22 +213,9 @@ class ProviderManager:
) )
if available_models: if available_models:
found = False available_model = next((model for model in available_models if model.model == "gpt-4"),
for available_model in available_models: available_models[0])
if available_model.model == "gpt-4":
default_model = TenantDefaultModel(
tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(),
provider_name=available_model.provider.provider,
model_name=available_model.model
)
db.session.add(default_model)
db.session.commit()
found = True
break
if not found:
available_model = available_models[0]
default_model = TenantDefaultModel( default_model = TenantDefaultModel(
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=model_type.to_origin_model_type(), model_type=model_type.to_origin_model_type(),