mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 17:15:54 +08:00
r2
This commit is contained in:
parent
389f15f8e3
commit
d4007ae073
@ -4,6 +4,7 @@ from typing import Any, Optional, TextIO, Union
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeMessage
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
@ -105,6 +106,36 @@ class DifyAgentCallbackHandler(BaseModel):
|
|||||||
|
|
||||||
self.current_loop += 1
|
self.current_loop += 1
|
||||||
|
|
||||||
|
def on_datasource_start(self, datasource_name: str, datasource_inputs: Mapping[str, Any]) -> None:
|
||||||
|
"""Run on datasource start."""
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
print_text("\n[on_datasource_start] DatasourceCall:" + datasource_name + "\n" +
|
||||||
|
str(datasource_inputs) + "\n", color=self.color)
|
||||||
|
|
||||||
|
def on_datasource_end(self, datasource_name: str, datasource_inputs: Mapping[str, Any], datasource_outputs:
|
||||||
|
Iterable[DatasourceInvokeMessage] | str, message_id: Optional[str] = None,
|
||||||
|
timer: Optional[Any] = None,
|
||||||
|
trace_manager: Optional[TraceQueueManager] = None) -> None:
|
||||||
|
"""Run on datasource end."""
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
print_text("\n[on_datasource_end]\n", color=self.color)
|
||||||
|
print_text("Datasource: " + datasource_name + "\n", color=self.color)
|
||||||
|
print_text("Inputs: " + str(datasource_inputs) + "\n", color=self.color)
|
||||||
|
print_text("Outputs: " + str(datasource_outputs)[:1000] + "\n", color=self.color)
|
||||||
|
print_text("\n")
|
||||||
|
|
||||||
|
if trace_manager:
|
||||||
|
trace_manager.add_trace_task(
|
||||||
|
TraceTask(
|
||||||
|
TraceTaskName.DATASOURCE_TRACE,
|
||||||
|
message_id=message_id,
|
||||||
|
datasource_name=datasource_name,
|
||||||
|
datasource_inputs=datasource_inputs,
|
||||||
|
datasource_outputs=datasource_outputs,
|
||||||
|
timer=timer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ignore_agent(self) -> bool:
|
def ignore_agent(self) -> bool:
|
||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
|
@ -1,36 +1,19 @@
|
|||||||
import json
|
import json
|
||||||
from collections.abc import Generator, Iterable
|
from collections.abc import Generator, Iterable
|
||||||
from copy import deepcopy
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from mimetypes import guess_type
|
from mimetypes import guess_type
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceInvokeMessage,
|
||||||
|
DatasourceInvokeMessageBinary,
|
||||||
|
)
|
||||||
from core.file import FileType
|
from core.file import FileType
|
||||||
from core.file.models import FileTransferMethod
|
from core.file.models import FileTransferMethod
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
|
||||||
from core.tools.__base.tool import Tool
|
|
||||||
from core.tools.entities.tool_entities import (
|
|
||||||
ToolInvokeMessage,
|
|
||||||
ToolInvokeMessageBinary,
|
|
||||||
ToolInvokeMeta,
|
|
||||||
ToolParameter,
|
|
||||||
)
|
|
||||||
from core.tools.errors import (
|
|
||||||
ToolEngineInvokeError,
|
|
||||||
ToolInvokeError,
|
|
||||||
ToolNotFoundError,
|
|
||||||
ToolNotSupportedError,
|
|
||||||
ToolParameterValidationError,
|
|
||||||
ToolProviderCredentialValidationError,
|
|
||||||
ToolProviderNotFoundError,
|
|
||||||
)
|
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
|
||||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatedByRole
|
||||||
from models.model import Message, MessageFile
|
from models.model import Message, MessageFile
|
||||||
@ -42,149 +25,39 @@ class DatasourceEngine:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def agent_invoke(
|
def invoke_first_step(
|
||||||
tool: Tool,
|
datasource: DatasourcePlugin,
|
||||||
tool_parameters: Union[str, dict],
|
datasource_parameters: dict[str, Any],
|
||||||
user_id: str,
|
|
||||||
tenant_id: str,
|
|
||||||
message: Message,
|
|
||||||
invoke_from: InvokeFrom,
|
|
||||||
agent_tool_callback: DifyAgentCallbackHandler,
|
|
||||||
trace_manager: Optional[TraceQueueManager] = None,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
app_id: Optional[str] = None,
|
|
||||||
message_id: Optional[str] = None,
|
|
||||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
|
||||||
"""
|
|
||||||
Agent invokes the tool with the given arguments.
|
|
||||||
"""
|
|
||||||
# check if arguments is a string
|
|
||||||
if isinstance(tool_parameters, str):
|
|
||||||
# check if this tool has only one parameter
|
|
||||||
parameters = [
|
|
||||||
parameter
|
|
||||||
for parameter in tool.get_runtime_parameters()
|
|
||||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
|
||||||
]
|
|
||||||
if parameters and len(parameters) == 1:
|
|
||||||
tool_parameters = {parameters[0].name: tool_parameters}
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
tool_parameters = json.loads(tool_parameters)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if not isinstance(tool_parameters, dict):
|
|
||||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# hit the callback handler
|
|
||||||
agent_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
|
||||||
|
|
||||||
messages = ToolEngine._invoke(tool, tool_parameters, user_id, conversation_id, app_id, message_id)
|
|
||||||
invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
|
||||||
|
|
||||||
def message_callback(
|
|
||||||
invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]
|
|
||||||
):
|
|
||||||
for message in messages:
|
|
||||||
if isinstance(message, ToolInvokeMeta):
|
|
||||||
invocation_meta_dict["meta"] = message
|
|
||||||
else:
|
|
||||||
yield message
|
|
||||||
|
|
||||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
|
||||||
messages=message_callback(invocation_meta_dict, messages),
|
|
||||||
user_id=user_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
conversation_id=message.conversation_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
message_list = list(messages)
|
|
||||||
|
|
||||||
# extract binary data from tool invoke message
|
|
||||||
binary_files = ToolEngine._extract_tool_response_binary_and_text(message_list)
|
|
||||||
# create message file
|
|
||||||
message_files = ToolEngine._create_message_files(
|
|
||||||
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
|
|
||||||
)
|
|
||||||
|
|
||||||
plain_text = ToolEngine._convert_tool_response_to_str(message_list)
|
|
||||||
|
|
||||||
meta = invocation_meta_dict["meta"]
|
|
||||||
|
|
||||||
# hit the callback handler
|
|
||||||
agent_tool_callback.on_tool_end(
|
|
||||||
tool_name=tool.entity.identity.name,
|
|
||||||
tool_inputs=tool_parameters,
|
|
||||||
tool_outputs=plain_text,
|
|
||||||
message_id=message.id,
|
|
||||||
trace_manager=trace_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# transform tool invoke message to get LLM friendly message
|
|
||||||
return plain_text, message_files, meta
|
|
||||||
except ToolProviderCredentialValidationError as e:
|
|
||||||
error_response = "Please check your tool provider credentials"
|
|
||||||
agent_tool_callback.on_tool_error(e)
|
|
||||||
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
|
|
||||||
error_response = f"there is not a tool named {tool.entity.identity.name}"
|
|
||||||
agent_tool_callback.on_tool_error(e)
|
|
||||||
except ToolParameterValidationError as e:
|
|
||||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
|
||||||
agent_tool_callback.on_tool_error(e)
|
|
||||||
except ToolInvokeError as e:
|
|
||||||
error_response = f"tool invoke error: {e}"
|
|
||||||
agent_tool_callback.on_tool_error(e)
|
|
||||||
except ToolEngineInvokeError as e:
|
|
||||||
meta = e.meta
|
|
||||||
error_response = f"tool invoke error: {meta.error}"
|
|
||||||
agent_tool_callback.on_tool_error(e)
|
|
||||||
return error_response, [], meta
|
|
||||||
except Exception as e:
|
|
||||||
error_response = f"unknown error: {e}"
|
|
||||||
agent_tool_callback.on_tool_error(e)
|
|
||||||
|
|
||||||
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def x(
|
|
||||||
tool: Tool,
|
|
||||||
tool_parameters: dict[str, Any],
|
|
||||||
user_id: str,
|
user_id: str,
|
||||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||||
workflow_call_depth: int,
|
|
||||||
thread_pool_id: Optional[str] = None,
|
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
app_id: Optional[str] = None,
|
app_id: Optional[str] = None,
|
||||||
message_id: Optional[str] = None,
|
message_id: Optional[str] = None,
|
||||||
) -> Generator[ToolInvokeMessage, None, None]:
|
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||||
"""
|
"""
|
||||||
Workflow invokes the tool with the given arguments.
|
Workflow invokes the datasource with the given arguments.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# hit the callback handler
|
# hit the callback handler
|
||||||
workflow_tool_callback.on_tool_start(tool_name=tool.entity.identity.name, tool_inputs=tool_parameters)
|
workflow_tool_callback.on_datasource_start(datasource_name=datasource.entity.identity.name,
|
||||||
|
datasource_inputs=datasource_parameters)
|
||||||
|
|
||||||
if isinstance(tool, WorkflowTool):
|
if datasource.runtime and datasource.runtime.runtime_parameters:
|
||||||
tool.workflow_call_depth = workflow_call_depth + 1
|
datasource_parameters = {**datasource.runtime.runtime_parameters, **datasource_parameters}
|
||||||
tool.thread_pool_id = thread_pool_id
|
|
||||||
|
|
||||||
if tool.runtime and tool.runtime.runtime_parameters:
|
response = datasource._invoke_first_step(
|
||||||
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
|
||||||
|
|
||||||
response = tool.invoke(
|
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tool_parameters=tool_parameters,
|
datasource_parameters=datasource_parameters,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# hit the callback handler
|
# hit the callback handler
|
||||||
response = workflow_tool_callback.on_tool_execution(
|
response = workflow_tool_callback.on_datasource_end(
|
||||||
tool_name=tool.entity.identity.name,
|
datasource_name=datasource.entity.identity.name,
|
||||||
tool_inputs=tool_parameters,
|
datasource_inputs=datasource_parameters,
|
||||||
tool_outputs=response,
|
datasource_outputs=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@ -193,61 +66,49 @@ class DatasourceEngine:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _invoke(
|
def invoke_second_step(
|
||||||
tool: Tool,
|
datasource: DatasourcePlugin,
|
||||||
tool_parameters: dict,
|
datasource_parameters: dict[str, Any],
|
||||||
user_id: str,
|
user_id: str,
|
||||||
conversation_id: Optional[str] = None,
|
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||||
app_id: Optional[str] = None,
|
) -> Generator[DatasourceInvokeMessage, None, None]:
|
||||||
message_id: Optional[str] = None,
|
|
||||||
) -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
|
|
||||||
"""
|
"""
|
||||||
Invoke the tool with the given arguments.
|
Workflow invokes the datasource with the given arguments.
|
||||||
"""
|
"""
|
||||||
started_at = datetime.now(UTC)
|
|
||||||
meta = ToolInvokeMeta(
|
|
||||||
time_cost=0.0,
|
|
||||||
error=None,
|
|
||||||
tool_config={
|
|
||||||
"tool_name": tool.entity.identity.name,
|
|
||||||
"tool_provider": tool.entity.identity.provider,
|
|
||||||
"tool_provider_type": tool.tool_provider_type().value,
|
|
||||||
"tool_parameters": deepcopy(tool.runtime.runtime_parameters),
|
|
||||||
"tool_icon": tool.entity.identity.icon,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
yield from tool.invoke(user_id, tool_parameters, conversation_id, app_id, message_id)
|
response = datasource._invoke_second_step(
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
meta.error = str(e)
|
workflow_tool_callback.on_tool_error(e)
|
||||||
raise ToolEngineInvokeError(meta)
|
raise e
|
||||||
finally:
|
|
||||||
ended_at = datetime.now(UTC)
|
|
||||||
meta.time_cost = (ended_at - started_at).total_seconds()
|
|
||||||
yield meta
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
def _convert_datasource_response_to_str(datasource_response: list[DatasourceInvokeMessage]) -> str:
|
||||||
"""
|
"""
|
||||||
Handle tool response
|
Handle datasource response
|
||||||
"""
|
"""
|
||||||
result = ""
|
result = ""
|
||||||
for response in tool_response:
|
for response in datasource_response:
|
||||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
if response.type == DatasourceInvokeMessage.MessageType.TEXT:
|
||||||
result += cast(ToolInvokeMessage.TextMessage, response.message).text
|
result += cast(DatasourceInvokeMessage.TextMessage, response.message).text
|
||||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||||
result += (
|
result += (
|
||||||
f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}."
|
f"result link: {cast(DatasourceInvokeMessage.TextMessage, response.message).text}."
|
||||||
+ " please tell user to check it."
|
+ " please tell user to check it."
|
||||||
)
|
)
|
||||||
elif response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
elif response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
|
||||||
result += (
|
result += (
|
||||||
"image has been created and sent to user already, "
|
"image has been created and sent to user already, "
|
||||||
+ "you do not need to create it, just tell the user to check it now."
|
+ "you do not need to create it, just tell the user to check it now."
|
||||||
)
|
)
|
||||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
elif response.type == DatasourceInvokeMessage.MessageType.JSON:
|
||||||
result = json.dumps(
|
result = json.dumps(
|
||||||
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
cast(DatasourceInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result += str(response.message)
|
result += str(response.message)
|
||||||
@ -255,14 +116,14 @@ class DatasourceEngine:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_tool_response_binary_and_text(
|
def _extract_datasource_response_binary_and_text(
|
||||||
tool_response: list[ToolInvokeMessage],
|
datasource_response: list[DatasourceInvokeMessage],
|
||||||
) -> Generator[ToolInvokeMessageBinary, None, None]:
|
) -> Generator[DatasourceInvokeMessageBinary, None, None]:
|
||||||
"""
|
"""
|
||||||
Extract tool response binary
|
Extract datasource response binary
|
||||||
"""
|
"""
|
||||||
for response in tool_response:
|
for response in datasource_response:
|
||||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
if response.type in {DatasourceInvokeMessage.MessageType.IMAGE_LINK, DatasourceInvokeMessage.MessageType.IMAGE}:
|
||||||
mimetype = None
|
mimetype = None
|
||||||
if not response.meta:
|
if not response.meta:
|
||||||
raise ValueError("missing meta data")
|
raise ValueError("missing meta data")
|
||||||
@ -270,7 +131,7 @@ class DatasourceEngine:
|
|||||||
mimetype = response.meta.get("mime_type")
|
mimetype = response.meta.get("mime_type")
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
url = URL(cast(DatasourceInvokeMessage.TextMessage, response.message).text)
|
||||||
extension = url.suffix
|
extension = url.suffix
|
||||||
guess_type_result, _ = guess_type(f"a{extension}")
|
guess_type_result, _ = guess_type(f"a{extension}")
|
||||||
if guess_type_result:
|
if guess_type_result:
|
||||||
@ -281,31 +142,31 @@ class DatasourceEngine:
|
|||||||
if not mimetype:
|
if not mimetype:
|
||||||
mimetype = "image/jpeg"
|
mimetype = "image/jpeg"
|
||||||
|
|
||||||
yield ToolInvokeMessageBinary(
|
yield DatasourceInvokeMessageBinary(
|
||||||
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
mimetype=response.meta.get("mime_type", "image/jpeg"),
|
||||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||||
)
|
)
|
||||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
elif response.type == DatasourceInvokeMessage.MessageType.BLOB:
|
||||||
if not response.meta:
|
if not response.meta:
|
||||||
raise ValueError("missing meta data")
|
raise ValueError("missing meta data")
|
||||||
|
|
||||||
yield ToolInvokeMessageBinary(
|
yield DatasourceInvokeMessageBinary(
|
||||||
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
mimetype=response.meta.get("mime_type", "application/octet-stream"),
|
||||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||||
)
|
)
|
||||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
elif response.type == DatasourceInvokeMessage.MessageType.LINK:
|
||||||
# check if there is a mime type in meta
|
# check if there is a mime type in meta
|
||||||
if response.meta and "mime_type" in response.meta:
|
if response.meta and "mime_type" in response.meta:
|
||||||
yield ToolInvokeMessageBinary(
|
yield DatasourceInvokeMessageBinary(
|
||||||
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
mimetype=response.meta.get("mime_type", "application/octet-stream")
|
||||||
if response.meta
|
if response.meta
|
||||||
else "application/octet-stream",
|
else "application/octet-stream",
|
||||||
url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
url=cast(DatasourceInvokeMessage.TextMessage, response.message).text,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_message_files(
|
def _create_message_files(
|
||||||
tool_messages: Iterable[ToolInvokeMessageBinary],
|
datasource_messages: Iterable[DatasourceInvokeMessageBinary],
|
||||||
agent_message: Message,
|
agent_message: Message,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@ -317,7 +178,7 @@ class DatasourceEngine:
|
|||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for message in tool_messages:
|
for message in datasource_messages:
|
||||||
if "image" in message.mimetype:
|
if "image" in message.mimetype:
|
||||||
file_type = FileType.IMAGE
|
file_type = FileType.IMAGE
|
||||||
elif "video" in message.mimetype:
|
elif "video" in message.mimetype:
|
||||||
|
@ -1,36 +1,36 @@
|
|||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderNotFoundError(ValueError):
|
class DatasourceProviderNotFoundError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolNotFoundError(ValueError):
|
class DatasourceNotFoundError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolParameterValidationError(ValueError):
|
class DatasourceParameterValidationError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderCredentialValidationError(ValueError):
|
class DatasourceProviderCredentialValidationError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolNotSupportedError(ValueError):
|
class DatasourceNotSupportedError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolInvokeError(ValueError):
|
class DatasourceInvokeError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolApiSchemaError(ValueError):
|
class DatasourceApiSchemaError(ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ToolEngineInvokeError(Exception):
|
class DatasourceEngineInvokeError(Exception):
|
||||||
meta: ToolInvokeMeta
|
meta: DatasourceInvokeMeta
|
||||||
|
|
||||||
def __init__(self, meta, **kwargs):
|
def __init__(self, meta, **kwargs):
|
||||||
self.meta = meta
|
self.meta = meta
|
||||||
|
@ -1,234 +0,0 @@
|
|||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from mimetypes import guess_extension, guess_type
|
|
||||||
from typing import Optional, Union
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from core.helper import ssrf_proxy
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from extensions.ext_storage import storage
|
|
||||||
from models.model import MessageFile
|
|
||||||
from models.tools import ToolFile
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolFileManager:
|
|
||||||
@staticmethod
|
|
||||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
|
||||||
"""
|
|
||||||
sign file to get a temporary url
|
|
||||||
"""
|
|
||||||
base_url = dify_config.FILES_URL
|
|
||||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
|
||||||
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
nonce = os.urandom(16).hex()
|
|
||||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
|
||||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
|
||||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
|
||||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
|
||||||
|
|
||||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
|
||||||
"""
|
|
||||||
verify signature
|
|
||||||
"""
|
|
||||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
|
||||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
|
||||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
|
||||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
|
||||||
|
|
||||||
# verify signature
|
|
||||||
if sign != recalculated_encoded_sign:
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_time = int(time.time())
|
|
||||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_file_by_raw(
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
tenant_id: str,
|
|
||||||
conversation_id: Optional[str],
|
|
||||||
file_binary: bytes,
|
|
||||||
mimetype: str,
|
|
||||||
filename: Optional[str] = None,
|
|
||||||
) -> ToolFile:
|
|
||||||
extension = guess_extension(mimetype) or ".bin"
|
|
||||||
unique_name = uuid4().hex
|
|
||||||
unique_filename = f"{unique_name}{extension}"
|
|
||||||
# default just as before
|
|
||||||
present_filename = unique_filename
|
|
||||||
if filename is not None:
|
|
||||||
has_extension = len(filename.split(".")) > 1
|
|
||||||
# Add extension flexibly
|
|
||||||
present_filename = filename if has_extension else f"{filename}{extension}"
|
|
||||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
|
||||||
storage.save(filepath, file_binary)
|
|
||||||
|
|
||||||
tool_file = ToolFile(
|
|
||||||
user_id=user_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
file_key=filepath,
|
|
||||||
mimetype=mimetype,
|
|
||||||
name=present_filename,
|
|
||||||
size=len(file_binary),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(tool_file)
|
|
||||||
db.session.commit()
|
|
||||||
db.session.refresh(tool_file)
|
|
||||||
|
|
||||||
return tool_file
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_file_by_url(
|
|
||||||
user_id: str,
|
|
||||||
tenant_id: str,
|
|
||||||
file_url: str,
|
|
||||||
conversation_id: Optional[str] = None,
|
|
||||||
) -> ToolFile:
|
|
||||||
# try to download image
|
|
||||||
try:
|
|
||||||
response = ssrf_proxy.get(file_url)
|
|
||||||
response.raise_for_status()
|
|
||||||
blob = response.content
|
|
||||||
except httpx.TimeoutException:
|
|
||||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
|
||||||
|
|
||||||
mimetype = (
|
|
||||||
guess_type(file_url)[0]
|
|
||||||
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
|
||||||
or "application/octet-stream"
|
|
||||||
)
|
|
||||||
extension = guess_extension(mimetype) or ".bin"
|
|
||||||
unique_name = uuid4().hex
|
|
||||||
filename = f"{unique_name}{extension}"
|
|
||||||
filepath = f"tools/{tenant_id}/{filename}"
|
|
||||||
storage.save(filepath, blob)
|
|
||||||
|
|
||||||
tool_file = ToolFile(
|
|
||||||
user_id=user_id,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
file_key=filepath,
|
|
||||||
mimetype=mimetype,
|
|
||||||
original_url=file_url,
|
|
||||||
name=filename,
|
|
||||||
size=len(blob),
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.add(tool_file)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return tool_file
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
|
||||||
"""
|
|
||||||
get file binary
|
|
||||||
|
|
||||||
:param id: the id of the file
|
|
||||||
|
|
||||||
:return: the binary of the file, mime type
|
|
||||||
"""
|
|
||||||
tool_file: ToolFile | None = (
|
|
||||||
db.session.query(ToolFile)
|
|
||||||
.filter(
|
|
||||||
ToolFile.id == id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_file:
|
|
||||||
return None
|
|
||||||
|
|
||||||
blob = storage.load_once(tool_file.file_key)
|
|
||||||
|
|
||||||
return blob, tool_file.mimetype
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
|
||||||
"""
|
|
||||||
get file binary
|
|
||||||
|
|
||||||
:param id: the id of the file
|
|
||||||
|
|
||||||
:return: the binary of the file, mime type
|
|
||||||
"""
|
|
||||||
message_file: MessageFile | None = (
|
|
||||||
db.session.query(MessageFile)
|
|
||||||
.filter(
|
|
||||||
MessageFile.id == id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if message_file is not None
|
|
||||||
if message_file is not None:
|
|
||||||
# get tool file id
|
|
||||||
if message_file.url is not None:
|
|
||||||
tool_file_id = message_file.url.split("/")[-1]
|
|
||||||
# trim extension
|
|
||||||
tool_file_id = tool_file_id.split(".")[0]
|
|
||||||
else:
|
|
||||||
tool_file_id = None
|
|
||||||
else:
|
|
||||||
tool_file_id = None
|
|
||||||
|
|
||||||
tool_file: ToolFile | None = (
|
|
||||||
db.session.query(ToolFile)
|
|
||||||
.filter(
|
|
||||||
ToolFile.id == tool_file_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_file:
|
|
||||||
return None
|
|
||||||
|
|
||||||
blob = storage.load_once(tool_file.file_key)
|
|
||||||
|
|
||||||
return blob, tool_file.mimetype
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_file_generator_by_tool_file_id(tool_file_id: str):
|
|
||||||
"""
|
|
||||||
get file binary
|
|
||||||
|
|
||||||
:param tool_file_id: the id of the tool file
|
|
||||||
|
|
||||||
:return: the binary of the file, mime type
|
|
||||||
"""
|
|
||||||
tool_file: ToolFile | None = (
|
|
||||||
db.session.query(ToolFile)
|
|
||||||
.filter(
|
|
||||||
ToolFile.id == tool_file_id,
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_file:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
stream = storage.load_stream(tool_file.file_key)
|
|
||||||
|
|
||||||
return stream, tool_file
|
|
||||||
|
|
||||||
|
|
||||||
# init tool_file_parser
|
|
||||||
from core.file.tool_file_parser import tool_file_manager
|
|
||||||
|
|
||||||
tool_file_manager["manager"] = ToolFileManager
|
|
@ -1,101 +0,0 @@
|
|||||||
from core.tools.__base.tool_provider import ToolProviderController
|
|
||||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
|
||||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
|
||||||
from core.tools.entities.values import default_tool_label_name_list
|
|
||||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.tools import ToolLabelBinding
|
|
||||||
|
|
||||||
|
|
||||||
class ToolLabelManager:
|
|
||||||
@classmethod
|
|
||||||
def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
|
|
||||||
"""
|
|
||||||
Filter tool labels
|
|
||||||
"""
|
|
||||||
tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
|
|
||||||
return list(set(tool_labels))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
|
|
||||||
"""
|
|
||||||
Update tool labels
|
|
||||||
"""
|
|
||||||
labels = cls.filter_tool_labels(labels)
|
|
||||||
|
|
||||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
|
||||||
provider_id = controller.provider_id
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported tool type")
|
|
||||||
|
|
||||||
# delete old labels
|
|
||||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
|
|
||||||
|
|
||||||
# insert new labels
|
|
||||||
for label in labels:
|
|
||||||
db.session.add(
|
|
||||||
ToolLabelBinding(
|
|
||||||
tool_id=provider_id,
|
|
||||||
tool_type=controller.provider_type.value,
|
|
||||||
label_name=label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
|
|
||||||
"""
|
|
||||||
Get tool labels
|
|
||||||
"""
|
|
||||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
|
||||||
provider_id = controller.provider_id
|
|
||||||
elif isinstance(controller, BuiltinToolProviderController):
|
|
||||||
return controller.tool_labels
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported tool type")
|
|
||||||
|
|
||||||
labels = (
|
|
||||||
db.session.query(ToolLabelBinding.label_name)
|
|
||||||
.filter(
|
|
||||||
ToolLabelBinding.tool_id == provider_id,
|
|
||||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
|
||||||
)
|
|
||||||
.all()
|
|
||||||
)
|
|
||||||
|
|
||||||
return [label.label_name for label in labels]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
|
||||||
"""
|
|
||||||
Get tools labels
|
|
||||||
|
|
||||||
:param tool_providers: list of tool providers
|
|
||||||
|
|
||||||
:return: dict of tool labels
|
|
||||||
:key: tool id
|
|
||||||
:value: list of tool labels
|
|
||||||
"""
|
|
||||||
if not tool_providers:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
for controller in tool_providers:
|
|
||||||
if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
|
||||||
raise ValueError("Unsupported tool type")
|
|
||||||
|
|
||||||
provider_ids = []
|
|
||||||
for controller in tool_providers:
|
|
||||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
|
||||||
provider_ids.append(controller.provider_id)
|
|
||||||
|
|
||||||
labels: list[ToolLabelBinding] = (
|
|
||||||
db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
|
|
||||||
|
|
||||||
for label in labels:
|
|
||||||
tool_labels[label.tool_id].append(label.label_name)
|
|
||||||
|
|
||||||
return tool_labels
|
|
@ -132,3 +132,4 @@ class TraceTaskName(StrEnum):
|
|||||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||||
TOOL_TRACE = "tool"
|
TOOL_TRACE = "tool"
|
||||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||||
|
DATASOURCE_TRACE = "datasource"
|
Loading…
x
Reference in New Issue
Block a user