mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 19:05:55 +08:00
feat: tool credentials cache and introduce _position.yaml (#2386)
This commit is contained in:
parent
6278ff0f30
commit
5010706d8b
49
api/core/helper/tool_provider_cache.py
Normal file
49
api/core/helper/tool_provider_cache.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from json import JSONDecodeError
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class ToolProviderCredentialsCacheType(Enum):
|
||||||
|
PROVIDER = "tool_provider"
|
||||||
|
|
||||||
|
class ToolProviderCredentialsCache:
|
||||||
|
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
||||||
|
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||||
|
|
||||||
|
def get(self) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Get cached model provider credentials.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||||
|
if cached_provider_credentials:
|
||||||
|
try:
|
||||||
|
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
|
||||||
|
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||||
|
except JSONDecodeError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cached_provider_credentials
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Cache model provider credentials.
|
||||||
|
|
||||||
|
:param credentials: provider credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||||
|
|
||||||
|
def delete(self) -> None:
|
||||||
|
"""
|
||||||
|
Delete cached model provider credentials.
|
||||||
|
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
redis_client.delete(self.cache_key)
|
15
api/core/tools/provider/_position.yaml
Normal file
15
api/core/tools/provider/_position.yaml
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
- google
|
||||||
|
- bing
|
||||||
|
- wikipedia
|
||||||
|
- dalle
|
||||||
|
- azuredalle
|
||||||
|
- webscraper
|
||||||
|
- wolframalpha
|
||||||
|
- github
|
||||||
|
- chart
|
||||||
|
- time
|
||||||
|
- yahoo
|
||||||
|
- stablediffusion
|
||||||
|
- vectorizer
|
||||||
|
- youtube
|
||||||
|
- gaode
|
@ -1,31 +1,29 @@
|
|||||||
from typing import List
|
|
||||||
|
|
||||||
from core.tools.entities.user_entities import UserToolProvider
|
from core.tools.entities.user_entities import UserToolProvider
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from typing import List
|
||||||
|
from yaml import load, FullLoader
|
||||||
|
|
||||||
position = {
|
import os.path
|
||||||
'google': 1,
|
|
||||||
'bing': 2,
|
|
||||||
'wikipedia': 2,
|
|
||||||
'dalle': 3,
|
|
||||||
'webscraper': 4,
|
|
||||||
'wolframalpha': 5,
|
|
||||||
'chart': 6,
|
|
||||||
'time': 7,
|
|
||||||
'yahoo': 8,
|
|
||||||
'stablediffusion': 9,
|
|
||||||
'vectorizer': 10,
|
|
||||||
'youtube': 11,
|
|
||||||
'github': 12,
|
|
||||||
'gaode': 13
|
|
||||||
}
|
|
||||||
|
|
||||||
|
position = {}
|
||||||
|
|
||||||
class BuiltinToolProviderSort:
|
class BuiltinToolProviderSort:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
|
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
|
||||||
|
global position
|
||||||
|
if not position:
|
||||||
|
tmp_position = {}
|
||||||
|
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
for pos, val in enumerate(load(f, Loader=FullLoader)):
|
||||||
|
tmp_position[val] = pos
|
||||||
|
position = tmp_position
|
||||||
|
|
||||||
def sort_compare(provider: UserToolProvider) -> int:
|
def sort_compare(provider: UserToolProvider) -> int:
|
||||||
|
# if provider.type == UserToolProvider.ProviderType.MODEL:
|
||||||
|
# return position.get(f'model_provider.{provider.name}', 10000)
|
||||||
return position.get(provider.name, 10000)
|
return position.get(provider.name, 10000)
|
||||||
|
|
||||||
sorted_providers = sorted(providers, key=sort_compare)
|
sorted_providers = sorted(providers, key=sort_compare)
|
||||||
|
|
||||||
return sorted_providers
|
return sorted_providers
|
@ -1,10 +1,10 @@
|
|||||||
from typing import Any, Dict
|
from typing import Dict, Any
|
||||||
|
|
||||||
from core.helper import encrypter
|
|
||||||
from core.tools.entities.tool_entities import ToolProviderCredentials
|
|
||||||
from core.tools.provider.tool_provider import ToolProviderController
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||||
|
from core.tools.provider.tool_provider import ToolProviderController
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache
|
||||||
|
|
||||||
class ToolConfiguration(BaseModel):
|
class ToolConfiguration(BaseModel):
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel):
|
|||||||
|
|
||||||
return a deep copy of credentials with decrypted values
|
return a deep copy of credentials with decrypted values
|
||||||
"""
|
"""
|
||||||
|
cache = ToolProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||||
|
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||||
|
)
|
||||||
|
cached_credentials = cache.get()
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
credentials = self._deep_copy(credentials)
|
credentials = self._deep_copy(credentials)
|
||||||
|
|
||||||
# get fields need to be decrypted
|
# get fields need to be decrypted
|
||||||
fields = self.provider_controller.get_credentials_schema()
|
fields = self.provider_controller.get_credentials_schema()
|
||||||
for field_name, field in fields.items():
|
for field_name, field in fields.items():
|
||||||
@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel):
|
|||||||
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
cache.set(credentials)
|
||||||
return credentials
|
return credentials
|
Loading…
x
Reference in New Issue
Block a user