This commit is contained in:
jyong 2025-04-25 15:49:36 +08:00
parent 389f15f8e3
commit d4007ae073
6 changed files with 105 additions and 547 deletions

View File

@ -4,6 +4,7 @@ from typing import Any, Optional, TextIO, Union
from pydantic import BaseModel
from configs import dify_config
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.entities.tool_entities import ToolInvokeMessage
@ -105,6 +106,36 @@ class DifyAgentCallbackHandler(BaseModel):
self.current_loop += 1
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
"""Run on datasource start."""
if dify_config.DEBUG:
print_text("\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" +
str(datasource_inputs) + "\n", color=self.color)
def on_datasource_end(self, datasource_name: str, datasource_inputs: Mapping[str, Any], datasource_outputs:
Iterable[DatasourceInvokeMessage] | str, message_id: Optional[str] = None,
timer: Optional[Any] = None,
trace_manager: Optional[TraceQueueManager] = None) -> None:
"""Run on datasource end."""
if dify_config.DEBUG:
print_text("\n[on_datasource_end]\n", color=self.color)
print_text("Datasource: " + datasource_name + "\n", color=self.color)
print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color)
print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color)
print_text("\n")
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASOURCE_TRACE,
message_id=message_id,
datasource_name=datasource_name,
datasource_inputs=datasource_inputs,
datasource_outputs=datasource_outputs,
timer=timer,
)
)
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""

View File

