mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-12 17:39:06 +08:00
feat: optimize performance (#1928)
This commit is contained in:
parent
5a756ca981
commit
3fa5204b0c
@ -10,6 +10,7 @@ from pydantic import BaseModel
|
|||||||
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
|
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, SimpleModelProviderEntity
|
||||||
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
from core.entities.provider_entities import SystemConfiguration, CustomConfiguration, SystemConfigurationStatus
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
@ -171,6 +172,14 @@ class ProviderConfiguration(BaseModel):
|
|||||||
db.session.add(provider_record)
|
db.session.add(provider_record)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=provider_record.id,
|
||||||
|
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model_credentials_cache.delete()
|
||||||
|
|
||||||
self.switch_preferred_provider_type(ProviderType.CUSTOM)
|
self.switch_preferred_provider_type(ProviderType.CUSTOM)
|
||||||
|
|
||||||
def delete_custom_credentials(self) -> None:
|
def delete_custom_credentials(self) -> None:
|
||||||
@ -193,6 +202,14 @@ class ProviderConfiguration(BaseModel):
|
|||||||
db.session.delete(provider_record)
|
db.session.delete(provider_record)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=provider_record.id,
|
||||||
|
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model_credentials_cache.delete()
|
||||||
|
|
||||||
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
def get_custom_model_credentials(self, model_type: ModelType, model: str, obfuscated: bool = False) \
|
||||||
-> Optional[dict]:
|
-> Optional[dict]:
|
||||||
"""
|
"""
|
||||||
@ -314,6 +331,14 @@ class ProviderConfiguration(BaseModel):
|
|||||||
db.session.add(provider_model_record)
|
db.session.add(provider_model_record)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=provider_model_record.id,
|
||||||
|
cache_type=ProviderCredentialsCacheType.MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model_credentials_cache.delete()
|
||||||
|
|
||||||
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
|
def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None:
|
||||||
"""
|
"""
|
||||||
Delete custom model credentials.
|
Delete custom model credentials.
|
||||||
@ -335,6 +360,14 @@ class ProviderConfiguration(BaseModel):
|
|||||||
db.session.delete(provider_model_record)
|
db.session.delete(provider_model_record)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=provider_model_record.id,
|
||||||
|
cache_type=ProviderCredentialsCacheType.MODEL
|
||||||
|
)
|
||||||
|
|
||||||
|
provider_model_credentials_cache.delete()
|
||||||
|
|
||||||
def get_provider_instance(self) -> ModelProvider:
|
def get_provider_instance(self) -> ModelProvider:
|
||||||
"""
|
"""
|
||||||
Get provider instance.
|
Get provider instance.
|
||||||
|
51
api/core/helper/model_provider_cache.py
Normal file
51
api/core/helper/model_provider_cache.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCredentialsCacheType(Enum):
|
||||||
|
PROVIDER = "provider"
|
||||||
|
MODEL = "provider_model"
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderCredentialsCache:
|
||||||
|
def __init__(self, tenant_id: str, identity_id: str, cache_type: ProviderCredentialsCacheType):
|
||||||
|
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get cached model provider credentials.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||||
|
if cached_provider_credentials:
|
||||||
|
try:
|
||||||
|
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
|
||||||
|
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||||
|
except JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cached_provider_credentials
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Cache model provider credentials.
|
||||||
|
|
||||||
|
:param credentials: provider credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
redis_client.setex(self.cache_key, 3600, json.dumps(credentials))
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete cached model provider credentials.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
redis_client.delete(self.cache_key)
|
@ -30,6 +30,10 @@ class ModelProviderExtension(BaseModel):
|
|||||||
class ModelProviderFactory:
|
class ModelProviderFactory:
|
||||||
model_provider_extensions: dict[str, ModelProviderExtension] = None
|
model_provider_extensions: dict[str, ModelProviderExtension] = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# for cache in memory
|
||||||
|
self.get_providers()
|
||||||
|
|
||||||
def get_providers(self) -> list[ProviderEntity]:
|
def get_providers(self) -> list[ProviderEntity]:
|
||||||
"""
|
"""
|
||||||
Get all providers
|
Get all providers
|
||||||
|
@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfigurations, Provide
|
|||||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
|
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
|
||||||
SystemConfiguration, QuotaConfiguration
|
SystemConfiguration, QuotaConfiguration
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
|
from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType
|
||||||
from core.model_runtime.entities.model_entities import ModelType
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
from core.model_runtime.entities.provider_entities import ProviderEntity, CredentialFormSchema, FormType
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
@ -79,9 +80,6 @@ class ProviderManager:
|
|||||||
# Get All preferred provider types of the workspace
|
# Get All preferred provider types of the workspace
|
||||||
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id)
|
||||||
|
|
||||||
# Get decoding rsa key and cipher for decrypting credentials
|
|
||||||
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
|
||||||
|
|
||||||
provider_configurations = ProviderConfigurations(
|
provider_configurations = ProviderConfigurations(
|
||||||
tenant_id=tenant_id
|
tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
@ -100,19 +98,17 @@ class ProviderManager:
|
|||||||
|
|
||||||
# Convert to custom configuration
|
# Convert to custom configuration
|
||||||
custom_configuration = self._to_custom_configuration(
|
custom_configuration = self._to_custom_configuration(
|
||||||
|
tenant_id,
|
||||||
provider_entity,
|
provider_entity,
|
||||||
provider_records,
|
provider_records,
|
||||||
provider_model_records,
|
provider_model_records
|
||||||
decoding_rsa_key,
|
|
||||||
decoding_cipher_rsa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert to system configuration
|
# Convert to system configuration
|
||||||
system_configuration = self._to_system_configuration(
|
system_configuration = self._to_system_configuration(
|
||||||
|
tenant_id,
|
||||||
provider_entity,
|
provider_entity,
|
||||||
provider_records,
|
provider_records
|
||||||
decoding_rsa_key,
|
|
||||||
decoding_cipher_rsa
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get preferred provider type
|
# Get preferred provider type
|
||||||
@ -413,19 +409,17 @@ class ProviderManager:
|
|||||||
return provider_name_to_provider_records_dict
|
return provider_name_to_provider_records_dict
|
||||||
|
|
||||||
def _to_custom_configuration(self,
|
def _to_custom_configuration(self,
|
||||||
|
tenant_id: str,
|
||||||
provider_entity: ProviderEntity,
|
provider_entity: ProviderEntity,
|
||||||
provider_records: list[Provider],
|
provider_records: list[Provider],
|
||||||
provider_model_records: list[ProviderModel],
|
provider_model_records: list[ProviderModel]) -> CustomConfiguration:
|
||||||
decoding_rsa_key,
|
|
||||||
decoding_cipher_rsa) -> CustomConfiguration:
|
|
||||||
"""
|
"""
|
||||||
Convert to custom configuration.
|
Convert to custom configuration.
|
||||||
|
|
||||||
|
:param tenant_id: workspace id
|
||||||
:param provider_entity: provider entity
|
:param provider_entity: provider entity
|
||||||
:param provider_records: provider records
|
:param provider_records: provider records
|
||||||
:param provider_model_records: provider model records
|
:param provider_model_records: provider model records
|
||||||
:param decoding_rsa_key: decoding rsa key
|
|
||||||
:param decoding_cipher_rsa: decoding cipher rsa
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# Get provider credential secret variables
|
# Get provider credential secret variables
|
||||||
@ -448,28 +442,48 @@ class ProviderManager:
|
|||||||
# Get custom provider credentials
|
# Get custom provider credentials
|
||||||
custom_provider_configuration = None
|
custom_provider_configuration = None
|
||||||
if custom_provider_record:
|
if custom_provider_record:
|
||||||
try:
|
provider_credentials_cache = ProviderCredentialsCache(
|
||||||
# fix origin data
|
tenant_id=tenant_id,
|
||||||
if (custom_provider_record.encrypted_config
|
identity_id=custom_provider_record.id,
|
||||||
and not custom_provider_record.encrypted_config.startswith("{")):
|
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||||
provider_credentials = {
|
)
|
||||||
"openai_api_key": custom_provider_record.encrypted_config
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
|
||||||
except JSONDecodeError:
|
|
||||||
provider_credentials = {}
|
|
||||||
|
|
||||||
for variable in provider_credential_secret_variables:
|
# Get cached provider credentials
|
||||||
if variable in provider_credentials:
|
cached_provider_credentials = provider_credentials_cache.get()
|
||||||
try:
|
|
||||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
if not cached_provider_credentials:
|
||||||
provider_credentials.get(variable),
|
try:
|
||||||
decoding_rsa_key,
|
# fix origin data
|
||||||
decoding_cipher_rsa
|
if (custom_provider_record.encrypted_config
|
||||||
)
|
and not custom_provider_record.encrypted_config.startswith("{")):
|
||||||
except ValueError:
|
provider_credentials = {
|
||||||
pass
|
"openai_api_key": custom_provider_record.encrypted_config
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||||
|
except JSONDecodeError:
|
||||||
|
provider_credentials = {}
|
||||||
|
|
||||||
|
# Get decoding rsa key and cipher for decrypting credentials
|
||||||
|
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||||
|
|
||||||
|
for variable in provider_credential_secret_variables:
|
||||||
|
if variable in provider_credentials:
|
||||||
|
try:
|
||||||
|
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
|
provider_credentials.get(variable),
|
||||||
|
decoding_rsa_key,
|
||||||
|
decoding_cipher_rsa
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# cache provider credentials
|
||||||
|
provider_credentials_cache.set(
|
||||||
|
credentials=provider_credentials
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider_credentials = cached_provider_credentials
|
||||||
|
|
||||||
custom_provider_configuration = CustomProviderConfiguration(
|
custom_provider_configuration = CustomProviderConfiguration(
|
||||||
credentials=provider_credentials
|
credentials=provider_credentials
|
||||||
@ -487,21 +501,41 @@ class ProviderManager:
|
|||||||
if not provider_model_record.encrypted_config:
|
if not provider_model_record.encrypted_config:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||||
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
tenant_id=tenant_id,
|
||||||
except JSONDecodeError:
|
identity_id=provider_model_record.id,
|
||||||
continue
|
cache_type=ProviderCredentialsCacheType.MODEL
|
||||||
|
)
|
||||||
|
|
||||||
for variable in model_credential_secret_variables:
|
# Get cached provider model credentials
|
||||||
if variable in provider_model_credentials:
|
cached_provider_model_credentials = provider_model_credentials_cache.get()
|
||||||
try:
|
|
||||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
if not cached_provider_model_credentials:
|
||||||
provider_model_credentials.get(variable),
|
try:
|
||||||
decoding_rsa_key,
|
provider_model_credentials = json.loads(provider_model_record.encrypted_config)
|
||||||
decoding_cipher_rsa
|
except JSONDecodeError:
|
||||||
)
|
continue
|
||||||
except ValueError:
|
|
||||||
pass
|
# Get decoding rsa key and cipher for decrypting credentials
|
||||||
|
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||||
|
|
||||||
|
for variable in model_credential_secret_variables:
|
||||||
|
if variable in provider_model_credentials:
|
||||||
|
try:
|
||||||
|
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
|
provider_model_credentials.get(variable),
|
||||||
|
decoding_rsa_key,
|
||||||
|
decoding_cipher_rsa
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# cache provider model credentials
|
||||||
|
provider_model_credentials_cache.set(
|
||||||
|
credentials=provider_model_credentials
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider_model_credentials = cached_provider_model_credentials
|
||||||
|
|
||||||
custom_model_configurations.append(
|
custom_model_configurations.append(
|
||||||
CustomModelConfiguration(
|
CustomModelConfiguration(
|
||||||
@ -517,17 +551,15 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _to_system_configuration(self,
|
def _to_system_configuration(self,
|
||||||
|
tenant_id: str,
|
||||||
provider_entity: ProviderEntity,
|
provider_entity: ProviderEntity,
|
||||||
provider_records: list[Provider],
|
provider_records: list[Provider]) -> SystemConfiguration:
|
||||||
decoding_rsa_key,
|
|
||||||
decoding_cipher_rsa) -> SystemConfiguration:
|
|
||||||
"""
|
"""
|
||||||
Convert to system configuration.
|
Convert to system configuration.
|
||||||
|
|
||||||
|
:param tenant_id: workspace id
|
||||||
:param provider_entity: provider entity
|
:param provider_entity: provider entity
|
||||||
:param provider_records: provider records
|
:param provider_records: provider records
|
||||||
:param decoding_rsa_key: decoding rsa key
|
|
||||||
:param decoding_cipher_rsa: decoding cipher rsa
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# Get hosting configuration
|
# Get hosting configuration
|
||||||
@ -580,29 +612,49 @@ class ProviderManager:
|
|||||||
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
|
provider_record = quota_type_to_provider_records_dict.get(current_quota_type)
|
||||||
|
|
||||||
if provider_record:
|
if provider_record:
|
||||||
try:
|
provider_credentials_cache = ProviderCredentialsCache(
|
||||||
provider_credentials = json.loads(provider_record.encrypted_config)
|
tenant_id=tenant_id,
|
||||||
except JSONDecodeError:
|
identity_id=provider_record.id,
|
||||||
provider_credentials = {}
|
cache_type=ProviderCredentialsCacheType.PROVIDER
|
||||||
|
|
||||||
# Get provider credential secret variables
|
|
||||||
provider_credential_secret_variables = self._extract_secret_variables(
|
|
||||||
provider_entity.provider_credential_schema.credential_form_schemas
|
|
||||||
if provider_entity.provider_credential_schema else []
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for variable in provider_credential_secret_variables:
|
# Get cached provider credentials
|
||||||
if variable in provider_credentials:
|
cached_provider_credentials = provider_credentials_cache.get()
|
||||||
try:
|
|
||||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
|
||||||
provider_credentials.get(variable),
|
|
||||||
decoding_rsa_key,
|
|
||||||
decoding_cipher_rsa
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
current_using_credentials = provider_credentials
|
if not cached_provider_credentials:
|
||||||
|
try:
|
||||||
|
provider_credentials = json.loads(provider_record.encrypted_config)
|
||||||
|
except JSONDecodeError:
|
||||||
|
provider_credentials = {}
|
||||||
|
|
||||||
|
# Get provider credential secret variables
|
||||||
|
provider_credential_secret_variables = self._extract_secret_variables(
|
||||||
|
provider_entity.provider_credential_schema.credential_form_schemas
|
||||||
|
if provider_entity.provider_credential_schema else []
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get decoding rsa key and cipher for decrypting credentials
|
||||||
|
decoding_rsa_key, decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id)
|
||||||
|
|
||||||
|
for variable in provider_credential_secret_variables:
|
||||||
|
if variable in provider_credentials:
|
||||||
|
try:
|
||||||
|
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||||
|
provider_credentials.get(variable),
|
||||||
|
decoding_rsa_key,
|
||||||
|
decoding_cipher_rsa
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
current_using_credentials = provider_credentials
|
||||||
|
|
||||||
|
# cache provider credentials
|
||||||
|
provider_credentials_cache.set(
|
||||||
|
credentials=current_using_credentials
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
current_using_credentials = cached_provider_credentials
|
||||||
else:
|
else:
|
||||||
current_using_credentials = {}
|
current_using_credentials = {}
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Optional, cast, Tuple
|
|||||||
import requests
|
import requests
|
||||||
from flask import current_app
|
from flask import current_app
|
||||||
|
|
||||||
from core.entities.model_entities import ModelWithProviderEntity, ModelStatus, DefaultModelEntity
|
from core.entities.model_entities import ModelStatus
|
||||||
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
from core.model_runtime.entities.model_entities import ModelType, ParameterRule
|
||||||
from core.model_runtime.model_providers import model_provider_factory
|
from core.model_runtime.model_providers import model_provider_factory
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
Loading…
x
Reference in New Issue
Block a user