diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py new file mode 100644 index 0000000000..6c5d3b8fb6 --- /dev/null +++ b/api/core/helper/tool_provider_cache.py @@ -0,0 +1,49 @@ +import json +from enum import Enum +from json import JSONDecodeError +from typing import Optional + +from extensions.ext_redis import redis_client + + +class ToolProviderCredentialsCacheType(Enum): + PROVIDER = "tool_provider" + +class ToolProviderCredentialsCache: + def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): + self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + + def get(self) -> Optional[dict]: + """ + Get cached model provider credentials. + + :return: + """ + cached_provider_credentials = redis_client.get(self.cache_key) + if cached_provider_credentials: + try: + cached_provider_credentials = cached_provider_credentials.decode('utf-8') + cached_provider_credentials = json.loads(cached_provider_credentials) + except JSONDecodeError: + return None + + return cached_provider_credentials + else: + return None + + def set(self, credentials: dict) -> None: + """ + Cache model provider credentials. + + :param credentials: provider credentials + :return: + """ + redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) + + def delete(self) -> None: + """ + Delete cached model provider credentials. + + :return: + """ + redis_client.delete(self.cache_key) \ No newline at end of file diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml new file mode 100644 index 0000000000..7fb762cd74 --- /dev/null +++ b/api/core/tools/provider/_position.yaml @@ -0,0 +1,15 @@ +- google +- bing +- wikipedia +- dalle +- azuredalle +- webscraper +- wolframalpha +- github +- chart +- time +- yahoo +- stablediffusion +- vectorizer +- youtube +- gaode diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index d4e47640ab..ccee002185 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,31 +1,29 @@ -from typing import List - from core.tools.entities.user_entities import UserToolProvider +from core.tools.entities.tool_entities import ToolProviderType +from typing import List +from yaml import load, FullLoader -position = { - 'google': 1, - 'bing': 2, - 'wikipedia': 2, - 'dalle': 3, - 'webscraper': 4, - 'wolframalpha': 5, - 'chart': 6, - 'time': 7, - 'yahoo': 8, - 'stablediffusion': 9, - 'vectorizer': 10, - 'youtube': 11, - 'github': 12, - 'gaode': 13 -} +import os.path +position = {} class BuiltinToolProviderSort: @staticmethod def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]: + global position + if not position: + tmp_position = {} + file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') + with open(file_path, 'r') as f: + for pos, val in enumerate(load(f, Loader=FullLoader)): + tmp_position[val] = pos + position = tmp_position + def sort_compare(provider: UserToolProvider) -> int: + # if provider.type == UserToolProvider.ProviderType.MODEL: + # return position.get(f'model_provider.{provider.name}', 10000) return position.get(provider.name, 10000) sorted_providers = sorted(providers, key=sort_compare) - return sorted_providers + return sorted_providers \ No newline at end of file diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 21609ce302..a43e5f218e 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,10 +1,10 @@ -from typing import Any, Dict - -from core.helper import encrypter -from core.tools.entities.tool_entities import ToolProviderCredentials -from core.tools.provider.tool_provider import ToolProviderController +from typing import Dict, Any from pydantic import BaseModel +from core.tools.entities.tool_entities import ToolProviderCredentials +from core.tools.provider.tool_provider import ToolProviderController +from core.helper import encrypter +from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache class ToolConfiguration(BaseModel): tenant_id: str @@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel): return a deep copy of credentials with decrypted values """ + cache = ToolProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', + cache_type=ToolProviderCredentialsCacheType.PROVIDER + ) + cached_credentials = cache.get() + if cached_credentials: + return cached_credentials credentials = self._deep_copy(credentials) - # get fields need to be decrypted fields = self.provider_controller.get_credentials_schema() for field_name, field in fields.items(): @@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel): credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name]) except: pass - + + cache.set(credentials) return credentials \ No newline at end of file