@ -1,36 +1,19 @@
import json
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
from mimetypes import guess_type
from typing import Any, Optional, Union, cast
from typing import Any, Optional, cast
from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.entities.datasource_entities import (
DatasourceInvokeMessage,
DatasourceInvokeMessageBinary,
)
from core.file import FileType
from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolInvokeMessage,
ToolInvokeMessageBinary,
ToolInvokeMeta,
ToolParameter,
)
from core.tools.errors import (
ToolEngineInvokeError,
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.enums import CreatedByRole
from models.model import Message, MessageFile
@ -42,149 +25,39 @@ class DatasourceEngine:
"""
@staticmethod
def agent_invoke(
tool: Tool,
tool_parameters: Union[str, dict],
user_id: str,
tenant_id: str,
message: Message,
invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler,
trace_manager: Optional[TraceQueueManager] = None,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> tuple[str, list[str], ToolInvokeMeta]:
"""
Agent invokes the tool with the given arguments.
"""
# check if arguments is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter
for parameter in tool.get_runtime_parameters()
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
tool_parameters = {parameters[0].name: tool_parameters}
else:
try:
tool_parameters = json.loads(tool_parameters)
except Exception:
pass
if not isinstance(tool_parameters, dict):
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
try:
# hit the callback handler
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
def message_callback(
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
):
for message in messages:
if isinstance(message, ToolInvokeMeta):
invocation_meta_dict["meta"] = message
else:
yield message
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=message_callback(invocation_meta_dict, messages),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=message.conversation_id,
)
message_list = list(messages)
# extract binary data from tool invoke message
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
# create message file
message_files = ToolEngine._create_message_files(
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
)
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
meta = invocation_meta_dict["meta"]
# hit the callback handler
agent_tool_callback.on_tool_end(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=plain_text,
message_id=message.id,
trace_manager=trace_manager,
)
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, [], ToolInvokeMeta.error_instance(error_response)
@staticmethod
def x(
tool: Tool,
tool_parameters: dict[str, Any],
def invoke_first_step(
datasource: DatasourcePlugin,
datasource_parameters: dict[str, Any],
user_id: str,
workflow_tool_callback: DifyWorkflowCallbackHandler,
workflow_call_depth: int,
thread_pool_id: Optional[str] = None,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
) -> Generator[DatasourceInvokeMessage, None, None]:
"""
Workflow invokes the tool with the given arguments.
Workflow invokes the datasource with the given arguments.
"""
try:
# hit the callback handler
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
workflow_tool_callback.on_datasource_start(datasource_name=datasource.entity.identity.name,
datasource_inputs=datasource_parameters)
if isinstance(tool, WorkflowTool):
tool.workflow_call_depth = workflow_call_depth + 1
tool.thread_pool_id = thread_pool_id
if datasource.runtime and datasource.runtime.runtime_parameters:
datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters}
if tool.runtime and tool.runtime.runtime_parameters:
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
response = tool.invoke(
response = datasource._invoke_first_step(
user_id=user_id,
tool_parameters=tool_parameters,
datasource_parameters=datasource_parameters,
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
)
# hit the callback handler
response = workflow_tool_callback.on_tool_execution(
tool_name=tool.entity.identity.name,
tool_inputs=tool_parameters,
tool_outputs=response,
response = workflow_tool_callback.on_datasource_end(
datasource_name=datasource.entity.identity.name,
datasource_inputs=datasource_parameters,
datasource_outputs=response,
)
return response
@ -193,61 +66,49 @@ class DatasourceEngine:
raise e
@staticmethod
def _invoke(
tool: Tool,
tool_parameters: dict,
def invoke_second_step(
datasource: DatasourcePlugin,
datasource_parameters: dict[str, Any],
user_id: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
workflow_tool_callback: DifyWorkflowCallbackHandler,
) -> Generator[DatasourceInvokeMessage, None, None]:
"""
Invoke the tool with the given arguments.
Workflow invokes the datasource with the given arguments.
"""
started_at = datetime.now(UTC)
meta = ToolInvokeMeta(
time_cost=0.0,
error=None,
tool_config={
"tool_name": tool.entity.identity.name,
"tool_provider": tool.entity.identity.provider,
"tool_provider_type": tool.tool_provider_type().value,
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
"tool_icon": tool.entity.identity.icon,
},
)
try:
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
response = datasource._invoke_second_step(
user_id=user_id,
datasource_parameters=datasource_parameters,
)
return response
except Exception as e:
meta.error = str(e)
raise ToolEngineInvokeError(meta)
finally:
ended_at = datetime.now(UTC)
meta.time_cost = (ended_at - started_at).total_seconds()
yield meta
workflow_tool_callback.on_tool_error(e)
raise e
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str:
"""
Handle tool response
Handle datasource response
"""
result = ""
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += cast(ToolInvokeMessage.TextMessage, response.message).text
elif response.type == ToolInvokeMessage.MessageType.LINK:
for response in datasource_response:
if response.type == DatasourceInvokeMessage.MessageType.TEXT:
result += cast(DatasourceInvokeMessage.TextMessage, response.message).text
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
result += (
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}."
+ " please tell user to check it."
)
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
elif response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
result += (
"image has been created and sent to user already, "
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
elif response.type == DatasourceInvokeMessage.MessageType.JSON:
result = json.dumps(
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
)
else:
result += str(response.message)
@ -255,14 +116,14 @@ class DatasourceEngine:
return result
@staticmethod
def _extract_tool_response_binary_and_text(
tool_response: list[ToolInvokeMessage],
) -> Generator[ToolInvokeMessageBinary, None, None]:
def _extract_datasource_response_binary_and_text(
datasource_response: list[DatasourceInvokeMessage],
) -> Generator[DatasourceInvokeMessageBinary, None, None]:
"""
Extract tool response binary
Extract datasource response binary
"""
for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
for response in datasource_response:
if response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
mimetype = None
if not response.meta:
raise ValueError("missing meta data")
@ -270,7 +131,7 @@ class DatasourceEngine:
mimetype = response.meta.get("mime_type")
else:
try:
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text)
extension = url.suffix
guess_type_result, _ = guess_type(f"a{extension}")
if guess_type_result:
@ -281,31 +142,31 @@ class DatasourceEngine:
if not mimetype:
mimetype = "image/jpeg"
yield ToolInvokeMessageBinary(
yield DatasourceInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "image/jpeg"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
elif response.type == DatasourceInvokeMessage.MessageType.BLOB:
if not response.meta:
raise ValueError("missing meta data")
yield ToolInvokeMessageBinary(
yield DatasourceInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream"),
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and "mime_type" in response.meta:
yield ToolInvokeMessageBinary(
yield DatasourceInvokeMessageBinary(
mimetype=response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream",
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
)
@staticmethod
def _create_message_files(
tool_messages: Iterable[ToolInvokeMessageBinary],
datasource_messages: Iterable[DatasourceInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str,
@ -317,7 +178,7 @@ class DatasourceEngine:
"""
result = []
for message in tool_messages:
for message in datasource_messages:
if "image" in message.mimetype:
file_type = FileType.IMAGE
elif "video" in message.mimetype:

View File

@ -1,36 +1,36 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
class ToolProviderNotFoundError(ValueError):
class DatasourceProviderNotFoundError(ValueError):
pass
class ToolNotFoundError(ValueError):
class DatasourceNotFoundError(ValueError):
pass
class ToolParameterValidationError(ValueError):
class DatasourceParameterValidationError(ValueError):
pass
class ToolProviderCredentialValidationError(ValueError):
class DatasourceProviderCredentialValidationError(ValueError):
pass
class ToolNotSupportedError(ValueError):
class DatasourceNotSupportedError(ValueError):
pass
class ToolInvokeError(ValueError):
class DatasourceInvokeError(ValueError):
pass
class ToolApiSchemaError(ValueError):
class DatasourceApiSchemaError(ValueError):
pass
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta
class DatasourceEngineInvokeError(Exception):
meta: DatasourceInvokeMeta
def __init__(self, meta, **kwargs):
self.meta = meta

View File

@ -1,234 +0,0 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
class ToolFileManager:
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = dify_config.FILES_URL
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@staticmethod
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
) -> ToolFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
unique_filename = f"{unique_name}{extension}"
# default just as before
present_filename = unique_filename
if filename is not None:
has_extension = len(filename.split(".")) > 1
# Add extension flexibly
present_filename = filename if has_extension else f"{filename}{extension}"
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
name=present_filename,
size=len(file_binary),
)
db.session.add(tool_file)
db.session.commit()
db.session.refresh(tool_file)
return tool_file
@staticmethod
def create_file_by_url(
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: Optional[str] = None,
) -> ToolFile:
# try to download image
try:
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except httpx.TimeoutException:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = (
guess_type(file_url)[0]
or response.headers.get("Content-Type", "").split(";")[0].strip()
or "application/octet-stream"
)
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
filename = f"{unique_name}{extension}"
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_key=filepath,
mimetype=mimetype,
original_url=file_url,
name=filename,
size=len(blob),
)
db.session.add(tool_file)
db.session.commit()
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
:param id: the id of the file
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
.first()
)
# Check if message_file is not None
if message_file is not None:
# get tool file id
if message_file.url is not None:
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
else:
tool_file_id = None
else:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None
blob = storage.load_once(tool_file.file_key)
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_tool_file_id(tool_file_id: str):
"""
get file binary
:param tool_file_id: the id of the tool file
:return: the binary of the file, mime type
"""
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
if not tool_file:
return None, None
stream = storage.load_stream(tool_file.file_key)
return stream, tool_file
# init tool_file_parser
from core.file.tool_file_parser import tool_file_manager
tool_file_manager["manager"] = ToolFileManager

