diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 65d899a002..38f5e51b63 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -4,6 +4,7 @@ from typing import Any, Optional, TextIO, Union from pydantic import BaseModel from configs import dify_config +from core.datasource.entities.datasource_entities import DatasourceInvokeMessage from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.entities.tool_entities import ToolInvokeMessage @@ -105,6 +106,36 @@ class DifyAgentCallbackHandler(BaseModel): self.current_loop += 1 + def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None: + """Run on datasource start.""" + if dify_config.DEBUG: + print_text("\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" + + str(datasource_inputs) + "\n", color=self.color) + + def on_datasource_end(self, datasource_name: str, datasource_inputs: Mapping[str, Any], datasource_outputs: + Iterable[DatasourceInvokeMessage] | str, message_id: Optional[str] = None, + timer: Optional[Any] = None, + trace_manager: Optional[TraceQueueManager] = None) -> None: + """Run on datasource end.""" + if dify_config.DEBUG: + print_text("\n[on_datasource_end]\n", color=self.color) + print_text("Datasource: " + datasource_name + "\n", color=self.color) + print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color) + print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color) + print_text("\n") + + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.DATASOURCE_TRACE, + message_id=message_id, + datasource_name=datasource_name, + datasource_inputs=datasource_inputs, + datasource_outputs=datasource_outputs, + timer=timer, + ) + ) + @property def ignore_agent(self) -> bool: """Whether to ignore agent callbacks.""" diff --git a/api/core/datasource/datasource_engine.py b/api/core/datasource/datasource_engine.py index 423f78a787..86a3b9d0a0 100644 --- a/api/core/datasource/datasource_engine.py +++ b/api/core/datasource/datasource_engine.py @@ -1,36 +1,19 @@ import json from collections.abc import Generator, Iterable -from copy import deepcopy -from datetime import UTC, datetime from mimetypes import guess_type -from typing import Any, Optional, Union, cast +from typing import Any, Optional, cast from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom -from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.entities.datasource_entities import ( + DatasourceInvokeMessage, + DatasourceInvokeMessageBinary, +) from core.file import FileType from core.file.models import FileTransferMethod -from core.ops.ops_trace_manager import TraceQueueManager -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolInvokeMessage, - ToolInvokeMessageBinary, - ToolInvokeMeta, - ToolParameter, -) -from core.tools.errors import ( - ToolEngineInvokeError, - ToolInvokeError, - ToolNotFoundError, - ToolNotSupportedError, - ToolParameterValidationError, - ToolProviderCredentialValidationError, - ToolProviderNotFoundError, -) -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.enums import CreatedByRole from models.model import Message, MessageFile @@ -42,149 +25,39 @@ class DatasourceEngine: """ @staticmethod - def agent_invoke( - tool: Tool, - tool_parameters: Union[str, dict], - user_id: str, - tenant_id: str, - message: Message, - invoke_from: InvokeFrom, - agent_tool_callback: DifyAgentCallbackHandler, - trace_manager: Optional[TraceQueueManager] = None, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> tuple[str, list[str], ToolInvokeMeta]: - """ - Agent invokes the tool with the given arguments. - """ - # check if arguments is a string - if isinstance(tool_parameters, str): - # check if this tool has only one parameter - parameters = [ - parameter - for parameter in tool.get_runtime_parameters() - if parameter.form == ToolParameter.ToolParameterForm.LLM - ] - if parameters and len(parameters) == 1: - tool_parameters = {parameters[0].name: tool_parameters} - else: - try: - tool_parameters = json.loads(tool_parameters) - except Exception: - pass - if not isinstance(tool_parameters, dict): - raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") - - try: - # hit the callback handler - agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) - - messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id) - invocation_meta_dict: dict[str, ToolInvokeMeta] = {} - - def message_callback( - invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None] - ): - for message in messages: - if isinstance(message, ToolInvokeMeta): - invocation_meta_dict["meta"] = message - else: - yield message - - messages = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=message_callback(invocation_meta_dict, messages), - user_id=user_id, - tenant_id=tenant_id, - conversation_id=message.conversation_id, - ) - - message_list = list(messages) - - # extract binary data from tool invoke message - binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list) - # create message file - message_files = ToolEngine._create_message_files( - tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id - ) - - plain_text = ToolEngine._convert_tool_response_to_str(message_list) - - meta = invocation_meta_dict["meta"] - - # hit the callback handler - agent_tool_callback.on_tool_end( - tool_name=tool.entity.identity.name, - tool_inputs=tool_parameters, - tool_outputs=plain_text, - message_id=message.id, - trace_manager=trace_manager, - ) - - # transform tool invoke message to get LLM friendly message - return plain_text, message_files, meta - except ToolProviderCredentialValidationError as e: - error_response = "Please check your tool provider credentials" - agent_tool_callback.on_tool_error(e) - except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e: - error_response = f"there is not a tool named {tool.entity.identity.name}" - agent_tool_callback.on_tool_error(e) - except ToolParameterValidationError as e: - error_response = f"tool parameters validation error: {e}, please check your tool parameters" - agent_tool_callback.on_tool_error(e) - except ToolInvokeError as e: - error_response = f"tool invoke error: {e}" - agent_tool_callback.on_tool_error(e) - except ToolEngineInvokeError as e: - meta = e.meta - error_response = f"tool invoke error: {meta.error}" - agent_tool_callback.on_tool_error(e) - return error_response, [], meta - except Exception as e: - error_response = f"unknown error: {e}" - agent_tool_callback.on_tool_error(e) - - return error_response, [], ToolInvokeMeta.error_instance(error_response) - - @staticmethod - def x( - tool: Tool, - tool_parameters: dict[str, Any], + def invoke_first_step( + datasource: DatasourcePlugin, + datasource_parameters: dict[str, Any], user_id: str, workflow_tool_callback: DifyWorkflowCallbackHandler, - workflow_call_depth: int, - thread_pool_id: Optional[str] = None, conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage, None, None]: + ) -> Generator[DatasourceInvokeMessage, None, None]: """ - Workflow invokes the tool with the given arguments. + Workflow invokes the datasource with the given arguments. """ try: # hit the callback handler - workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters) + workflow_tool_callback.on_datasource_start(datasource_name=datasource.entity.identity.name, + datasource_inputs=datasource_parameters) - if isinstance(tool, WorkflowTool): - tool.workflow_call_depth = workflow_call_depth + 1 - tool.thread_pool_id = thread_pool_id + if datasource.runtime and datasource.runtime.runtime_parameters: + datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters} - if tool.runtime and tool.runtime.runtime_parameters: - tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters} - - response = tool.invoke( + response = datasource._invoke_first_step( user_id=user_id, - tool_parameters=tool_parameters, + datasource_parameters=datasource_parameters, conversation_id=conversation_id, app_id=app_id, message_id=message_id, ) # hit the callback handler - response = workflow_tool_callback.on_tool_execution( - tool_name=tool.entity.identity.name, - tool_inputs=tool_parameters, - tool_outputs=response, + response = workflow_tool_callback.on_datasource_end( + datasource_name=datasource.entity.identity.name, + datasource_inputs=datasource_parameters, + datasource_outputs=response, ) return response @@ -193,61 +66,49 @@ class DatasourceEngine: raise e @staticmethod - def _invoke( - tool: Tool, - tool_parameters: dict, + def invoke_second_step( + datasource: DatasourcePlugin, + datasource_parameters: dict[str, Any], user_id: str, - conversation_id: Optional[str] = None, - app_id: Optional[str] = None, - message_id: Optional[str] = None, - ) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]: + workflow_tool_callback: DifyWorkflowCallbackHandler, + ) -> Generator[DatasourceInvokeMessage, None, None]: """ - Invoke the tool with the given arguments. + Workflow invokes the datasource with the given arguments. """ - started_at = datetime.now(UTC) - meta = ToolInvokeMeta( - time_cost=0.0, - error=None, - tool_config={ - "tool_name": tool.entity.identity.name, - "tool_provider": tool.entity.identity.provider, - "tool_provider_type": tool.tool_provider_type().value, - "tool_parameters": deepcopy(tool.runtime.runtime_parameters), - "tool_icon": tool.entity.identity.icon, - }, - ) try: - yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id) + response = datasource._invoke_second_step( + user_id=user_id, + datasource_parameters=datasource_parameters, + ) + + return response except Exception as e: - meta.error = str(e) - raise ToolEngineInvokeError(meta) - finally: - ended_at = datetime.now(UTC) - meta.time_cost = (ended_at - started_at).total_seconds() - yield meta + workflow_tool_callback.on_tool_error(e) + raise e + @staticmethod - def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str: + def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str: """ - Handle tool response + Handle datasource response """ result = "" - for response in tool_response: - if response.type == ToolInvokeMessage.MessageType.TEXT: - result += cast(ToolInvokeMessage.TextMessage, response.message).text - elif response.type == ToolInvokeMessage.MessageType.LINK: + for response in datasource_response: + if response.type == DatasourceInvokeMessage.MessageType.TEXT: + result += cast(DatasourceInvokeMessage.TextMessage, response.message).text + elif response.type == DatasourceInvokeMessage.MessageType.LINK: result += ( - f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}." + f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}." + " please tell user to check it." ) - elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + elif response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}: result += ( "image has been created and sent to user already, " + "you do not need to create it, just tell the user to check it now." ) - elif response.type == ToolInvokeMessage.MessageType.JSON: + elif response.type == DatasourceInvokeMessage.MessageType.JSON: result = json.dumps( - cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False + cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False ) else: result += str(response.message) @@ -255,14 +116,14 @@ class DatasourceEngine: return result @staticmethod - def _extract_tool_response_binary_and_text( - tool_response: list[ToolInvokeMessage], - ) -> Generator[ToolInvokeMessageBinary, None, None]: + def _extract_datasource_response_binary_and_text( + datasource_response: list[DatasourceInvokeMessage], + ) -> Generator[DatasourceInvokeMessageBinary, None, None]: """ - Extract tool response binary + Extract datasource response binary """ - for response in tool_response: - if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: + for response in datasource_response: + if response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}: mimetype = None if not response.meta: raise ValueError("missing meta data") @@ -270,7 +131,7 @@ class DatasourceEngine: mimetype = response.meta.get("mime_type") else: try: - url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text) + url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text) extension = url.suffix guess_type_result, _ = guess_type(f"a{extension}") if guess_type_result: @@ -281,31 +142,31 @@ class DatasourceEngine: if not mimetype: mimetype = "image/jpeg" - yield ToolInvokeMessageBinary( + yield DatasourceInvokeMessageBinary( mimetype=response.meta.get("mime_type", "image/jpeg"), - url=cast(ToolInvokeMessage.TextMessage, response.message).text, + url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, ) - elif response.type == ToolInvokeMessage.MessageType.BLOB: + elif response.type == DatasourceInvokeMessage.MessageType.BLOB: if not response.meta: raise ValueError("missing meta data") - yield ToolInvokeMessageBinary( + yield DatasourceInvokeMessageBinary( mimetype=response.meta.get("mime_type", "application/octet-stream"), - url=cast(ToolInvokeMessage.TextMessage, response.message).text, + url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, ) - elif response.type == ToolInvokeMessage.MessageType.LINK: + elif response.type == DatasourceInvokeMessage.MessageType.LINK: # check if there is a mime type in meta if response.meta and "mime_type" in response.meta: - yield ToolInvokeMessageBinary( + yield DatasourceInvokeMessageBinary( mimetype=response.meta.get("mime_type", "application/octet-stream") if response.meta else "application/octet-stream", - url=cast(ToolInvokeMessage.TextMessage, response.message).text, + url=cast(DatasourceInvokeMessage.TextMessage, response.message).text, ) @staticmethod def _create_message_files( - tool_messages: Iterable[ToolInvokeMessageBinary], + datasource_messages: Iterable[DatasourceInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str, @@ -317,7 +178,7 @@ class DatasourceEngine: """ result = [] - for message in tool_messages: + for message in datasource_messages: if "image" in message.mimetype: file_type = FileType.IMAGE elif "video" in message.mimetype: diff --git a/api/core/datasource/errors.py b/api/core/datasource/errors.py index c5f9ca4774..c7fc2f85b9 100644 --- a/api/core/datasource/errors.py +++ b/api/core/datasource/errors.py @@ -1,36 +1,36 @@ -from core.tools.entities.tool_entities import ToolInvokeMeta +from core.datasource.entities.datasource_entities import DatasourceInvokeMeta -class ToolProviderNotFoundError(ValueError): +class DatasourceProviderNotFoundError(ValueError): pass -class ToolNotFoundError(ValueError): +class DatasourceNotFoundError(ValueError): pass -class ToolParameterValidationError(ValueError): +class DatasourceParameterValidationError(ValueError): pass -class ToolProviderCredentialValidationError(ValueError): +class DatasourceProviderCredentialValidationError(ValueError): pass -class ToolNotSupportedError(ValueError): +class DatasourceNotSupportedError(ValueError): pass -class ToolInvokeError(ValueError): +class DatasourceInvokeError(ValueError): pass -class ToolApiSchemaError(ValueError): +class DatasourceApiSchemaError(ValueError): pass -class ToolEngineInvokeError(Exception): - meta: ToolInvokeMeta +class DatasourceEngineInvokeError(Exception): + meta: DatasourceInvokeMeta def __init__(self, meta, **kwargs): self.meta = meta diff --git a/api/core/datasource/tool_file_manager.py b/api/core/datasource/tool_file_manager.py deleted file mode 100644 index 7e8d4280d4..0000000000 --- a/api/core/datasource/tool_file_manager.py +++ /dev/null @@ -1,234 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -import os -import time -from mimetypes import guess_extension, guess_type -from typing import Optional, Union -from uuid import uuid4 - -import httpx - -from configs import dify_config -from core.helper import ssrf_proxy -from extensions.ext_database import db -from extensions.ext_storage import storage -from models.model import MessageFile -from models.tools import ToolFile - -logger = logging.getLogger(__name__) - - -class ToolFileManager: - @staticmethod - def sign_file(tool_file_id: str, extension: str) -> str: - """ - sign file to get a temporary url - """ - base_url = dify_config.FILES_URL - file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" - - @staticmethod - def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - """ - verify signature - """ - data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - # verify signature - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT - - @staticmethod - def create_file_by_raw( - *, - user_id: str, - tenant_id: str, - conversation_id: Optional[str], - file_binary: bytes, - mimetype: str, - filename: Optional[str] = None, - ) -> ToolFile: - extension = guess_extension(mimetype) or ".bin" - unique_name = uuid4().hex - unique_filename = f"{unique_name}{extension}" - # default just as before - present_filename = unique_filename - if filename is not None: - has_extension = len(filename.split(".")) > 1 - # Add extension flexibly - present_filename = filename if has_extension else f"{filename}{extension}" - filepath = f"tools/{tenant_id}/{unique_filename}" - storage.save(filepath, file_binary) - - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - name=present_filename, - size=len(file_binary), - ) - - db.session.add(tool_file) - db.session.commit() - db.session.refresh(tool_file) - - return tool_file - - @staticmethod - def create_file_by_url( - user_id: str, - tenant_id: str, - file_url: str, - conversation_id: Optional[str] = None, - ) -> ToolFile: - # try to download image - try: - response = ssrf_proxy.get(file_url) - response.raise_for_status() - blob = response.content - except httpx.TimeoutException: - raise ValueError(f"timeout when downloading file from {file_url}") - - mimetype = ( - guess_type(file_url)[0] - or response.headers.get("Content-Type", "").split(";")[0].strip() - or "application/octet-stream" - ) - extension = guess_extension(mimetype) or ".bin" - unique_name = uuid4().hex - filename = f"{unique_name}{extension}" - filepath = f"tools/{tenant_id}/{filename}" - storage.save(filepath, blob) - - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - original_url=file_url, - name=filename, - size=len(blob), - ) - - db.session.add(tool_file) - db.session.commit() - - return tool_file - - @staticmethod - def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: - """ - get file binary - - :param id: the id of the file - - :return: the binary of the file, mime type - """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == id, - ) - .first() - ) - - if not tool_file: - return None - - blob = storage.load_once(tool_file.file_key) - - return blob, tool_file.mimetype - - @staticmethod - def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: - """ - get file binary - - :param id: the id of the file - - :return: the binary of the file, mime type - """ - message_file: MessageFile | None = ( - db.session.query(MessageFile) - .filter( - MessageFile.id == id, - ) - .first() - ) - - # Check if message_file is not None - if message_file is not None: - # get tool file id - if message_file.url is not None: - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] - else: - tool_file_id = None - else: - tool_file_id = None - - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, - ) - .first() - ) - - if not tool_file: - return None - - blob = storage.load_once(tool_file.file_key) - - return blob, tool_file.mimetype - - @staticmethod - def get_file_generator_by_tool_file_id(tool_file_id: str): - """ - get file binary - - :param tool_file_id: the id of the tool file - - :return: the binary of the file, mime type - """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, - ) - .first() - ) - - if not tool_file: - return None, None - - stream = storage.load_stream(tool_file.file_key) - - return stream, tool_file - - -# init tool_file_parser -from core.file.tool_file_parser import tool_file_manager - -tool_file_manager["manager"] = ToolFileManager diff --git a/api/core/datasource/tool_label_manager.py b/api/core/datasource/tool_label_manager.py deleted file mode 100644 index 4787d7d79c..0000000000 --- a/api/core/datasource/tool_label_manager.py +++ /dev/null @@ -1,101 +0,0 @@ -from core.tools.__base.tool_provider import ToolProviderController -from core.tools.builtin_tool.provider import BuiltinToolProviderController -from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.values import default_tool_label_name_list -from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from extensions.ext_database import db -from models.tools import ToolLabelBinding - - -class ToolLabelManager: - @classmethod - def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]: - """ - Filter tool labels - """ - tool_labels = [label for label in tool_labels if label in default_tool_label_name_list] - return list(set(tool_labels)) - - @classmethod - def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]): - """ - Update tool labels - """ - labels = cls.filter_tool_labels(labels) - - if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id - else: - raise ValueError("Unsupported tool type") - - # delete old labels - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete() - - # insert new labels - for label in labels: - db.session.add( - ToolLabelBinding( - tool_id=provider_id, - tool_type=controller.provider_type.value, - label_name=label, - ) - ) - - db.session.commit() - - @classmethod - def get_tool_labels(cls, controller: ToolProviderController) -> list[str]: - """ - Get tool labels - """ - if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - provider_id = controller.provider_id - elif isinstance(controller, BuiltinToolProviderController): - return controller.tool_labels - else: - raise ValueError("Unsupported tool type") - - labels = ( - db.session.query(ToolLabelBinding.label_name) - .filter( - ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, - ) - .all() - ) - - return [label.label_name for label in labels] - - @classmethod - def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]: - """ - Get tools labels - - :param tool_providers: list of tool providers - - :return: dict of tool labels - :key: tool id - :value: list of tool labels - """ - if not tool_providers: - return {} - - for controller in tool_providers: - if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): - raise ValueError("Unsupported tool type") - - provider_ids = [] - for controller in tool_providers: - assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) - provider_ids.append(controller.provider_id) - - labels: list[ToolLabelBinding] = ( - db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all() - ) - - tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels} - - for label in labels: - tool_labels[label.tool_id].append(label.label_name) - - return tool_labels diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index f0e34c0cd7..be6f3c007a 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -132,3 +132,4 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + DATASOURCE_TRACE = "datasource" \ No newline at end of file