diff --git a/api/constants/mimetypes.py b/api/constants/mimetypes.py new file mode 100644 index 0000000000..38988cdd24 --- /dev/null +++ b/api/constants/mimetypes.py @@ -0,0 +1,7 @@ +# The two constants below should keep in sync. +# Default content type for files which have no explicit content type. + +DEFAULT_MIME_TYPE = "application/octet-stream" +# Default file extension for files which have no explicit content type, should +# correspond to the `DEFAULT_MIME_TYPE` above. +DEFAULT_EXTENSION = ".bin" diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index cfcce81247..7cb97806c8 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -4,7 +4,9 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api from controllers.files.error import UnsupportedFileTypeError +from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager +from models import db as global_db class ToolFilePreviewApi(Resource): @@ -19,17 +21,14 @@ class ToolFilePreviewApi(Resource): parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") args = parser.parse_args() - - if not ToolFileManager.verify_file( - file_id=file_id, - timestamp=args["timestamp"], - nonce=args["nonce"], - sign=args["sign"], + if not verify_tool_file_signature( + file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"] ): raise Forbidden("Invalid request.") try: - stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id( + tool_file_manager = ToolFileManager(engine=global_db.engine) + stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id( file_id, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 28ee0eecf4..178fb9477a 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -53,7 +53,7 @@ class PluginUploadFileApi(Resource): raise Forbidden("Invalid request.") try: - tool_file = ToolFileManager.create_file_by_raw( + tool_file = ToolFileManager().create_file_by_raw( user_id=user.id, tenant_id=tenant_id, file_binary=file.read(), diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index fde506639f..a6d826f08b 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -24,7 +24,7 @@ from core.app.entities.task_entities import ( WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator -from core.tools.tool_file_manager import ToolFileManager +from core.tools.signature import sign_tool_file from extensions.ext_database import db from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -154,7 +154,7 @@ class MessageCycleManage: if message_file.url.startswith("http"): url = message_file.url else: - url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) return MessageFileStreamResponse( task_id=self._application_generate_entity.task_id, diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 9a204e9ff6..ada19ef8ce 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -10,12 +10,12 @@ from core.model_runtime.entities import ( VideoPromptMessageContent, ) from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from core.tools.signature import sign_tool_file from extensions.ext_storage import storage from . import helpers from .enums import FileAttribute from .models import File, FileTransferMethod, FileType -from .tool_file_parser import ToolFileParser def get_attr(*, file: File, attr: FileAttribute): @@ -130,6 +130,6 @@ def _to_url(f: File, /): # add sign url if f.related_id is None or f.extension is None: raise ValueError("Missing file related_id or extension") - return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) + return sign_tool_file(tool_file_id=f.related_id, extension=f.extension) else: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") diff --git a/api/core/file/models.py b/api/core/file/models.py index f5db6c2d74..aa3b5f629c 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -4,11 +4,11 @@ from typing import Any, Optional from pydantic import BaseModel, Field, model_validator from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.tools.signature import sign_tool_file from . import helpers from .constants import FILE_MODEL_IDENTITY from .enums import FileTransferMethod, FileType -from .tool_file_parser import ToolFileParser class ImageConfig(BaseModel): @@ -34,13 +34,21 @@ class FileUploadConfig(BaseModel): class File(BaseModel): + # NOTE: dify_model_identity is a special identifier used to distinguish between + # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY id: Optional[str] = None # message file id tenant_id: str type: FileType transfer_method: FileTransferMethod + # If `transfer_method` is `FileTransferMethod.remote_url`, the + # `remote_url` attribute must not be `None`. remote_url: Optional[str] = None # remote url + # If `transfer_method` is `FileTransferMethod.local_file` or + # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. + # + # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. related_id: Optional[str] = None filename: Optional[str] = None extension: Optional[str] = Field(default=None, description="File extension, should contains dot") @@ -110,9 +118,7 @@ class File(BaseModel): elif self.transfer_method == FileTransferMethod.TOOL_FILE: assert self.related_id is not None assert self.extension is not None - return ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=self.related_id, extension=self.extension - ) + return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) def to_plugin_parameter(self) -> dict[str, Any]: return { diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index 6fa101cf36..656c9d48ed 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,12 +1,19 @@ -from typing import TYPE_CHECKING, Any, cast +from collections.abc import Callable +from typing import TYPE_CHECKING if TYPE_CHECKING: from core.tools.tool_file_manager import ToolFileManager -tool_file_manager: dict[str, Any] = {"manager": None} +_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None class ToolFileParser: @staticmethod def get_tool_file_manager() -> "ToolFileManager": - return cast("ToolFileManager", tool_file_manager["manager"]) + assert _tool_file_manager_factory is not None + return _tool_file_manager_factory() + + +def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: + global _tool_file_manager_factory + _tool_file_manager_factory = factory diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0845ef206e..995a30d44c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -101,7 +101,7 @@ class ModelInstance: @overload def invoke_llm( self, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index b1c43d1455..9d010ae28d 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,4 +1,5 @@ -from collections.abc import Sequence +from abc import ABC +from collections.abc import Mapping, Sequence from enum import Enum, StrEnum from typing import Annotated, Any, Literal, Optional, Union @@ -60,8 +61,12 @@ class PromptMessageContentType(StrEnum): DOCUMENT = "document" -class PromptMessageContent(BaseModel): - pass +class PromptMessageContent(ABC, BaseModel): + """ + Model class for prompt message content. + """ + + type: PromptMessageContentType class TextPromptMessageContent(PromptMessageContent): @@ -125,7 +130,16 @@ PromptMessageContentUnionTypes = Annotated[ ] -class PromptMessage(BaseModel): +CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { + PromptMessageContentType.TEXT: TextPromptMessageContent, + PromptMessageContentType.IMAGE: ImagePromptMessageContent, + PromptMessageContentType.AUDIO: AudioPromptMessageContent, + PromptMessageContentType.VIDEO: VideoPromptMessageContent, + PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent, +} + + +class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ @@ -142,6 +156,23 @@ class PromptMessage(BaseModel): """ return not self.content + @field_validator("content", mode="before") + @classmethod + def validate_content(cls, v): + if isinstance(v, list): + prompts = [] + for prompt in v: + if isinstance(prompt, PromptMessageContent): + if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent): + prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) + elif isinstance(prompt, dict): + prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt) + else: + raise ValueError(f"invalid prompt message {prompt}") + prompts.append(prompt) + return prompts + return v + @field_serializer("content") def serialize_content( self, content: Optional[Union[str, Sequence[PromptMessageContent]]] diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 6312587861..e2cc576f83 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Sequence -from typing import Optional, Union, cast +from typing import Optional, Union from pydantic import ConfigDict @@ -13,14 +13,15 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContentUnionTypes, PromptMessageTool, + TextPromptMessageContent, ) from core.model_runtime.entities.model_entities import ( ModelType, PriceType, ) from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.impl.model import PluginModelClient logger = logging.getLogger(__name__) @@ -238,7 +239,7 @@ class LargeLanguageModel(AIModel): def _invoke_result_generator( self, model: str, - result: Generator, + result: Generator[LLMResultChunk, None, None], credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, @@ -255,11 +256,21 @@ class LargeLanguageModel(AIModel): :return: result generator """ callbacks = callbacks or [] - assistant_message = AssistantPromptMessage(content="") + message_content: list[PromptMessageContentUnionTypes] = [] usage = None system_fingerprint = None real_model = model + def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): + if not content: + return + if isinstance(content, list): + message_content.extend(content) + return + if isinstance(content, str): + message_content.append(TextPromptMessageContent(data=content)) + return + try: for chunk in result: # Following https://github.com/langgenius/dify/issues/17799, @@ -281,9 +292,8 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) - text = convert_llm_result_chunk_to_str(chunk.delta.message.content) - current_content = cast(str, assistant_message.content) - assistant_message.content = current_content + text + _update_message_content(chunk.delta.message.content) + real_model = chunk.model if chunk.delta.usage: usage = chunk.delta.usage @@ -293,6 +303,7 @@ class LargeLanguageModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) + assistant_message = AssistantPromptMessage(content=message_content) self._trigger_after_invoke_callbacks( model=model, result=LLMResult( diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 53789a8e91..5e8a723ec7 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -1,8 +1,6 @@ import pydantic from pydantic import BaseModel -from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes - def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): @@ -10,18 +8,3 @@ def dump_model(model: BaseModel) -> dict: return pydantic.model_dump(model) # type: ignore else: return model.model_dump() - - -def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str: - if content is None: - message_text = "" - elif isinstance(content, str): - message_text = content - elif isinstance(content, list): - # Assuming the list contains PromptMessageContent objects with a "data" attribute - message_text = "".join( - item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content - ) - else: - message_text = str(content) - return message_text diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py new file mode 100644 index 0000000000..e80005d7bf --- /dev/null +++ b/api/core/tools/signature.py @@ -0,0 +1,41 @@ +import base64 +import hashlib +import hmac +import os +import time + +from configs import dify_config + + +def sign_tool_file(tool_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = dify_config.FILES_URL + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + +def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 7e8d4280d4..c60e4f4e4a 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -9,18 +9,28 @@ from typing import Optional, Union from uuid import uuid4 import httpx +from sqlalchemy.orm import Session from configs import dify_config from core.helper import ssrf_proxy -from extensions.ext_database import db +from extensions.ext_database import db as global_db from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile logger = logging.getLogger(__name__) +from sqlalchemy.engine import Engine + class ToolFileManager: + _engine: Engine + + def __init__(self, engine: Engine | None = None): + if engine is None: + engine = global_db.engine + self._engine = engine + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -55,8 +65,8 @@ class ToolFileManager: current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT - @staticmethod def create_file_by_raw( + self, *, user_id: str, tenant_id: str, @@ -77,24 +87,25 @@ class ToolFileManager: filepath = f"tools/{tenant_id}/{unique_filename}" storage.save(filepath, file_binary) - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - name=present_filename, - size=len(file_binary), - ) + with Session(self._engine, expire_on_commit=False) as session: + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + name=present_filename, + size=len(file_binary), + ) - db.session.add(tool_file) - db.session.commit() - db.session.refresh(tool_file) + session.add(tool_file) + session.commit() + session.refresh(tool_file) return tool_file - @staticmethod def create_file_by_url( + self, user_id: str, tenant_id: str, file_url: str, @@ -119,24 +130,24 @@ class ToolFileManager: filepath = f"tools/{tenant_id}/{filename}" storage.save(filepath, blob) - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - original_url=file_url, - name=filename, - size=len(blob), - ) + with Session(self._engine, expire_on_commit=False) as session: + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + original_url=file_url, + name=filename, + size=len(blob), + ) - db.session.add(tool_file) - db.session.commit() + session.add(tool_file) + session.commit() return tool_file - @staticmethod - def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: """ get file binary @@ -144,13 +155,14 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == id, + with Session(self._engine, expire_on_commit=False) as session: + tool_file: ToolFile | None = ( + session.query(ToolFile) + .filter( + ToolFile.id == id, + ) + .first() ) - .first() - ) if not tool_file: return None @@ -159,8 +171,7 @@ class ToolFileManager: return blob, tool_file.mimetype - @staticmethod - def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: """ get file binary @@ -168,33 +179,34 @@ class ToolFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile | None = ( - db.session.query(MessageFile) - .filter( - MessageFile.id == id, + with Session(self._engine, expire_on_commit=False) as session: + message_file: MessageFile | None = ( + session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() ) - .first() - ) - # Check if message_file is not None - if message_file is not None: - # get tool file id - if message_file.url is not None: - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] + # Check if message_file is not None + if message_file is not None: + # get tool file id + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None else: tool_file_id = None - else: - tool_file_id = None - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, + tool_file: ToolFile | None = ( + session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() ) - .first() - ) if not tool_file: return None @@ -203,8 +215,7 @@ class ToolFileManager: return blob, tool_file.mimetype - @staticmethod - def get_file_generator_by_tool_file_id(tool_file_id: str): + def get_file_generator_by_tool_file_id(self, tool_file_id: str): """ get file binary @@ -212,13 +223,14 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, + with Session(self._engine, expire_on_commit=False) as session: + tool_file: ToolFile | None = ( + session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() ) - .first() - ) if not tool_file: return None, None @@ -229,6 +241,11 @@ class ToolFileManager: # init tool_file_parser -from core.file.tool_file_parser import tool_file_manager +from core.file.tool_file_parser import set_tool_file_manager_factory -tool_file_manager["manager"] = ToolFileManager + +def _factory() -> ToolFileManager: + return ToolFileManager() + + +set_tool_file_manager_factory(_factory) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fd0c201e3..0dbc9ccbcc 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -31,8 +31,8 @@ class ToolFileMessageTransformer: # try to download image try: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - file = ToolFileManager.create_file_by_url( + tool_file_manager = ToolFileManager() + file = tool_file_manager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, file_url=message.message.text, @@ -68,7 +68,8 @@ class ToolFileMessageTransformer: # FIXME: should do a type check here. assert isinstance(message.message.blob, bytes) - file = ToolFileManager.create_file_by_raw( + tool_file_manager = ToolFileManager() + file = tool_file_manager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index fd2b0f9ae8..1c82637974 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): mime_type = ( content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" ) + tool_file_manager = ToolFileManager() - tool_file = ToolFileManager.create_file_by_raw( + tool_file = tool_file_manager.create_file_by_raw( user_id=self.user_id, tenant_id=self.tenant_id, conversation_id=None, diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index 6599221691..42b8f4e6ce 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError): class FileTypeNotSupportError(LLMNodeError): def __init__(self, *, type_name: str): super().__init__(f"{type_name} type is not supported by this model") + + +class UnsupportedPromptContentTypeError(LLMNodeError): + def __init__(self, *, type_name: str) -> None: + super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py new file mode 100644 index 0000000000..c85baade03 --- /dev/null +++ b/api/core/workflow/nodes/llm/file_saver.py @@ -0,0 +1,160 @@ +import mimetypes +import typing as tp + +from sqlalchemy import Engine + +from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE +from core.file import File, FileTransferMethod, FileType +from core.helper import ssrf_proxy +from core.tools.signature import sign_tool_file +from core.tools.tool_file_manager import ToolFileManager +from models import db as global_db + + +class LLMFileSaver(tp.Protocol): + """LLMFileSaver is responsible for save multimodal output returned by + LLM. + """ + + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: FileType, + extension_override: str | None = None, + ) -> File: + """save_binary_string saves the inline file data returned by LLM. + + Currently (2025-04-30), only some of Google Gemini models will return + multimodal output as inline data. + + :param data: the contents of the file + :param mime_type: the media type of the file, specified by rfc6838 + (https://datatracker.ietf.org/doc/html/rfc6838) + :param file_type: The file type of the inline file. + :param extension_override: Override the auto-detected file extension while saving this file. + + The default value is `None`, which means do not override the file extension and guessing it + from the `mime_type` attribute while saving the file. + + Setting it to values other than `None` means override the file's extension, and + will bypass the extension guessing saving the file. + + Specially, setting it to empty string (`""`) will leave the file extension empty. + + When it is not `None` or empty string (`""`), it should be a string beginning with a + dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` + and `tar.gz` are not. + """ + pass + + def save_remote_url(self, url: str, file_type: FileType) -> File: + """save_remote_url saves the file from a remote url returned by LLM. + + Currently (2025-04-30), no model returns multimodel output as a url. + + :param url: the url of the file. + :param file_type: the file type of the file, check `FileType` enum for reference. + """ + pass + + +EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] + + +class FileSaverImpl(LLMFileSaver): + _engine_factory: EngineFactory + _tenant_id: str + _user_id: str + + def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None): + if engine_factory is None: + + def _factory(): + return global_db.engine + + engine_factory = _factory + self._engine_factory = engine_factory + self._user_id = user_id + self._tenant_id = tenant_id + + def _get_tool_file_manager(self): + return ToolFileManager(engine=self._engine_factory()) + + def save_remote_url(self, url: str, file_type: FileType) -> File: + http_response = ssrf_proxy.get(url) + http_response.raise_for_status() + data = http_response.content + mime_type_from_header = http_response.headers.get("Content-Type") + mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) + return self.save_binary_string(data, mime_type, file_type, extension_override=extension) + + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: FileType, + extension_override: str | None = None, + ) -> File: + tool_file_manager = self._get_tool_file_manager() + tool_file = tool_file_manager.create_file_by_raw( + user_id=self._user_id, + tenant_id=self._tenant_id, + # TODO(QuantumGhost): what is conversation id? + conversation_id=None, + file_binary=data, + mimetype=mime_type, + ) + extension_override = _validate_extension_override(extension_override) + extension = _get_extension(mime_type, extension_override) + url = sign_tool_file(tool_file.id, extension) + + return File( + tenant_id=self._tenant_id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + filename=tool_file.name, + extension=extension, + mime_type=mime_type, + size=len(data), + related_id=tool_file.id, + url=url, + # TODO(QuantumGhost): how should I set the following key? + # What's the difference between `remote_url` and `url`? + # What's the purpose of `storage_key` and `dify_model_identity`? + storage_key=tool_file.file_key, + ) + + +def _get_extension(mime_type: str, extension_override: str | None = None) -> str: + """get_extension return the extension of file. + + If the `extension_override` parameter is set, this function should honor it and + return its value. + """ + if extension_override is not None: + return extension_override + return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION + + +def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: + """_extract_content_type_and_extension tries to + guess content type of file from url and `Content-Type` header in response. + """ + if content_type_header: + extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION + return content_type_header, extension + content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE + extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION + return content_type, extension + + +def _validate_extension_override(extension_override: str | None) -> str | None: + # `extension_override` is allow to be `None or `""`. + if extension_override is None: + return None + if extension_override == "": + return "" + if not extension_override.startswith("."): + raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) + return extension_override diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 35b146e5d9..5481bd383a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,3 +1,5 @@ +import base64 +import io import json import logging from collections.abc import Generator, Mapping, Sequence @@ -21,7 +23,7 @@ from core.model_runtime.entities import ( PromptMessageContentType, TextPromptMessageContent, ) -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContentUnionTypes, @@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -95,9 +96,13 @@ from .exc import ( TemplateTypeNotSupportError, VariableNotFoundError, ) +from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File + from core.workflow.graph_engine.entities.graph import Graph + from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState logger = logging.getLogger(__name__) @@ -106,6 +111,43 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM + # Instance attributes specific to LLMNode. + # Output variable for file + _file_outputs: list["File"] + + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: """Process structured output if enabled""" @@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]): structured_output = process_structured_output(result_text) if structured_output: outputs["structured_output"] = structured_output + if self._file_outputs is not None: + outputs["files"] = self._file_outputs + yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) except Exception as e: + logger.exception("error while executing llm node") yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]): return self._handle_invoke_result(invoke_result=invoke_result) - def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None] + ) -> Generator[NodeEvent, None, None]: + # For blocking mode if isinstance(invoke_result, LLMResult): - message_text = convert_llm_result_chunk_to_str(invoke_result.message.content) - - yield ModelInvokeCompletedEvent( - text=message_text, - usage=invoke_result.usage, - finish_reason=None, - ) + event = self._handle_blocking_result(invoke_result=invoke_result) + yield event return - model = None + # For streaming mode + model = "" prompt_messages: list[PromptMessage] = [] - full_text = "" - usage = None + + usage = LLMUsage.empty_usage() finish_reason = None + full_text_buffer = io.StringIO() for result in invoke_result: - text = convert_llm_result_chunk_to_str(result.delta.message.content) - full_text += text + contents = result.delta.message.content + for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): + full_text_buffer.write(text_part) + yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"]) - yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) - - if not model: + # Update the whole metadata + if not model and result.model: model = result.model - - if not prompt_messages: - prompt_messages = result.prompt_messages - - if not usage and result.delta.usage: + if len(prompt_messages) == 0: + # TODO(QuantumGhost): it seems that this update has no visable effect. + # What's the purpose of the line below? + prompt_messages = list(result.prompt_messages) + if usage.prompt_tokens == 0 and result.delta.usage: usage = result.delta.usage - - if not finish_reason and result.delta.finish_reason: + if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason - if not usage: - usage = LLMUsage.empty_usage() + yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) - yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) + def _image_file_to_markdown(self, file: "File", /): + text_chunk = f"})" + return text_chunk def _transform_chat_messages( self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / @@ -963,6 +1010,42 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: + buffer = io.StringIO() + for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): + buffer.write(text_part) + + return ModelInvokeCompletedEvent( + text=buffer.getvalue(), + usage=invoke_result.usage, + finish_reason=None, + ) + + def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": + """_save_multimodal_output saves multi-modal contents generated by LLM plugins. + + There are two kinds of multimodal outputs: + + - Inlined data encoded in base64, which would be saved to storage directly. + - Remote files referenced by an url, which would be downloaded and then saved to storage. + + Currently, only image files are supported. + """ + # Inject the saver somehow... + _saver = self._llm_file_saver + + # If this + if content.url != "": + saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) + else: + saved_file = _saver.save_binary_string( + data=base64.b64decode(content.base64_data), + mime_type=content.mime_type, + file_type=FileType.IMAGE, + ) + self._file_outputs.append(saved_file) + return saved_file + def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: """ Handle structured output for models with native JSON schema support. @@ -1123,6 +1206,41 @@ class LLMNode(BaseNode[LLMNodeData]): else SupportStructuredOutputStatus.UNSUPPORTED ) + def _save_multimodal_output_and_convert_result_to_markdown( + self, + contents: str | list[PromptMessageContentUnionTypes] | None, + ) -> Generator[str, None, None]: + """Convert intermediate prompt messages into strings and yield them to the caller. + + If the messages contain non-textual content (e.g., multimedia like images or videos), + it will be saved separately, and the corresponding Markdown representation will + be yielded to the caller. + """ + + # NOTE(QuantumGhost): This function should yield results to the caller immediately + # whenever new content or partial content is available. Avoid any intermediate buffering + # of results. Additionally, do not yield empty strings; instead, yield from an empty list + # if necessary. + if contents is None: + yield from [] + return + if isinstance(contents, str): + yield contents + elif isinstance(contents, list): + for item in contents: + if isinstance(item, TextPromptMessageContent): + yield item.data + elif isinstance(item, ImagePromptMessageContent): + file = self._save_multimodal_image_output(item) + self._file_outputs.append(file) + yield self._image_file_to_markdown(file) + else: + logger.warning("unknown item type encountered, type=%s", type(item)) + yield str(item) + else: + logger.warning("unknown contents type encountered, type=%s", type(contents)) + yield str(contents) + def _combine_message_content_with_role( *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole diff --git a/api/models/engine.py b/api/models/engine.py index dda93bc941..05c1cacdcb 100644 --- a/api/models/engine.py +++ b/api/models/engine.py @@ -10,4 +10,16 @@ POSTGRES_INDEXES_NAMING_CONVENTION = { } metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) + +# ****** IMPORTANT NOTICE ****** +# +# NOTE(QuantumGhost): Avoid directly importing and using `db` in modules outside of the +# `controllers` package. +# +# Instead, import `db` within the `controllers` package and pass it as an argument to +# functions or class constructors. +# +# Directly importing `db` in other modules can make the code more difficult to read, test, and maintain. +# +# Whenever possible, avoid this pattern in new code. db = SQLAlchemy(metadata=metadata) diff --git a/api/models/model.py b/api/models/model.py index 901e92284a..fd05d67e9a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast from core.plugin.entities.plugin import GenericProviderID from core.tools.entities.tool_entities import ToolProviderType +from core.tools.signature import sign_tool_file from services.plugin.plugin_service import PluginService if TYPE_CHECKING: @@ -23,7 +24,6 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers -from core.file.tool_file_parser import ToolFileParser from libs.helper import generate_string from models.base import Base from models.enums import CreatedByRole @@ -986,9 +986,7 @@ class Message(db.Model): # type: ignore[name-defined] if not tool_file_id: continue - sign_url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, extension=extension - ) + sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) elif "file-preview" in url: # get upload file id upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp=" diff --git a/api/models/tools.py b/api/models/tools.py index aef1490729..05604b9330 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -263,8 +263,8 @@ class ToolConversationVariables(Base): class ToolFile(Base): - """ - store the file created by agent + """This table stores file metadata generated in workflows, + not only files created by agent. """ __tablename__ = "tool_files" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py new file mode 100644 index 0000000000..7c722660bc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -0,0 +1,192 @@ +import uuid +from typing import NamedTuple +from unittest import mock + +import httpx +import pytest +from sqlalchemy import Engine + +from core.file import FileTransferMethod, FileType, models +from core.helper import ssrf_proxy +from core.tools import signature +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.nodes.llm.file_saver import ( + FileSaverImpl, + _extract_content_type_and_extension, + _get_extension, + _validate_extension_override, +) +from models import ToolFile + +_PNG_DATA = b"\x89PNG\r\n\x1a\n" + + +def _gen_id(): + return str(uuid.uuid4()) + + +class TestFileSaverImpl: + def test_save_binary_string(self, monkeypatch): + user_id = _gen_id() + tenant_id = _gen_id() + file_type = FileType.IMAGE + mime_type = "image/png" + mock_signed_url = "https://example.com/image.png" + mock_tool_file = ToolFile( + id=_gen_id(), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_key="test-file-key", + mimetype=mime_type, + original_url=None, + name=f"{_gen_id()}.png", + size=len(_PNG_DATA), + ) + mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) + mocked_engine = mock.MagicMock(spec=Engine) + + mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file + monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) + # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. + mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) + # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. + monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) + mocked_sign_file.return_value = mock_signed_url + + storage_file_manager = FileSaverImpl( + user_id=user_id, + tenant_id=tenant_id, + engine_factory=mocked_engine, + ) + + file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) + assert file.tenant_id == tenant_id + assert file.type == file_type + assert file.transfer_method == FileTransferMethod.TOOL_FILE + assert file.extension == ".png" + assert file.mime_type == mime_type + assert file.size == len(_PNG_DATA) + assert file.related_id == mock_tool_file.id + + assert file.generate_url() == mock_signed_url + + mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=_PNG_DATA, + mimetype=mime_type, + ) + mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") + + def test_save_remote_url_request_failed(self, monkeypatch): + _TEST_URL = "https://example.com/image.png" + mock_request = httpx.Request("GET", _TEST_URL) + mock_response = httpx.Response( + status_code=401, + request=mock_request, + ) + file_saver = FileSaverImpl( + user_id=_gen_id(), + tenant_id=_gen_id(), + ) + mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) + monkeypatch.setattr(ssrf_proxy, "get", mock_get) + + with pytest.raises(httpx.HTTPStatusError) as exc: + file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) + mock_get.assert_called_once_with(_TEST_URL) + assert exc.value.response.status_code == 401 + + def test_save_remote_url_success(self, monkeypatch): + _TEST_URL = "https://example.com/image.png" + mime_type = "image/png" + user_id = _gen_id() + tenant_id = _gen_id() + + mock_request = httpx.Request("GET", _TEST_URL) + mock_response = httpx.Response( + status_code=200, + content=b"test-data", + headers={"Content-Type": mime_type}, + request=mock_request, + ) + + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + mock_tool_file = ToolFile( + id=_gen_id(), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_key="test-file-key", + mimetype=mime_type, + original_url=None, + name=f"{_gen_id()}.png", + size=len(_PNG_DATA), + ) + mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) + monkeypatch.setattr(ssrf_proxy, "get", mock_get) + mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) + monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) + + file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) + mock_save_binary_string.assert_called_once_with( + mock_response.content, + mime_type, + FileType.IMAGE, + extension_override=".png", + ) + assert file == mock_tool_file + + +def test_validate_extension_override(): + class TestCase(NamedTuple): + extension_override: str | None + expected: str | None + + cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] + + for valid_ext_override in [None, "", ".png", ".tar.gz"]: + assert valid_ext_override == _validate_extension_override(valid_ext_override) + + for invalid_ext_override in ["png", "tar.gz"]: + with pytest.raises(ValueError) as exc: + _validate_extension_override(invalid_ext_override) + + +class TestExtractContentTypeAndExtension: + def test_with_both_content_type_and_extension(self): + content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") + assert content_type == "image/png" + assert extension == ".png" + + def test_url_with_file_extension(self): + for content_type in [None, ""]: + content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) + assert content_type == "image/png" + assert extension == ".png" + + def test_response_with_content_type(self): + content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") + assert content_type == "image/png" + assert extension == ".png" + + def test_no_content_type_and_no_extension(self): + for content_type in [None, ""]: + content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) + assert content_type == "application/octet-stream" + assert extension == ".bin" + + +class TestGetExtension: + def test_with_extension_override(self): + mime_type = "image/png" + for override in [".jpg", ""]: + extension = _get_extension(mime_type, override) + assert extension == override + + def test_without_extension_override(self): + mime_type = "image/png" + extension = _get_extension(mime_type) + assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5c3e5540c4..519dd73787 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,5 +1,8 @@ +import base64 +import uuid from collections.abc import Sequence from typing import Optional +from unittest import mock import pytest @@ -30,6 +33,7 @@ from core.workflow.nodes.llm.entities import ( VisionConfig, VisionConfigOptions, ) +from core.workflow.nodes.llm.file_saver import LLMFileSaver from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom from models.provider import ProviderType @@ -49,8 +53,8 @@ class MockTokenBufferMemory: @pytest.fixture -def llm_node(): - data = LLMNodeData( +def llm_node_data() -> LLMNodeData: + return LLMNodeData( title="Test LLM", model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), prompt_template=[], @@ -64,42 +68,65 @@ def llm_node(): ), ), ) + + +@pytest.fixture +def graph_init_params() -> GraphInitParams: + return GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + +@pytest.fixture +def graph() -> Graph: + return Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ) + + +@pytest.fixture +def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( system_variables={}, user_inputs={}, ) + return GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) + + +@pytest.fixture +def llm_node( + llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState +) -> LLMNode: + mock_file_saver = mock.MagicMock(spec=LLMFileSaver) node = LLMNode( id="1", config={ "id": "1", - "data": data.model_dump(), + "data": llm_node_data.model_dump(), }, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + llm_file_saver=mock_file_saver, ) return node @@ -465,3 +492,167 @@ def test_handle_list_messages_basic(llm_node): assert len(result) == 1 assert isinstance(result[0], UserPromptMessage) assert result[0].content == [TextPromptMessageContent(data="Hello, world")] + + +@pytest.fixture +def llm_node_for_multimodal( + llm_node_data, graph_init_params, graph, graph_runtime_state +) -> tuple[LLMNode, LLMFileSaver]: + mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": llm_node_data.model_dump(), + }, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + llm_file_saver=mock_file_saver, + ) + return node, mock_file_saver + + +class TestLLMNodeSaveMultiModalImageOutput: + def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]): + llm_node, mock_file_saver = llm_node_for_multimodal + content = ImagePromptMessageContent( + format="png", + base64_data=base64.b64encode(b"test-data").decode(), + mime_type="image/png", + ) + mock_file = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=str(uuid.uuid4()), + filename="test-file.png", + extension=".png", + mime_type="image/png", + size=9, + ) + mock_file_saver.save_binary_string.return_value = mock_file + file = llm_node._save_multimodal_image_output(content=content) + assert llm_node._file_outputs == [mock_file] + assert file == mock_file + mock_file_saver.save_binary_string.assert_called_once_with( + data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE + ) + + def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]): + llm_node, mock_file_saver = llm_node_for_multimodal + content = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/jpg", + ) + mock_file = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=str(uuid.uuid4()), + filename="test-file.png", + extension=".png", + mime_type="image/png", + size=9, + ) + mock_file_saver.save_remote_url.return_value = mock_file + file = llm_node._save_multimodal_image_output(content=content) + assert llm_node._file_outputs == [mock_file] + assert file == mock_file + mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) + + +def test_llm_node_image_file_to_markdown(llm_node: LLMNode): + mock_file = mock.MagicMock(spec=File) + mock_file.generate_url.return_value = "https://example.com/image.png" + markdown = llm_node._image_file_to_markdown(mock_file) + assert markdown == "" + + +class TestSaveMultimodalOutputAndConvertResultToMarkdown: + def test_str_content(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") + assert list(gen) == ["hello world"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_text_prompt_message_content(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + [TextPromptMessageContent(data="hello world")] + ) + assert list(gen) == ["hello world"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch): + llm_node, mock_file_saver = llm_node_for_multimodal + + image_raw_data = b"PNG_DATA" + image_b64_data = base64.b64encode(image_raw_data).decode() + + mock_saved_file = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + filename="test.png", + extension=".png", + size=len(image_raw_data), + related_id=str(uuid.uuid4()), + url="https://example.com/test.png", + storage_key="test_storage_key", + ) + mock_file_saver.save_binary_string.return_value = mock_saved_file + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + [ + ImagePromptMessageContent( + format="png", + base64_data=image_b64_data, + mime_type="image/png", + ) + ] + ) + yielded_strs = list(gen) + assert len(yielded_strs) == 1 + + # This assertion requires careful handling. + # `FILES_URL` settings can vary across environments, which might lead to fragile tests. + # + # Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown, + # we verify that the result includes the markdown image syntax and the expected file URL path. + expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png" + assert yielded_strs[0].startswith(" + assert expected_file_url_path in yielded_strs[0] + assert yielded_strs[0].endswith(")") + mock_file_saver.save_binary_string.assert_called_once_with( + data=image_raw_data, + mime_type="image/png", + file_type=FileType.IMAGE, + ) + assert mock_saved_file in llm_node._file_outputs + + def test_unknown_content_type(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) + assert list(gen) == ["frozenset({'hello world'})"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_unknown_item_type(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) + assert list(gen) == ["frozenset({'hello world'})"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_none_content(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) + assert list(gen) == [] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() diff --git a/web/app/components/base/mermaid/index.tsx b/web/app/components/base/mermaid/index.tsx index 8fd8ae8b59..a484261a51 100644 --- a/web/app/components/base/mermaid/index.tsx +++ b/web/app/components/base/mermaid/index.tsx @@ -476,15 +476,15 @@ const Flowchart = React.forwardRef((props: { 'bg-white': currentTheme === Theme.light, 'bg-slate-900': currentTheme === Theme.dark, }), - mermaidDiv: cn('mermaid cursor-pointer h-auto w-full relative', { + mermaidDiv: cn('mermaid relative h-auto w-full cursor-pointer', { 'bg-white': currentTheme === Theme.light, 'bg-slate-900': currentTheme === Theme.dark, }), - errorMessage: cn('py-4 px-[26px]', { + errorMessage: cn('px-[26px] py-4', { 'text-red-500': currentTheme === Theme.light, 'text-red-400': currentTheme === Theme.dark, }), - errorIcon: cn('w-6 h-6', { + errorIcon: cn('h-6 w-6', { 'text-red-500': currentTheme === Theme.light, 'text-red-400': currentTheme === Theme.dark, }), @@ -492,7 +492,7 @@ const Flowchart = React.forwardRef((props: { 'text-gray-700': currentTheme === Theme.light, 'text-gray-300': currentTheme === Theme.dark, }), - themeToggle: cn('flex items-center justify-center w-10 h-10 rounded-full transition-all duration-300 shadow-md backdrop-blur-sm', { + themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', { 'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light, 'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark, }), @@ -501,7 +501,7 @@ const Flowchart = React.forwardRef((props: { // Style classes for look options const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => { return cn( - 'flex items-center justify-center mb-4 w-[calc((100%-8px)/2)] h-8 rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg cursor-pointer system-sm-medium text-text-secondary', + 'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary', look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary', currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300', look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white', @@ -512,7 +512,7 @@ const Flowchart = React.forwardRef((props: {