diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 127b8fe76d..7dac252201 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -3,6 +3,7 @@ from threading import Lock from typing import TYPE_CHECKING from contexts.wrapper import RecyclableContextVar +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController if TYPE_CHECKING: from core.model_runtime.entities.model_entities import AIModelEntity @@ -37,3 +38,11 @@ plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(Cont plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar( ContextVar("plugin_model_schemas") ) + +datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = RecyclableContextVar( + ContextVar("datasource_plugin_providers") +) + +datasource_plugin_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( + ContextVar("datasource_plugin_providers_lock") +) diff --git a/api/core/datasource/__base/datasource.py b/api/core/datasource/__base/datasource.py deleted file mode 100644 index 3a67b56e32..0000000000 --- a/api/core/datasource/__base/datasource.py +++ /dev/null @@ -1,221 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Generator -from copy import deepcopy -from typing import TYPE_CHECKING, Any, Optional - -if TYPE_CHECKING: - from models.model import File - -from core.datasource.__base.datasource_runtime import DatasourceRuntime -from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType -from core.tools.entities.tool_entities import ( - ToolInvokeMessage, - ToolParameter, -) - - -class Datasource(ABC): - """ - The base class of a datasource - """ - - entity: DatasourceEntity - runtime: DatasourceRuntime - - def __init__(self, entity: DatasourceEntity, runtime: DatasourceRuntime) -> None: - self.entity = entity - self.runtime = runtime - - def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "Datasource": - """ - fork a new datasource with metadata - :return: the new datasource - """ - return self.__class__( - entity=self.entity.model_copy(), - runtime=runtime, - ) - - @abstractmethod - def datasource_provider_type(self) -> DatasourceProviderType: - """ - get the datasource provider type - - :return: the tool provider type - """ - - def invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage]: - if self.runtime and self.runtime.runtime_parameters: - tool_parameters.update(self.runtime.runtime_parameters) - - # try parse tool parameters into the correct type - tool_parameters = self._transform_tool_parameters_type(tool_parameters) - - result = self._invoke( - user_id=user_id, - tool_parameters=tool_parameters, - conversation_id=conversation_id, - app_id=app_id, - message_id=message_id, - ) - - if isinstance(result, ToolInvokeMessage): - - def single_generator() -> Generator[ToolInvokeMessage, None, None]: - yield result - - return single_generator() - elif isinstance(result, list): - - def generator() -> Generator[ToolInvokeMessage, None, None]: - yield from result - - return generator() - else: - return result - - def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]: - """ - Transform tool parameters type - """ - # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials - result = deepcopy(tool_parameters) - for parameter in self.entity.parameters or []: - if parameter.name in tool_parameters: - result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name]) - - return result - - @abstractmethod - def _invoke( - self, - user_id: str, - tool_parameters: dict[str, Any], - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]: - pass - - def get_runtime_parameters( - self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - """ - get the runtime parameters - - interface for developer to dynamic change the parameters of a tool depends on the variables pool - - :return: the runtime parameters - """ - return self.entity.parameters - - def get_merged_runtime_parameters( - self, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> list[ToolParameter]: - """ - get merged runtime parameters - - :return: merged runtime parameters - """ - parameters = self.entity.parameters - parameters = parameters.copy() - user_parameters = self.get_runtime_parameters() or [] - user_parameters = user_parameters.copy() - - # override parameters - for parameter in user_parameters: - # check if parameter in tool parameters - for tool_parameter in parameters: - if tool_parameter.name == parameter.name: - # override parameter - tool_parameter.type = parameter.type - tool_parameter.form = parameter.form - tool_parameter.required = parameter.required - tool_parameter.default = parameter.default - tool_parameter.options = parameter.options - tool_parameter.llm_description = parameter.llm_description - break - else: - # add new parameter - parameters.append(parameter) - - return parameters - - def create_image_message( - self, - image: str, - ) -> ToolInvokeMessage: - """ - create an image message - - :param image: the url of the image - :return: the image message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image) - ) - - def create_file_message(self, file: "File") -> ToolInvokeMessage: - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.FILE, - message=ToolInvokeMessage.FileMessage(), - meta={"file": file}, - ) - - def create_link_message(self, link: str) -> ToolInvokeMessage: - """ - create a link message - - :param link: the url of the link - :return: the link message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link) - ) - - def create_text_message(self, text: str) -> ToolInvokeMessage: - """ - create a text message - - :param text: the text - :return: the text message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.TEXT, - message=ToolInvokeMessage.TextMessage(text=text), - ) - - def create_blob_message(self, blob: bytes, meta: Optional[dict] = None) -> ToolInvokeMessage: - """ - create a blob message - - :param blob: the blob - :param meta: the meta info of blob object - :return: the blob message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.BLOB, - message=ToolInvokeMessage.BlobMessage(blob=blob), - meta=meta, - ) - - def create_json_message(self, object: dict) -> ToolInvokeMessage: - """ - create a json message - """ - return ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object) - ) diff --git a/api/core/datasource/datasource_tool/tool.py b/api/core/datasource/__base/datasource_plugin.py similarity index 97% rename from api/core/datasource/datasource_tool/tool.py rename to api/core/datasource/__base/datasource_plugin.py index d55c28a9b9..037c0f4630 100644 --- a/api/core/datasource/datasource_tool/tool.py +++ b/api/core/datasource/__base/datasource_plugin.py @@ -1,7 +1,6 @@ from collections.abc import Generator from typing import Any, Optional -from core.datasource.__base.datasource import Datasource from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import ( DatasourceEntity, @@ -13,7 +12,7 @@ from core.plugin.manager.datasource import PluginDatasourceManager from core.plugin.utils.converter import convert_parameters_to_plugin_format -class DatasourcePlugin(Datasource): +class DatasourcePlugin: tenant_id: str icon: str plugin_unique_identifier: str diff --git a/api/core/datasource/datasource_tool/provider.py b/api/core/datasource/__base/datasource_provider.py similarity index 92% rename from api/core/datasource/datasource_tool/provider.py rename to api/core/datasource/__base/datasource_provider.py index 820224eeaa..ba66e2a3c4 100644 --- a/api/core/datasource/datasource_tool/provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -1,15 +1,15 @@ from typing import Any -from core.datasource.datasource_tool.tool import DatasourceTool +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.entities.provider_entities import ProviderConfig from core.plugin.manager.tool import PluginToolManager -from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType from core.tools.errors import ToolProviderCredentialValidationError -class DatasourceToolProviderController(BuiltinToolProviderController): +class DatasourcePluginProviderController(BuiltinToolProviderController): entity: DatasourceProviderEntityWithPlugin tenant_id: str plugin_id: str @@ -45,7 +45,7 @@ class DatasourceToolProviderController(BuiltinToolProviderController): ): raise ToolProviderCredentialValidationError("Invalid credentials") - def get_datasource(self, datasource_name: str) -> DatasourceTool: # type: ignore + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: # type: ignore """ return datasource with given name """ @@ -56,9 +56,9 @@ class DatasourceToolProviderController(BuiltinToolProviderController): if not datasource_entity: raise ValueError(f"Datasource with name {datasource_name} not found") - return DatasourceTool( + return DatasourcePlugin( entity=datasource_entity, - runtime=ToolRuntime(tenant_id=self.tenant_id), + runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, icon=self.entity.identity.icon, plugin_unique_identifier=self.plugin_unique_identifier, @@ -69,9 +69,9 @@ class DatasourceToolProviderController(BuiltinToolProviderController): get all datasources """ return [ - DatasourceTool( + DatasourcePlugin( entity=datasource_entity, - runtime=ToolRuntime(tenant_id=self.tenant_id), + runtime=DatasourceRuntime(tenant_id=self.tenant_id), tenant_id=self.tenant_id, icon=self.entity.identity.icon, plugin_unique_identifier=self.plugin_unique_identifier, diff --git a/api/core/datasource/__base/tool_provider.py b/api/core/datasource/__base/tool_provider.py deleted file mode 100644 index d096fc7df7..0000000000 --- a/api/core/datasource/__base/tool_provider.py +++ /dev/null @@ -1,109 +0,0 @@ -from abc import ABC, abstractmethod -from copy import deepcopy -from typing import Any - -from core.entities.provider_entities import ProviderConfig -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolProviderEntity, - ToolProviderType, -) -from core.tools.errors import ToolProviderCredentialValidationError - - -class ToolProviderController(ABC): - entity: ToolProviderEntity - - def __init__(self, entity: ToolProviderEntity) -> None: - self.entity = entity - - def get_credentials_schema(self) -> list[ProviderConfig]: - """ - returns the credentials schema of the provider - - :return: the credentials schema - """ - return deepcopy(self.entity.credentials_schema) - - @abstractmethod - def get_tool(self, tool_name: str) -> Tool: - """ - returns a tool that the provider can provide - - :return: tool - """ - pass - - @property - def provider_type(self) -> ToolProviderType: - """ - returns the type of the provider - - :return: type of the provider - """ - return ToolProviderType.BUILT_IN - - def validate_credentials_format(self, credentials: dict[str, Any]) -> None: - """ - validate the format of the credentials of the provider and set the default value if needed - - :param credentials: the credentials of the tool - """ - credentials_schema = dict[str, ProviderConfig]() - if credentials_schema is None: - return - - for credential in self.entity.credentials_schema: - credentials_schema[credential.name] = credential - - credentials_need_to_validate: dict[str, ProviderConfig] = {} - for credential_name in credentials_schema: - credentials_need_to_validate[credential_name] = credentials_schema[credential_name] - - for credential_name in credentials: - if credential_name not in credentials_need_to_validate: - raise ToolProviderCredentialValidationError( - f"credential {credential_name} not found in provider {self.entity.identity.name}" - ) - - # check type - credential_schema = credentials_need_to_validate[credential_name] - if not credential_schema.required and credentials[credential_name] is None: - continue - - if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}: - if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") - - elif credential_schema.type == ProviderConfig.Type.SELECT: - if not isinstance(credentials[credential_name], str): - raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string") - - options = credential_schema.options - if not isinstance(options, list): - raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list") - - if credentials[credential_name] not in [x.value for x in options]: - raise ToolProviderCredentialValidationError( - f"credential {credential_name} should be one of {options}" - ) - - credentials_need_to_validate.pop(credential_name) - - for credential_name in credentials_need_to_validate: - credential_schema = credentials_need_to_validate[credential_name] - if credential_schema.required: - raise ToolProviderCredentialValidationError(f"credential {credential_name} is required") - - # the credential is not set currently, set the default value if needed - if credential_schema.default is not None: - default_value = credential_schema.default - # parse default value into the correct type - if credential_schema.type in { - ProviderConfig.Type.SECRET_INPUT, - ProviderConfig.Type.TEXT_INPUT, - ProviderConfig.Type.SELECT, - }: - default_value = str(default_value) - - credentials[credential_name] = default_value diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py new file mode 100644 index 0000000000..6704d4e73a --- /dev/null +++ b/api/core/datasource/datasource_file_manager.py @@ -0,0 +1,244 @@ +import base64 +import hashlib +import hmac +import logging +import os +import time +from mimetypes import guess_extension, guess_type +from typing import Optional, Union +from uuid import uuid4 + +import httpx + +from configs import dify_config +from core.helper import ssrf_proxy +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.enums import CreatedByRole +from models.model import MessageFile, UploadFile +from models.tools import ToolFile + +logger = logging.getLogger(__name__) + + +class DatasourceFileManager: + @staticmethod + def sign_file(datasource_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = dify_config.FILES_URL + file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + @staticmethod + def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def create_file_by_raw( + *, + user_id: str, + tenant_id: str, + conversation_id: Optional[str], + file_binary: bytes, + mimetype: str, + filename: Optional[str] = None, + ) -> UploadFile: + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + unique_filename = f"{unique_name}{extension}" + # default just as before + present_filename = unique_filename + if filename is not None: + has_extension = len(filename.split(".")) > 1 + # Add extension flexibly + present_filename = filename if has_extension else f"{filename}{extension}" + filepath = f"datasources/{tenant_id}/{unique_filename}" + storage.save(filepath, file_binary) + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=filepath, + name=present_filename, + size=len(file_binary), + extension=extension, + mime_type=mimetype, + created_by_role=CreatedByRole.ACCOUNT, + created_by=user_id, + used=False, + hash=hashlib.sha3_256(file_binary).hexdigest(), + source_url="", + ) + + db.session.add(upload_file) + db.session.commit() + db.session.refresh(upload_file) + + return upload_file + + @staticmethod + def create_file_by_url( + user_id: str, + tenant_id: str, + file_url: str, + conversation_id: Optional[str] = None, + ) -> UploadFile: + # try to download image + try: + response = ssrf_proxy.get(file_url) + response.raise_for_status() + blob = response.content + except httpx.TimeoutException: + raise ValueError(f"timeout when downloading file from {file_url}") + + mimetype = ( + guess_type(file_url)[0] + or response.headers.get("Content-Type", "").split(";")[0].strip() + or "application/octet-stream" + ) + extension = guess_extension(mimetype) or ".bin" + unique_name = uuid4().hex + filename = f"{unique_name}{extension}" + filepath = f"tools/{tenant_id}/{filename}" + storage.save(filepath, blob) + + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=dify_config.STORAGE_TYPE, + key=filepath, + name=filename, + size=len(blob), + extension=extension, + mime_type=mimetype, + created_by_role=CreatedByRole.ACCOUNT, + created_by=user_id, + used=False, + hash=hashlib.sha3_256(blob).hexdigest(), + source_url=file_url, + ) + + db.session.add(upload_file) + db.session.commit() + + return upload_file + + @staticmethod + def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + upload_file: UploadFile | None = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == id, + ) + .first() + ) + + if not upload_file: + return None + + blob = storage.load_once(upload_file.key) + + return blob, upload_file.mime_type + + @staticmethod + def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: + """ + get file binary + + :param id: the id of the file + + :return: the binary of the file, mime type + """ + message_file: MessageFile | None = ( + db.session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() + ) + + # Check if message_file is not None + if message_file is not None: + # get tool file id + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None + else: + tool_file_id = None + + tool_file: ToolFile | None = ( + db.session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() + ) + + if not tool_file: + return None + + blob = storage.load_once(tool_file.file_key) + + return blob, tool_file.mimetype + + @staticmethod + def get_file_generator_by_upload_file_id(upload_file_id: str): + """ + get file binary + + :param tool_file_id: the id of the tool file + + :return: the binary of the file, mime type + """ + upload_file: UploadFile | None = ( + db.session.query(UploadFile) + .filter( + UploadFile.id == upload_file_id, + ) + .first() + ) + + if not upload_file: + return None, None + + stream = storage.load_stream(upload_file.key) + + return stream, upload_file.mime_type + + +# init tool_file_parser +from core.file.datasource_file_parser import datasource_file_manager + +datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py new file mode 100644 index 0000000000..9cde52e39f --- /dev/null +++ b/api/core/datasource/datasource_manager.py @@ -0,0 +1,95 @@ + +import logging +from threading import Lock +from typing import Union + +import contexts +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.entities.common_entities import I18nObject +from core.datasource.entities.datasource_entities import DatasourceProviderType +from core.datasource.errors import ToolProviderNotFoundError +from core.plugin.manager.tool import PluginToolManager + +logger = logging.getLogger(__name__) + + +class DatasourceManager: + _builtin_provider_lock = Lock() + _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {} + _builtin_providers_loaded = False + _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} + + @classmethod + def get_datasource_plugin_provider(cls, provider: str, tenant_id: str) -> DatasourcePluginProviderController: + """ + get the datasource plugin provider + """ + # check if context is set + try: + contexts.datasource_plugin_providers.get() + except LookupError: + contexts.datasource_plugin_providers.set({}) + contexts.datasource_plugin_providers_lock.set(Lock()) + + with contexts.datasource_plugin_providers_lock.get(): + datasource_plugin_providers = contexts.datasource_plugin_providers.get() + if provider in datasource_plugin_providers: + return datasource_plugin_providers[provider] + + manager = PluginToolManager() + provider_entity = manager.fetch_tool_provider(tenant_id, provider) + if not provider_entity: + raise ToolProviderNotFoundError(f"plugin provider {provider} not found") + + controller = DatasourcePluginProviderController( + entity=provider_entity.declaration, + plugin_id=provider_entity.plugin_id, + plugin_unique_identifier=provider_entity.plugin_unique_identifier, + tenant_id=tenant_id, + ) + + datasource_plugin_providers[provider] = controller + + return controller + + @classmethod + def get_datasource_runtime( + cls, + provider_type: DatasourceProviderType, + provider_id: str, + datasource_name: str, + tenant_id: str, + ) -> DatasourcePlugin: + """ + get the datasource runtime + + :param provider_type: the type of the provider + :param provider_id: the id of the provider + :param datasource_name: the name of the datasource + :param tenant_id: the tenant id + + :return: the datasource plugin + """ + if provider_type == DatasourceProviderType.RAG_PIPELINE: + return cls.get_datasource_plugin_provider(provider_id, tenant_id).get_datasource(datasource_name) + else: + raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") + + + @classmethod + def list_datasource_providers(cls, tenant_id: str) -> list[DatasourcePluginProviderController]: + """ + list all the datasource providers + """ + manager = PluginToolManager() + provider_entities = manager.fetch_tool_providers(tenant_id) + return [ + DatasourcePluginProviderController( + entity=provider.declaration, + plugin_id=provider.plugin_id, + plugin_unique_identifier=provider.plugin_unique_identifier, + tenant_id=tenant_id, + ) + for provider in provider_entities + ] diff --git a/api/core/datasource/entities/constants.py b/api/core/datasource/entities/constants.py index 199c9f0d53..a4dbf6f11f 100644 --- a/api/core/datasource/entities/constants.py +++ b/api/core/datasource/entities/constants.py @@ -1 +1 @@ -TOOL_SELECTOR_MODEL_IDENTITY = "__dify__tool_selector__" +DATASOURCE_SELECTOR_MODEL_IDENTITY = "__dify__datasource_selector__" diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index de580b270e..80e89ef1a9 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator +from core.datasource.entities.constants import DATASOURCE_SELECTOR_MODEL_IDENTITY from core.entities.provider_entities import ProviderConfig from core.plugin.entities.parameters import ( PluginParameter, @@ -16,7 +17,6 @@ from core.plugin.entities.parameters import ( init_frontend_parameter, ) from core.tools.entities.common_entities import I18nObject -from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY class ToolLabelEnum(Enum): @@ -400,7 +400,7 @@ class DatasourceInvokeFrom(Enum): class DatasourceSelector(BaseModel): - dify_model_identity: str = TOOL_SELECTOR_MODEL_IDENTITY + dify_model_identity: str = DATASOURCE_SELECTOR_MODEL_IDENTITY class Parameter(BaseModel): name: str = Field(..., description="The name of the parameter") diff --git a/api/core/datasource/entities/file_entities.py b/api/core/datasource/entities/file_entities.py deleted file mode 100644 index 8b13789179..0000000000 --- a/api/core/datasource/entities/file_entities.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/api/core/datasource/entities/tool_bundle.py b/api/core/datasource/entities/tool_bundle.py deleted file mode 100644 index ffeeabbc1c..0000000000 --- a/api/core/datasource/entities/tool_bundle.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from core.tools.entities.tool_entities import ToolParameter - - -class ApiToolBundle(BaseModel): - """ - This class is used to store the schema information of an api based tool. - such as the url, the method, the parameters, etc. - """ - - # server_url - server_url: str - # method - method: str - # summary - summary: Optional[str] = None - # operation_id - operation_id: Optional[str] = None - # parameters - parameters: Optional[list[ToolParameter]] = None - # author - author: str - # icon - icon: Optional[str] = None - # openapi operation - openapi: dict diff --git a/api/core/datasource/tool_manager.py b/api/core/datasource/tool_manager.py deleted file mode 100644 index f2d0b74f7c..0000000000 --- a/api/core/datasource/tool_manager.py +++ /dev/null @@ -1,870 +0,0 @@ -import json -import logging -import mimetypes -from collections.abc import Generator -from os import listdir, path -from threading import Lock -from typing import TYPE_CHECKING, Any, Union, cast - -from yarl import URL - -import contexts -from core.plugin.entities.plugin import ToolProviderID -from core.plugin.manager.tool import PluginToolManager -from core.tools.__base.tool_provider import ToolProviderController -from core.tools.__base.tool_runtime import ToolRuntime -from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.plugin_tool.tool import PluginTool -from core.tools.workflow_as_tool.provider import WorkflowToolProviderController - -if TYPE_CHECKING: - from core.workflow.nodes.tool.entities import ToolEntity - - -from configs import dify_config -from core.agent.entities import AgentToolEntity -from core.app.entities.app_invoke_entities import InvokeFrom -from core.helper.module_import_helper import load_single_subclass_from_source -from core.helper.position_helper import is_filtered -from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.__base.tool import Tool -from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.custom_tool.tool import ApiTool -from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral -from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ( - ApiProviderAuthType, - ToolInvokeFrom, - ToolParameter, - ToolProviderType, -) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError -from core.tools.tool_label_manager import ToolLabelManager -from core.tools.utils.configuration import ( - ProviderConfigEncrypter, - ToolParameterConfigurationManager, -) -from core.tools.workflow_as_tool.tool import WorkflowTool -from extensions.ext_database import db -from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider -from services.tools.tools_transform_service import ToolTransformService - -logger = logging.getLogger(__name__) - - -class ToolManager: - _builtin_provider_lock = Lock() - _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} - _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} - - @classmethod - def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: - """ - get the hardcoded provider - """ - if len(cls._hardcoded_providers) == 0: - # init the builtin providers - cls.load_hardcoded_providers_cache() - - return cls._hardcoded_providers[provider] - - @classmethod - def get_builtin_provider( - cls, provider: str, tenant_id: str - ) -> BuiltinToolProviderController | PluginToolProviderController: - """ - get the builtin provider - - :param provider: the name of the provider - :param tenant_id: the id of the tenant - :return: the provider - """ - # split provider to - - if len(cls._hardcoded_providers) == 0: - # init the builtin providers - cls.load_hardcoded_providers_cache() - - if provider not in cls._hardcoded_providers: - # get plugin provider - plugin_provider = cls.get_plugin_provider(provider, tenant_id) - if plugin_provider: - return plugin_provider - - return cls._hardcoded_providers[provider] - - @classmethod - def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController: - """ - get the plugin provider - """ - # check if context is set - try: - contexts.plugin_tool_providers.get() - except LookupError: - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(Lock()) - - with contexts.plugin_tool_providers_lock.get(): - plugin_tool_providers = contexts.plugin_tool_providers.get() - if provider in plugin_tool_providers: - return plugin_tool_providers[provider] - - manager = PluginToolManager() - provider_entity = manager.fetch_tool_provider(tenant_id, provider) - if not provider_entity: - raise ToolProviderNotFoundError(f"plugin provider {provider} not found") - - controller = PluginToolProviderController( - entity=provider_entity.declaration, - plugin_id=provider_entity.plugin_id, - plugin_unique_identifier=provider_entity.plugin_unique_identifier, - tenant_id=tenant_id, - ) - - plugin_tool_providers[provider] = controller - - return controller - - @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: - """ - get the builtin tool - - :param provider: the name of the provider - :param tool_name: the name of the tool - :param tenant_id: the id of the tenant - :return: the provider, the tool - """ - provider_controller = cls.get_builtin_provider(provider, tenant_id) - tool = provider_controller.get_tool(tool_name) - if tool is None: - raise ToolNotFoundError(f"tool {tool_name} not found") - - return tool - - @classmethod - def get_tool_runtime( - cls, - provider_type: ToolProviderType, - provider_id: str, - tool_name: str, - tenant_id: str, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, - ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]: - """ - get the tool runtime - - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :param tool_name: the name of the tool - :param tenant_id: the tenant id - :param invoke_from: invoke from - :param tool_invoke_from: the tool invoke from - - :return: the tool - """ - if provider_type == ToolProviderType.BUILT_IN: - # check if the builtin tool need credentials - provider_controller = cls.get_builtin_provider(provider_id, tenant_id) - - builtin_tool = provider_controller.get_tool(tool_name) - if not builtin_tool: - raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found") - - if not provider_controller.need_credentials: - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - - if isinstance(provider_controller, PluginToolProviderController): - provider_id_entity = ToolProviderID(provider_id) - # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .first() - ) - - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - else: - builtin_provider = ( - db.session.query(BuiltinToolProvider) - .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) - .first() - ) - - if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - - # decrypt the credentials - credentials = builtin_provider.credentials - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - - decrypted_credentials = tool_configuration.decrypt(credentials) - - return cast( - BuiltinTool, - builtin_tool.fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials=decrypted_credentials, - runtime_parameters={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - - elif provider_type == ToolProviderType.API: - api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - - # decrypt the credentials - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], - provider_type=api_provider.provider_type.value, - provider_identity=api_provider.entity.identity.name, - ) - decrypted_credentials = tool_configuration.decrypt(credentials) - - return cast( - ApiTool, - api_provider.get_tool(tool_name).fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials=decrypted_credentials, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - elif provider_type == ToolProviderType.WORKFLOW: - workflow_provider = ( - db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() - ) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) - controller_tools: list[WorkflowTool] = controller.get_tools(tenant_id=workflow_provider.tenant_id) - if controller_tools is None or len(controller_tools) == 0: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - return cast( - WorkflowTool, - controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( - runtime=ToolRuntime( - tenant_id=tenant_id, - credentials={}, - invoke_from=invoke_from, - tool_invoke_from=tool_invoke_from, - ) - ), - ) - elif provider_type == ToolProviderType.APP: - raise NotImplementedError("app provider not implemented") - elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) - else: - raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") - - @classmethod - def get_agent_tool_runtime( - cls, - tenant_id: str, - app_id: str, - agent_tool: AgentToolEntity, - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - ) -> Tool: - """ - get the agent tool runtime - """ - tool_entity = cls.get_tool_runtime( - provider_type=agent_tool.provider_type, - provider_id=agent_tool.provider_id, - tool_name=agent_tool.tool_name, - tenant_id=tenant_id, - invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.AGENT, - ) - runtime_parameters = {} - parameters = tool_entity.get_merged_runtime_parameters() - for parameter in parameters: - # check file types - if ( - parameter.type - in { - ToolParameter.ToolParameterType.SYSTEM_FILES, - ToolParameter.ToolParameterType.FILE, - ToolParameter.ToolParameterType.FILES, - } - and parameter.required - ): - raise ValueError(f"file type parameter {parameter.name} not supported in agent") - - if parameter.form == ToolParameter.ToolParameterForm.FORM: - # save tool parameter to tool entity memory - value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name)) - runtime_parameters[parameter.name] = value - - # decrypt runtime parameters - encryption_manager = ToolParameterConfigurationManager( - tenant_id=tenant_id, - tool_runtime=tool_entity, - provider_name=agent_tool.provider_id, - provider_type=agent_tool.provider_type, - identity_id=f"AGENT.{app_id}", - ) - runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - if tool_entity.runtime is None or tool_entity.runtime.runtime_parameters is None: - raise ValueError("runtime not found or runtime parameters not found") - - tool_entity.runtime.runtime_parameters.update(runtime_parameters) - return tool_entity - - @classmethod - def get_workflow_tool_runtime( - cls, - tenant_id: str, - app_id: str, - node_id: str, - workflow_tool: "ToolEntity", - invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, - ) -> Tool: - """ - get the workflow tool runtime - """ - tool_runtime = cls.get_tool_runtime( - provider_type=workflow_tool.provider_type, - provider_id=workflow_tool.provider_id, - tool_name=workflow_tool.tool_name, - tenant_id=tenant_id, - invoke_from=invoke_from, - tool_invoke_from=ToolInvokeFrom.WORKFLOW, - ) - runtime_parameters = {} - parameters = tool_runtime.get_merged_runtime_parameters() - - for parameter in parameters: - # save tool parameter to tool entity memory - if parameter.form == ToolParameter.ToolParameterForm.FORM: - value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name)) - runtime_parameters[parameter.name] = value - - # decrypt runtime parameters - encryption_manager = ToolParameterConfigurationManager( - tenant_id=tenant_id, - tool_runtime=tool_runtime, - provider_name=workflow_tool.provider_id, - provider_type=workflow_tool.provider_type, - identity_id=f"WORKFLOW.{app_id}.{node_id}", - ) - - if runtime_parameters: - runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) - - tool_runtime.runtime.runtime_parameters.update(runtime_parameters) - return tool_runtime - - @classmethod - def get_tool_runtime_from_plugin( - cls, - tool_type: ToolProviderType, - tenant_id: str, - provider: str, - tool_name: str, - tool_parameters: dict[str, Any], - ) -> Tool: - """ - get tool runtime from plugin - """ - tool_entity = cls.get_tool_runtime( - provider_type=tool_type, - provider_id=provider, - tool_name=tool_name, - tenant_id=tenant_id, - invoke_from=InvokeFrom.SERVICE_API, - tool_invoke_from=ToolInvokeFrom.PLUGIN, - ) - runtime_parameters = {} - parameters = tool_entity.get_merged_runtime_parameters() - for parameter in parameters: - if parameter.form == ToolParameter.ToolParameterForm.FORM: - # save tool parameter to tool entity memory - value = parameter.init_frontend_parameter(tool_parameters.get(parameter.name)) - runtime_parameters[parameter.name] = value - - tool_entity.runtime.runtime_parameters.update(runtime_parameters) - return tool_entity - - @classmethod - def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]: - """ - get the absolute path of the icon of the hardcoded provider - - :param provider: the name of the provider - :return: the absolute path of the icon, the mime type of the icon - """ - # get provider - provider_controller = cls.get_hardcoded_provider(provider) - - absolute_path = path.join( - path.dirname(path.realpath(__file__)), - "builtin_tool", - "providers", - provider, - "_assets", - provider_controller.entity.identity.icon, - ) - # check if the icon exists - if not path.exists(absolute_path): - raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") - - # get the mime type - mime_type, _ = mimetypes.guess_type(absolute_path) - mime_type = mime_type or "application/octet-stream" - - return absolute_path, mime_type - - @classmethod - def list_hardcoded_providers(cls): - # use cache first - if cls._builtin_providers_loaded: - yield from list(cls._hardcoded_providers.values()) - return - - with cls._builtin_provider_lock: - if cls._builtin_providers_loaded: - yield from list(cls._hardcoded_providers.values()) - return - - yield from cls._list_hardcoded_providers() - - @classmethod - def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]: - """ - list all the plugin providers - """ - manager = PluginToolManager() - provider_entities = manager.fetch_tool_providers(tenant_id) - return [ - PluginToolProviderController( - entity=provider.declaration, - plugin_id=provider.plugin_id, - plugin_unique_identifier=provider.plugin_unique_identifier, - tenant_id=tenant_id, - ) - for provider in provider_entities - ] - - @classmethod - def list_builtin_providers( - cls, tenant_id: str - ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]: - """ - list all the builtin providers - """ - yield from cls.list_hardcoded_providers() - # get plugin providers - yield from cls.list_plugin_providers(tenant_id) - - @classmethod - def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]: - """ - list all the builtin providers - """ - for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")): - if provider_path.startswith("__"): - continue - - if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)): - if provider_path.startswith("__"): - continue - - # init provider - try: - provider_class = load_single_subclass_from_source( - module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}", - script_path=path.join( - path.dirname(path.realpath(__file__)), - "builtin_tool", - "providers", - provider_path, - f"{provider_path}.py", - ), - parent_type=BuiltinToolProviderController, - ) - provider: BuiltinToolProviderController = provider_class() - cls._hardcoded_providers[provider.entity.identity.name] = provider - for tool in provider.get_tools(): - cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label - yield provider - - except Exception: - logger.exception(f"load builtin provider {provider}") - continue - # set builtin providers loaded - cls._builtin_providers_loaded = True - - @classmethod - def load_hardcoded_providers_cache(cls): - for _ in cls.list_hardcoded_providers(): - pass - - @classmethod - def clear_hardcoded_providers_cache(cls): - cls._hardcoded_providers = {} - cls._builtin_providers_loaded = False - - @classmethod - def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]: - """ - get the tool label - - :param tool_name: the name of the tool - - :return: the label of the tool - """ - if len(cls._builtin_tools_labels) == 0: - # init the builtin providers - cls.load_hardcoded_providers_cache() - - if tool_name not in cls._builtin_tools_labels: - return None - - return cls._builtin_tools_labels[tool_name] - - @classmethod - def list_providers_from_api( - cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral - ) -> list[ToolProviderApiEntity]: - result_providers: dict[str, ToolProviderApiEntity] = {} - - filters = [] - if not typ: - filters.extend(["builtin", "api", "workflow"]) - else: - filters.append(typ) - - with db.session.no_autoflush: - if "builtin" in filters: - # get builtin providers - builtin_providers = cls.list_builtin_providers(tenant_id) - - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - tool_provider_id = str(ToolProviderID(db_provider.provider)) - db_provider.provider = tool_provider_id - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) - - # append builtin providers - for provider in builtin_providers: - # handle include, exclude - if is_filtered( - include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET), - exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET), - data=provider, - name_func=lambda x: x.identity.name, - ): - continue - - user_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), - decrypt_credentials=False, - ) - - if isinstance(provider, PluginToolProviderController): - result_providers[f"plugin_provider.{user_provider.name}"] = user_provider - else: - result_providers[f"builtin_provider.{user_provider.name}"] = user_provider - - # get db api providers - - if "api" in filters: - db_api_providers: list[ApiToolProvider] = ( - db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() - ) - - api_provider_controllers: list[dict[str, Any]] = [ - {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)} - for provider in db_api_providers - ] - - # get labels - labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers]) - - for api_provider_controller in api_provider_controllers: - user_provider = ToolTransformService.api_provider_to_user_provider( - provider_controller=api_provider_controller["controller"], - db_provider=api_provider_controller["provider"], - decrypt_credentials=False, - labels=labels.get(api_provider_controller["controller"].provider_id, []), - ) - result_providers[f"api_provider.{user_provider.name}"] = user_provider - - if "workflow" in filters: - # get workflow providers - workflow_providers: list[WorkflowToolProvider] = ( - db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all() - ) - - workflow_provider_controllers: list[WorkflowToolProviderController] = [] - for provider in workflow_providers: - try: - workflow_provider_controllers.append( - ToolTransformService.workflow_provider_to_controller(db_provider=provider) - ) - except Exception: - # app has been deleted - pass - - labels = ToolLabelManager.get_tools_labels( - [cast(ToolProviderController, controller) for controller in workflow_provider_controllers] - ) - - for provider_controller in workflow_provider_controllers: - user_provider = ToolTransformService.workflow_provider_to_user_provider( - provider_controller=provider_controller, - labels=labels.get(provider_controller.provider_id, []), - ) - result_providers[f"workflow_provider.{user_provider.name}"] = user_provider - - return BuiltinToolProviderSort.sort(list(result_providers.values())) - - @classmethod - def get_api_provider_controller( - cls, tenant_id: str, provider_id: str - ) -> tuple[ApiToolProviderController, dict[str, Any]]: - """ - get the api provider - - :param tenant_id: the id of the tenant - :param provider_id: the id of the provider - - :return: the provider controller, the credentials - """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) - .filter( - ApiToolProvider.id == provider_id, - ApiToolProvider.tenant_id == tenant_id, - ) - .first() - ) - - if provider is None: - raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - - controller = ApiToolProviderController.from_db( - provider, - ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, - ) - controller.load_bundled_tools(provider.tools) - - return controller, provider.credentials - - @classmethod - def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict: - """ - get api provider - """ - """ - get tool provider - """ - provider_name = provider - provider_obj: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) - .filter( - ApiToolProvider.tenant_id == tenant_id, - ApiToolProvider.name == provider, - ) - .first() - ) - - if provider_obj is None: - raise ValueError(f"you have not added provider {provider_name}") - - try: - credentials = json.loads(provider_obj.credentials_str) or {} - except Exception: - credentials = {} - - # package tool provider controller - controller = ApiToolProviderController.from_db( - provider_obj, - ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE, - ) - # init tool configuration - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, - ) - - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) - - try: - icon = json.loads(provider_obj.icon) - except Exception: - icon = {"background": "#252525", "content": "\ud83d\ude01"} - - # add tool labels - labels = ToolLabelManager.get_tool_labels(controller) - - return cast( - dict, - jsonable_encoder( - { - "schema_type": provider_obj.schema_type, - "schema": provider_obj.schema, - "tools": provider_obj.tools, - "icon": icon, - "description": provider_obj.description, - "credentials": masked_credentials, - "privacy_policy": provider_obj.privacy_policy, - "custom_disclaimer": provider_obj.custom_disclaimer, - "labels": labels, - } - ), - ) - - @classmethod - def generate_builtin_tool_icon_url(cls, provider_id: str) -> str: - return str( - URL(dify_config.CONSOLE_API_URL or "/") - / "console" - / "api" - / "workspaces" - / "current" - / "tool-provider" - / "builtin" - / provider_id - / "icon" - ) - - @classmethod - def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str: - return str( - URL(dify_config.CONSOLE_API_URL or "/") - / "console" - / "api" - / "workspaces" - / "current" - / "plugin" - / "icon" - % {"tenant_id": tenant_id, "filename": filename} - ) - - @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: - try: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) - .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() - ) - - if workflow_provider is None: - raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found") - - icon: dict = json.loads(workflow_provider.icon) - return icon - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - - @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict: - try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) - .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() - ) - - if api_provider is None: - raise ToolProviderNotFoundError(f"api provider {provider_id} not found") - - icon: dict = json.loads(api_provider.icon) - return icon - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - - @classmethod - def get_tool_icon( - cls, - tenant_id: str, - provider_type: ToolProviderType, - provider_id: str, - ) -> Union[str, dict]: - """ - get the tool icon - - :param tenant_id: the id of the tenant - :param provider_type: the type of the provider - :param provider_id: the id of the provider - :return: - """ - provider_type = provider_type - provider_id = provider_id - if provider_type == ToolProviderType.BUILT_IN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - return cls.generate_builtin_tool_icon_url(provider_id) - elif provider_type == ToolProviderType.API: - return cls.generate_api_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.WORKFLOW: - return cls.generate_workflow_tool_icon_url(tenant_id, provider_id) - elif provider_type == ToolProviderType.PLUGIN: - provider = ToolManager.get_builtin_provider(provider_id, tenant_id) - if isinstance(provider, PluginToolProviderController): - try: - return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon) - except Exception: - return {"background": "#252525", "content": "\ud83d\ude01"} - raise ValueError(f"plugin provider {provider_id} not found") - else: - raise ValueError(f"provider type {provider_type} not found") - - -ToolManager.load_hardcoded_providers_cache() diff --git a/api/core/file/datasource_file_parser.py b/api/core/file/datasource_file_parser.py new file mode 100644 index 0000000000..52687951ac --- /dev/null +++ b/api/core/file/datasource_file_parser.py @@ -0,0 +1,15 @@ +from typing import TYPE_CHECKING, Any, cast + +from core.datasource import datasource_file_manager +from core.datasource.datasource_file_manager import DatasourceFileManager + +if TYPE_CHECKING: + from core.datasource.datasource_file_manager import DatasourceFileManager + +tool_file_manager: dict[str, Any] = {"manager": None} + + +class DatasourceFileParser: + @staticmethod + def get_datasource_file_manager() -> "DatasourceFileManager": + return cast("DatasourceFileManager", datasource_file_manager["manager"])