diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a587ec9cf7..da7d81cabb 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -10,6 +10,7 @@ from pydantic import BaseModel from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus from core.helper import encrypter +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType from core.model_runtime.model_providers import model_provider_factory @@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel): db.session.add(provider_record) db.session.commit() + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER + ) + + provider_model_credentials_cache.delete() + self.switch_preferred_provider_type(ProviderType.CUSTOM) def delete_custom_credentials(self) -> None: @@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel): db.session.delete(provider_record) db.session.commit() + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER + ) + + provider_model_credentials_cache.delete() + def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \ -> Optional[dict]: """ @@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel): db.session.add(provider_model_record) db.session.commit() + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL + ) + + provider_model_credentials_cache.delete() + def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: """ Delete custom model credentials. @@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel): db.session.delete(provider_model_record) db.session.commit() + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL + ) + + provider_model_credentials_cache.delete() + def get_provider_instance(self) -> ModelProvider: """ Get provider instance. diff --git a/api/core/helper/model_provider_cache.py b/api/core/helper/model_provider_cache.py new file mode 100644 index 0000000000..64908f9501 --- /dev/null +++ b/api/core/helper/model_provider_cache.py @@ -0,0 +1,51 @@ +import json +from enum import Enum +from json import JSONDecodeError +from typing import Optional + +from extensions.ext_redis import redis_client + + +class ProviderCredentialsCacheType(Enum): + PROVIDER = "provider" + MODEL = "provider_model" + + +class ProviderCredentialsCache: + def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType): + self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + + def get(self) -> Optional[dict]: + """ + Get cached model provider credentials. + + :return: + """ + cached_provider_credentials = redis_client.get(self.cache_key) + if cached_provider_credentials: + try: + cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = json.loads(cached_provider_credentials) + except JSONDecodeError: + return None + + return cached_provider_credentials + else: + return None + + def set(self, credentials: dict) -> None: + """ + Cache model provider credentials. + + :param credentials: provider credentials + :return: + """ + redis_client.setex(self.cache_key, 3600, json.dumps(credentials)) + + def delete(self) -> None: + """ + Delete cached model provider credentials. + + :return: + """ + redis_client.delete(self.cache_key) diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index ad6e785435..740425d43f 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel): class ModelProviderFactory: model_provider_extensions: dict[str, ModelProviderExtension] = None + def __init__(self) -> None: + # for cache in memory + self.get_providers() + def get_providers(self) -> list[ProviderEntity]: """ Get all providers diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 60a388a633..aadc5344e0 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \ SystemConfiguration, QuotaConfiguration from core.helper import encrypter +from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType from core.model_runtime.model_providers import model_provider_factory @@ -79,9 +80,6 @@ class ProviderManager: # Get All preferred provider types of the workspace provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) - # Get decoding rsa key and cipher for decrypting credentials - decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) - provider_configurations = ProviderConfigurations( tenant_id=tenant_id ) @@ -100,19 +98,17 @@ class ProviderManager: # Convert to custom configuration custom_configuration = self._to_custom_configuration( + tenant_id, provider_entity, provider_records, - provider_model_records, - decoding_rsa_key, - decoding_cipher_rsa + provider_model_records ) # Convert to system configuration system_configuration = self._to_system_configuration( + tenant_id, provider_entity, - provider_records, - decoding_rsa_key, - decoding_cipher_rsa + provider_records ) # Get preferred provider type @@ -413,19 +409,17 @@ class ProviderManager: return provider_name_to_provider_records_dict def _to_custom_configuration(self, + tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider], - provider_model_records: list[ProviderModel], - decoding_rsa_key, - decoding_cipher_rsa) -> CustomConfiguration: + provider_model_records: list[ProviderModel]) -> CustomConfiguration: """ Convert to custom configuration. + :param tenant_id: workspace id :param provider_entity: provider entity :param provider_records: provider records :param provider_model_records: provider model records - :param decoding_rsa_key: decoding rsa key - :param decoding_cipher_rsa: decoding cipher rsa :return: """ # Get provider credential secret variables @@ -448,28 +442,48 @@ class ProviderManager: # Get custom provider credentials custom_provider_configuration = None if custom_provider_record: - try: - # fix origin data - if (custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{")): - provider_credentials = { - "openai_api_key": custom_provider_record.encrypted_config - } - else: - provider_credentials = json.loads(custom_provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + provider_credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=custom_provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER + ) - for variable in provider_credential_secret_variables: - if variable in provider_credentials: - try: - provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa - ) - except ValueError: - pass + # Get cached provider credentials + cached_provider_credentials = provider_credentials_cache.get() + + if not cached_provider_credentials: + try: + # fix origin data + if (custom_provider_record.encrypted_config + and not custom_provider_record.encrypted_config.startswith("{")): + provider_credentials = { + "openai_api_key": custom_provider_record.encrypted_config + } + else: + provider_credentials = json.loads(custom_provider_record.encrypted_config) + except JSONDecodeError: + provider_credentials = {} + + # Get decoding rsa key and cipher for decrypting credentials + decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + for variable in provider_credential_secret_variables: + if variable in provider_credentials: + try: + provider_credentials[variable] = encrypter.decrypt_token_with_decoding( + provider_credentials.get(variable), + decoding_rsa_key, + decoding_cipher_rsa + ) + except ValueError: + pass + + # cache provider credentials + provider_credentials_cache.set( + credentials=provider_credentials + ) + else: + provider_credentials = cached_provider_credentials custom_provider_configuration = CustomProviderConfiguration( credentials=provider_credentials @@ -487,21 +501,41 @@ class ProviderManager: if not provider_model_record.encrypted_config: continue - try: - provider_model_credentials = json.loads(provider_model_record.encrypted_config) - except JSONDecodeError: - continue + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=provider_model_record.id, + cache_type=ProviderCredentialsCacheType.MODEL + ) - for variable in model_credential_secret_variables: - if variable in provider_model_credentials: - try: - provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_model_credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa - ) - except ValueError: - pass + # Get cached provider model credentials + cached_provider_model_credentials = provider_model_credentials_cache.get() + + if not cached_provider_model_credentials: + try: + provider_model_credentials = json.loads(provider_model_record.encrypted_config) + except JSONDecodeError: + continue + + # Get decoding rsa key and cipher for decrypting credentials + decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + for variable in model_credential_secret_variables: + if variable in provider_model_credentials: + try: + provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( + provider_model_credentials.get(variable), + decoding_rsa_key, + decoding_cipher_rsa + ) + except ValueError: + pass + + # cache provider model credentials + provider_model_credentials_cache.set( + credentials=provider_model_credentials + ) + else: + provider_model_credentials = cached_provider_model_credentials custom_model_configurations.append( CustomModelConfiguration( @@ -517,17 +551,15 @@ class ProviderManager: ) def _to_system_configuration(self, + tenant_id: str, provider_entity: ProviderEntity, - provider_records: list[Provider], - decoding_rsa_key, - decoding_cipher_rsa) -> SystemConfiguration: + provider_records: list[Provider]) -> SystemConfiguration: """ Convert to system configuration. + :param tenant_id: workspace id :param provider_entity: provider entity :param provider_records: provider records - :param decoding_rsa_key: decoding rsa key - :param decoding_cipher_rsa: decoding cipher rsa :return: """ # Get hosting configuration @@ -580,29 +612,49 @@ class ProviderManager: provider_record = quota_type_to_provider_records_dict.get(current_quota_type) if provider_record: - try: - provider_credentials = json.loads(provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} - - # Get provider credential secret variables - provider_credential_secret_variables = self._extract_secret_variables( - provider_entity.provider_credential_schema.credential_form_schemas - if provider_entity.provider_credential_schema else [] + provider_credentials_cache = ProviderCredentialsCache( + tenant_id=tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER ) - for variable in provider_credential_secret_variables: - if variable in provider_credentials: - try: - provider_credentials[variable] = encrypter.decrypt_token_with_decoding( - provider_credentials.get(variable), - decoding_rsa_key, - decoding_cipher_rsa - ) - except ValueError: - pass + # Get cached provider credentials + cached_provider_credentials = provider_credentials_cache.get() - current_using_credentials = provider_credentials + if not cached_provider_credentials: + try: + provider_credentials = json.loads(provider_record.encrypted_config) + except JSONDecodeError: + provider_credentials = {} + + # Get provider credential secret variables + provider_credential_secret_variables = self._extract_secret_variables( + provider_entity.provider_credential_schema.credential_form_schemas + if provider_entity.provider_credential_schema else [] + ) + + # Get decoding rsa key and cipher for decrypting credentials + decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) + + for variable in provider_credential_secret_variables: + if variable in provider_credentials: + try: + provider_credentials[variable] = encrypter.decrypt_token_with_decoding( + provider_credentials.get(variable), + decoding_rsa_key, + decoding_cipher_rsa + ) + except ValueError: + pass + + current_using_credentials = provider_credentials + + # cache provider credentials + provider_credentials_cache.set( + credentials=current_using_credentials + ) + else: + current_using_credentials = cached_provider_credentials else: current_using_credentials = {} diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index b832504457..31fa09ac21 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple import requests from flask import current_app -from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity +from core.entities.model_entities import ModelStatus from core.model_runtime.entities.model_entities import ModelType, ParameterRule from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel