diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index da7d81cabb..a1c05a2af3 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -520,7 +520,13 @@ class ProviderConfiguration(BaseModel): provider_models.extend( [ ModelWithProviderEntity( - **m.dict(), + model=m.model, + label=m.label, + model_type=m.model_type, + features=m.features, + fetch_from=m.fetch_from, + model_properties=m.model_properties, + deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=ModelStatus.ACTIVE ) @@ -569,7 +575,13 @@ class ProviderConfiguration(BaseModel): for m in models: provider_models.append( ModelWithProviderEntity( - **m.dict(), + model=m.model, + label=m.label, + model_type=m.model_type, + features=m.features, + fetch_from=m.fetch_from, + model_properties=m.model_properties, + deprecated=m.deprecated, provider=SimpleModelProviderEntity(self.provider), status=ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE ) @@ -597,7 +609,13 @@ class ProviderConfiguration(BaseModel): provider_models.append( ModelWithProviderEntity( - **custom_model_schema.dict(), + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=custom_model_schema.fetch_from, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, provider=SimpleModelProviderEntity(self.provider), status=ModelStatus.ACTIVE ) diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 8edf1df8d6..d1e93bb839 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -24,6 +24,9 @@ class ProviderManager: """ ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. """ + def __init__(self) -> None: + self.decoding_rsa_key = None + self.decoding_cipher_rsa = None def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -472,15 +475,16 @@ class ProviderManager: provider_credentials = {} # Get decoding rsa key and cipher for decrypting credentials - decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: + self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) for variable in provider_credential_secret_variables: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + self.decoding_rsa_key, + self.decoding_cipher_rsa ) except ValueError: pass @@ -524,15 +528,16 @@ class ProviderManager: continue # Get decoding rsa key and cipher for decrypting credentials - decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: + self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) for variable in model_credential_secret_variables: if variable in provider_model_credentials: try: provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_model_credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + self.decoding_rsa_key, + self.decoding_cipher_rsa ) except ValueError: pass @@ -641,15 +646,16 @@ class ProviderManager: ) # Get decoding rsa key and cipher for decrypting credentials - decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: + self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) for variable in provider_credential_secret_variables: if variable in provider_credentials: try: provider_credentials[variable] = encrypter.decrypt_token_with_decoding( provider_credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa + self.decoding_rsa_key, + self.decoding_cipher_rsa ) except ValueError: pass diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 31fa09ac21..422a59b0cc 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -14,7 +14,7 @@ from core.provider_manager import ProviderManager from models.provider import ProviderType from services.entities.model_provider_entities import ProviderResponse, CustomConfigurationResponse, \ SystemConfigurationResponse, CustomConfigurationStatus, ProviderWithModelsResponse, ModelResponse, \ - DefaultModelResponse, ModelWithProviderEntityResponse + DefaultModelResponse, ModelWithProviderEntityResponse, SimpleProviderEntityResponse logger = logging.getLogger(__name__) @@ -45,7 +45,17 @@ class ModelProviderService: continue provider_response = ProviderResponse( - **provider_configuration.provider.dict(), + provider=provider_configuration.provider.provider, + label=provider_configuration.provider.label, + description=provider_configuration.provider.description, + icon_small=provider_configuration.provider.icon_small, + icon_large=provider_configuration.provider.icon_large, + background=provider_configuration.provider.background, + help=provider_configuration.provider.help, + supported_model_types=provider_configuration.provider.supported_model_types, + configurate_methods=provider_configuration.provider.configurate_methods, + provider_credential_schema=provider_configuration.provider.provider_credential_schema, + model_credential_schema=provider_configuration.provider.model_credential_schema, preferred_provider_type=provider_configuration.preferred_provider_type, custom_configuration=CustomConfigurationResponse( status=CustomConfigurationStatus.ACTIVE @@ -53,7 +63,9 @@ class ModelProviderService: else CustomConfigurationStatus.NO_CONFIGURE ), system_configuration=SystemConfigurationResponse( - **provider_configuration.system_configuration.dict() + enabled=provider_configuration.system_configuration.enabled, + current_quota_type=provider_configuration.system_configuration.current_quota_type, + quota_configurations=provider_configuration.system_configuration.quota_configurations ) ) @@ -369,7 +381,15 @@ class ModelProviderService: ) return DefaultModelResponse( - **result.dict() + model=result.model, + model_type=result.model_type, + provider=SimpleProviderEntityResponse( + provider=result.provider.provider, + label=result.provider.label, + icon_small=result.provider.icon_small, + icon_large=result.provider.icon_large, + supported_model_types=result.provider.supported_model_types + ) ) if result else None def update_default_model_of_model_type(self, tenant_id: str, model_type: str, provider: str, model: str) -> None: