mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 13:35:55 +08:00
r2
This commit is contained in:
parent
389f15f8e3
commit
d4007ae073
@ -4,6 +4,7 @@ from typing import Any, Optional, TextIO, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
@ -105,6 +106,36 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
|
||||
self.current_loop += 1
|
||||
|
||||
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
|
||||
"""Run on datasource start."""
|
||||
if dify_config.DEBUG:
|
||||
print_text("\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" +
|
||||
str(datasource_inputs) + "\n", color=self.color)
|
||||
|
||||
def on_datasource_end(self, datasource_name: str, datasource_inputs: Mapping[str, Any], datasource_outputs:
|
||||
Iterable[DatasourceInvokeMessage] | str, message_id: Optional[str] = None,
|
||||
timer: Optional[Any] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||
"""Run on datasource end."""
|
||||
if dify_config.DEBUG:
|
||||
print_text("\n[on_datasource_end]\n", color=self.color)
|
||||
print_text("Datasource: " + datasource_name + "\n", color=self.color)
|
||||
print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color)
|
||||
print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color)
|
||||
print_text("\n")
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASOURCE_TRACE,
|
||||
message_id=message_id,
|
||||
datasource_name=datasource_name,
|
||||
datasource_inputs=datasource_inputs,
|
||||
datasource_outputs=datasource_outputs,
|
||||
timer=timer,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def ignore_agent(self) -> bool:
|
||||
"""Whether to ignore agent callbacks."""
|
||||
|
@ -1,36 +1,19 @@
|
||||
import json
|
||||
from collections.abc import Generator, Iterable
|
||||
from copy import deepcopy
|
||||
from datetime import UTC, datetime
|
||||
from mimetypes import guess_type
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceInvokeMessage,
|
||||
DatasourceInvokeMessageBinary,
|
||||
)
|
||||
from core.file import FileType
|
||||
from core.file.models import FileTransferMethod
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeMessage,
|
||||
ToolInvokeMessageBinary,
|
||||
ToolInvokeMeta,
|
||||
ToolParameter,
|
||||
)
|
||||
from core.tools.errors import (
|
||||
ToolEngineInvokeError,
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import Message, MessageFile
|
||||
@ -42,149 +25,39 @@ class DatasourceEngine:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def agent_invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: Union[str, dict],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
# check if arguments is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter
|
||||
for parameter in tool.get_runtime_parameters()
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {parameters[0].name: tool_parameters}
|
||||
else:
|
||||
try:
|
||||
tool_parameters = json.loads(tool_parameters)
|
||||
except Exception:
|
||||
pass
|
||||
if not isinstance(tool_parameters, dict):
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
|
||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
|
||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
||||
|
||||
def message_callback(
|
||||
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
|
||||
):
|
||||
for message in messages:
|
||||
if isinstance(message, ToolInvokeMeta):
|
||||
invocation_meta_dict["meta"] = message
|
||||
else:
|
||||
yield message
|
||||
|
||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=message_callback(invocation_meta_dict, messages),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=message.conversation_id,
|
||||
)
|
||||
|
||||
message_list = list(messages)
|
||||
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
|
||||
# create message file
|
||||
message_files = ToolEngine._create_message_files(
|
||||
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
|
||||
)
|
||||
|
||||
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
|
||||
|
||||
meta = invocation_meta_dict["meta"]
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text,
|
||||
message_id=message.id,
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
return plain_text, message_files, meta
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
|
||||
error_response = f"there is not a tool named {tool.entity.identity.name}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolParameterValidationError as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolEngineInvokeError as e:
|
||||
meta = e.meta
|
||||
error_response = f"tool invoke error: {meta.error}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
return error_response, [], meta
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
||||
|
||||
@staticmethod
|
||||
def x(
|
||||
tool: Tool,
|
||||
tool_parameters: dict[str, Any],
|
||||
def invoke_first_step(
|
||||
datasource: DatasourcePlugin,
|
||||
datasource_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
Workflow invokes the datasource with the given arguments.
|
||||
"""
|
||||
try:
|
||||
# hit the callback handler
|
||||
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
||||
workflow_tool_callback.on_datasource_start(datasource_name=datasource.entity.identity.name,
|
||||
datasource_inputs=datasource_parameters)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
tool.thread_pool_id = thread_pool_id
|
||||
if datasource.runtime and datasource.runtime.runtime_parameters:
|
||||
datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters}
|
||||
|
||||
if tool.runtime and tool.runtime.runtime_parameters:
|
||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
||||
|
||||
response = tool.invoke(
|
||||
response = datasource._invoke_first_step(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
datasource_parameters=datasource_parameters,
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
# hit the callback handler
|
||||
response = workflow_tool_callback.on_tool_execution(
|
||||
tool_name=tool.entity.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=response,
|
||||
response = workflow_tool_callback.on_datasource_end(
|
||||
datasource_name=datasource.entity.identity.name,
|
||||
datasource_inputs=datasource_parameters,
|
||||
datasource_outputs=response,
|
||||
)
|
||||
|
||||
return response
|
||||
@ -193,61 +66,49 @@ class DatasourceEngine:
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _invoke(
|
||||
tool: Tool,
|
||||
tool_parameters: dict,
|
||||
def invoke_second_step(
|
||||
datasource: DatasourcePlugin,
|
||||
datasource_parameters: dict[str, Any],
|
||||
user_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given arguments.
|
||||
Workflow invokes the datasource with the given arguments.
|
||||
"""
|
||||
started_at = datetime.now(UTC)
|
||||
meta = ToolInvokeMeta(
|
||||
time_cost=0.0,
|
||||
error=None,
|
||||
tool_config={
|
||||
"tool_name": tool.entity.identity.name,
|
||||
"tool_provider": tool.entity.identity.provider,
|
||||
"tool_provider_type": tool.tool_provider_type().value,
|
||||
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
|
||||
"tool_icon": tool.entity.identity.icon,
|
||||
},
|
||||
)
|
||||
try:
|
||||
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
|
||||
response = datasource._invoke_second_step(
|
||||
user_id=user_id,
|
||||
datasource_parameters=datasource_parameters,
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
meta.error = str(e)
|
||||
raise ToolEngineInvokeError(meta)
|
||||
finally:
|
||||
ended_at = datetime.now(UTC)
|
||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
||||
yield meta
|
||||
workflow_tool_callback.on_tool_error(e)
|
||||
raise e
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
Handle datasource response
|
||||
"""
|
||||
result = ""
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += cast(ToolInvokeMessage.TextMessage, response.message).text
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
for response in datasource_response:
|
||||
if response.type == DatasourceInvokeMessage.MessageType.TEXT:
|
||||
result += cast(DatasourceInvokeMessage.TextMessage, response.message).text
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
result += (
|
||||
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
|
||||
f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}."
|
||||
+ " please tell user to check it."
|
||||
)
|
||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
elif response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
|
||||
result += (
|
||||
"image has been created and sent to user already, "
|
||||
+ "you do not need to create it, just tell the user to check it now."
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.JSON:
|
||||
result = json.dumps(
|
||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
result += str(response.message)
|
||||
@ -255,14 +116,14 @@ class DatasourceEngine:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary_and_text(
|
||||
tool_response: list[ToolInvokeMessage],
|
||||
) -> Generator[ToolInvokeMessageBinary, None, None]:
|
||||
def _extract_datasource_response_binary_and_text(
|
||||
datasource_response: list[DatasourceInvokeMessage],
|
||||
) -> Generator[DatasourceInvokeMessageBinary, None, None]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
Extract datasource response binary
|
||||
"""
|
||||
for response in tool_response:
|
||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||
for response in datasource_response:
|
||||
if response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
|
||||
mimetype = None
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
@ -270,7 +131,7 @@ class DatasourceEngine:
|
||||
mimetype = response.meta.get("mime_type")
|
||||
else:
|
||||
try:
|
||||
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
||||
url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f"a{extension}")
|
||||
if guess_type_result:
|
||||
@ -281,31 +142,31 @@ class DatasourceEngine:
|
||||
if not mimetype:
|
||||
mimetype = "image/jpeg"
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||
if not response.meta:
|
||||
raise ValueError("missing meta data")
|
||||
|
||||
yield ToolInvokeMessageBinary(
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and "mime_type" in response.meta:
|
||||
yield ToolInvokeMessageBinary(
|
||||
yield DatasourceInvokeMessageBinary(
|
||||
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream",
|
||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
||||
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
tool_messages: Iterable[ToolInvokeMessageBinary],
|
||||
datasource_messages: Iterable[DatasourceInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str,
|
||||
@ -317,7 +178,7 @@ class DatasourceEngine:
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in tool_messages:
|
||||
for message in datasource_messages:
|
||||
if "image" in message.mimetype:
|
||||
file_type = FileType.IMAGE
|
||||
elif "video" in message.mimetype:
|
||||
|
@ -1,36 +1,36 @@
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||
|
||||
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
class DatasourceProviderNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotFoundError(ValueError):
|
||||
class DatasourceNotFoundError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolParameterValidationError(ValueError):
|
||||
class DatasourceParameterValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolProviderCredentialValidationError(ValueError):
|
||||
class DatasourceProviderCredentialValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolNotSupportedError(ValueError):
|
||||
class DatasourceNotSupportedError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolInvokeError(ValueError):
|
||||
class DatasourceInvokeError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolApiSchemaError(ValueError):
|
||||
class DatasourceApiSchemaError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
class DatasourceEngineInvokeError(Exception):
|
||||
meta: DatasourceInvokeMeta
|
||||
|
||||
def __init__(self, meta, **kwargs):
|
||||
self.meta = meta
|
||||
|
@ -1,234 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@staticmethod
|
||||
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
) -> ToolFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
unique_filename = f"{unique_name}{extension}"
|
||||
# default just as before
|
||||
present_filename = unique_filename
|
||||
if filename is not None:
|
||||
has_extension = len(filename.split(".")) > 1
|
||||
# Add extension flexibly
|
||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except httpx.TimeoutException:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = (
|
||||
guess_type(file_url)[0]
|
||||
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
or "application/octet-stream"
|
||||
)
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
filename = f"{unique_name}{extension}"
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param id: the id of the file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
||||
blob = storage.load_once(tool_file.file_key)
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
:param tool_file_id: the id of the tool file
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
|
||||
stream = storage.load_stream(tool_file.file_key)
|
||||
|
||||
return stream, tool_file
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
from core.file.tool_file_parser import tool_file_manager
|
||||
|
||||
tool_file_manager["manager"] = ToolFileManager
|
@ -1,101 +0,0 @@
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.values import default_tool_label_name_list
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ToolLabelBinding
|
||||
|
||||
|
||||
class ToolLabelManager:
|
||||
@classmethod
|
||||
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
|
||||
"""
|
||||
Filter tool labels
|
||||
"""
|
||||
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
|
||||
return list(set(tool_labels))
|
||||
|
||||
@classmethod
|
||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
||||
"""
|
||||
Update tool labels
|
||||
"""
|
||||
labels = cls.filter_tool_labels(labels)
|
||||
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
db.session.add(
|
||||
ToolLabelBinding(
|
||||
tool_id=provider_id,
|
||||
tool_type=controller.provider_type.value,
|
||||
label_name=label,
|
||||
)
|
||||
)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
||||
"""
|
||||
Get tool labels
|
||||
"""
|
||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
provider_id = controller.provider_id
|
||||
elif isinstance(controller, BuiltinToolProviderController):
|
||||
return controller.tool_labels
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
labels = (
|
||||
db.session.query(ToolLabelBinding.label_name)
|
||||
.filter(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [label.label_name for label in labels]
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Get tools labels
|
||||
|
||||
:param tool_providers: list of tool providers
|
||||
|
||||
:return: dict of tool labels
|
||||
:key: tool id
|
||||
:value: list of tool labels
|
||||
"""
|
||||
if not tool_providers:
|
||||
return {}
|
||||
|
||||
for controller in tool_providers:
|
||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
provider_ids = []
|
||||
for controller in tool_providers:
|
||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||
provider_ids.append(controller.provider_id)
|
||||
|
||||
labels: list[ToolLabelBinding] = (
|
||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
||||
)
|
||||
|
||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
||||
|
||||
for label in labels:
|
||||
tool_labels[label.tool_id].append(label.label_name)
|
||||
|
||||
return tool_labels
|
@ -132,3 +132,4 @@ class TraceTaskName(StrEnum):
|
||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
DATASOURCE_TRACE = "datasource"
|
Loading…
x
Reference in New Issue
Block a user