View File

@ -1,101 +0,0 @@
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.values import default_tool_label_name_list
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from extensions.ext_database import db
from models.tools import ToolLabelBinding
class ToolLabelManager:
@classmethod
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
"""
Filter tool labels
"""
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
return list(set(tool_labels))
@classmethod
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
"""
Update tool labels
"""
labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
else:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels
for label in labels:
db.session.add(
ToolLabelBinding(
tool_id=provider_id,
tool_type=controller.provider_type.value,
label_name=label,
)
)
db.session.commit()
@classmethod
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
"""
Get tool labels
"""
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels
else:
raise ValueError("Unsupported tool type")
labels = (
db.session.query(ToolLabelBinding.label_name)
.filter(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
.all()
)
return [label.label_name for label in labels]
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
"""
Get tools labels
:param tool_providers: list of tool providers
:return: dict of tool labels
:key: tool id
:value: list of tool labels
"""
if not tool_providers:
return {}
for controller in tool_providers:
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
raise ValueError("Unsupported tool type")
provider_ids = []
for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id)
labels: list[ToolLabelBinding] = (
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
for label in labels:
tool_labels[label.tool_id].append(label.label_name)
return tool_labels

View File

@ -132,3 +132,4 @@ class TraceTaskName(StrEnum):
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
DATASOURCE_TRACE = "datasource"