dify/api/core/tools/utils/configuration.py
2024-03-08 15:22:55 +08:00

164 lines
6.0 KiB
Python

import os
from typing import Any, Union
from pydantic import BaseModel
from yaml import FullLoader, load
from core.helper import encrypter
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.entities.tool_entities import (
ModelToolConfiguration,
ModelToolProviderConfiguration,
ToolProviderCredentials,
)
from core.tools.provider.tool_provider import ToolProviderController
class ToolConfiguration(BaseModel):
tenant_id: str
provider_controller: ToolProviderController
def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
"""
deep copy credentials
"""
return {key: value for key, value in credentials.items()}
def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
"""
encrypt tool credentials with tenant id
return a deep copy of credentials with encrypted values
"""
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
credentials[field_name] = encrypted
return credentials
def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
"""
mask tool credentials
return a deep copy of credentials with masked values
"""
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) +\
credentials[field_name][-2:]
else:
credentials[field_name] = '*' * len(credentials[field_name])
return credentials
def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
"""
decrypt tool credentials with tenant id
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
try:
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
except:
pass
cache.set(credentials)
return credentials
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cache.delete()
class ModelToolConfigurationManager:
"""
Model as tool configuration
"""
_configurations: dict[str, ModelToolProviderConfiguration] = {}
_model_configurations: dict[str, ModelToolConfiguration] = {}
_inited = False
@classmethod
def _init_configuration(cls):
"""
init configuration
"""
absolute_path = os.path.abspath(os.path.dirname(__file__))
model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
# get all .yaml file
files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
for file in files:
provider = file.split('.')[0]
with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
models = configurations.models or []
for model in models:
model_key = f'{provider}.{model.model}'
cls._model_configurations[model_key] = model
cls._configurations[provider] = configurations
cls._inited = True
@classmethod
def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
"""
get configuration by provider
"""
if not cls._inited:
cls._init_configuration()
return cls._configurations.get(provider, None)
@classmethod
def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
"""
get all configurations
"""
if not cls._inited:
cls._init_configuration()
return cls._configurations
@classmethod
def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
"""
get model configuration
"""
key = f'{provider}.{model}'
if not cls._inited:
cls._init_configuration()
return cls._model_configurations.get(key, None)