diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 56f9b433f3..30c0ff000d 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -49,8 +49,8 @@ jobs: if: steps.changed-files.outputs.any_changed == 'true' run: | uv run --directory api ruff --version - uv run --directory api ruff check ./ - uv run --directory api ruff format --check ./ + uv run --directory api ruff check --diff ./ + uv run --directory api ruff format --check --diff ./ - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/api/constants/mimetypes.py b/api/constants/mimetypes.py new file mode 100644 index 0000000000..38988cdd24 --- /dev/null +++ b/api/constants/mimetypes.py @@ -0,0 +1,7 @@ +# The two constants below should keep in sync. +# Default content type for files which have no explicit content type. + +DEFAULT_MIME_TYPE = "application/octet-stream" +# Default file extension for files which have no explicit content type, should +# correspond to the `DEFAULT_MIME_TYPE` above. +DEFAULT_EXTENSION = ".bin" diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index cfcce81247..7cb97806c8 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -4,7 +4,9 @@ from werkzeug.exceptions import Forbidden, NotFound from controllers.files import api from controllers.files.error import UnsupportedFileTypeError +from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager +from models import db as global_db class ToolFilePreviewApi(Resource): @@ -19,17 +21,14 @@ class ToolFilePreviewApi(Resource): parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args") args = parser.parse_args() - - if not ToolFileManager.verify_file( - file_id=file_id, - timestamp=args["timestamp"], - nonce=args["nonce"], - sign=args["sign"], + if not verify_tool_file_signature( + file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"] ): raise Forbidden("Invalid request.") try: - stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id( + tool_file_manager = ToolFileManager(engine=global_db.engine) + stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id( file_id, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 28ee0eecf4..178fb9477a 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -53,7 +53,7 @@ class PluginUploadFileApi(Resource): raise Forbidden("Invalid request.") try: - tool_file = ToolFileManager.create_file_by_raw( + tool_file = ToolFileManager().create_file_by_raw( user_id=user.id, tenant_id=tenant_id, file_binary=file.read(), diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py index fde506639f..a6d826f08b 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manage.py @@ -24,7 +24,7 @@ from core.app.entities.task_entities import ( WorkflowTaskState, ) from core.llm_generator.llm_generator import LLMGenerator -from core.tools.tool_file_manager import ToolFileManager +from core.tools.signature import sign_tool_file from extensions.ext_database import db from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -154,7 +154,7 @@ class MessageCycleManage: if message_file.url.startswith("http"): url = message_file.url else: - url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension) + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) return MessageFileStreamResponse( task_id=self._application_generate_entity.task_id, diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index 9a204e9ff6..ada19ef8ce 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -10,12 +10,12 @@ from core.model_runtime.entities import ( VideoPromptMessageContent, ) from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from core.tools.signature import sign_tool_file from extensions.ext_storage import storage from . import helpers from .enums import FileAttribute from .models import File, FileTransferMethod, FileType -from .tool_file_parser import ToolFileParser def get_attr(*, file: File, attr: FileAttribute): @@ -130,6 +130,6 @@ def _to_url(f: File, /): # add sign url if f.related_id is None or f.extension is None: raise ValueError("Missing file related_id or extension") - return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) + return sign_tool_file(tool_file_id=f.related_id, extension=f.extension) else: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") diff --git a/api/core/file/models.py b/api/core/file/models.py index f5db6c2d74..aa3b5f629c 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -4,11 +4,11 @@ from typing import Any, Optional from pydantic import BaseModel, Field, model_validator from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.tools.signature import sign_tool_file from . import helpers from .constants import FILE_MODEL_IDENTITY from .enums import FileTransferMethod, FileType -from .tool_file_parser import ToolFileParser class ImageConfig(BaseModel): @@ -34,13 +34,21 @@ class FileUploadConfig(BaseModel): class File(BaseModel): + # NOTE: dify_model_identity is a special identifier used to distinguish between + # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY id: Optional[str] = None # message file id tenant_id: str type: FileType transfer_method: FileTransferMethod + # If `transfer_method` is `FileTransferMethod.remote_url`, the + # `remote_url` attribute must not be `None`. remote_url: Optional[str] = None # remote url + # If `transfer_method` is `FileTransferMethod.local_file` or + # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. + # + # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. related_id: Optional[str] = None filename: Optional[str] = None extension: Optional[str] = Field(default=None, description="File extension, should contains dot") @@ -110,9 +118,7 @@ class File(BaseModel): elif self.transfer_method == FileTransferMethod.TOOL_FILE: assert self.related_id is not None assert self.extension is not None - return ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=self.related_id, extension=self.extension - ) + return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) def to_plugin_parameter(self) -> dict[str, Any]: return { diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index 6fa101cf36..656c9d48ed 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -1,12 +1,19 @@ -from typing import TYPE_CHECKING, Any, cast +from collections.abc import Callable +from typing import TYPE_CHECKING if TYPE_CHECKING: from core.tools.tool_file_manager import ToolFileManager -tool_file_manager: dict[str, Any] = {"manager": None} +_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None class ToolFileParser: @staticmethod def get_tool_file_manager() -> "ToolFileManager": - return cast("ToolFileManager", tool_file_manager["manager"]) + assert _tool_file_manager_factory is not None + return _tool_file_manager_factory() + + +def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: + global _tool_file_manager_factory + _tool_file_manager_factory = factory diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0845ef206e..995a30d44c 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -101,7 +101,7 @@ class ModelInstance: @overload def invoke_llm( self, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, stop: Optional[list[str]] = None, diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index b1c43d1455..9d010ae28d 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,4 +1,5 @@ -from collections.abc import Sequence +from abc import ABC +from collections.abc import Mapping, Sequence from enum import Enum, StrEnum from typing import Annotated, Any, Literal, Optional, Union @@ -60,8 +61,12 @@ class PromptMessageContentType(StrEnum): DOCUMENT = "document" -class PromptMessageContent(BaseModel): - pass +class PromptMessageContent(ABC, BaseModel): + """ + Model class for prompt message content. + """ + + type: PromptMessageContentType class TextPromptMessageContent(PromptMessageContent): @@ -125,7 +130,16 @@ PromptMessageContentUnionTypes = Annotated[ ] -class PromptMessage(BaseModel): +CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { + PromptMessageContentType.TEXT: TextPromptMessageContent, + PromptMessageContentType.IMAGE: ImagePromptMessageContent, + PromptMessageContentType.AUDIO: AudioPromptMessageContent, + PromptMessageContentType.VIDEO: VideoPromptMessageContent, + PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent, +} + + +class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ @@ -142,6 +156,23 @@ class PromptMessage(BaseModel): """ return not self.content + @field_validator("content", mode="before") + @classmethod + def validate_content(cls, v): + if isinstance(v, list): + prompts = [] + for prompt in v: + if isinstance(prompt, PromptMessageContent): + if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent): + prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) + elif isinstance(prompt, dict): + prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt) + else: + raise ValueError(f"invalid prompt message {prompt}") + prompts.append(prompt) + return prompts + return v + @field_serializer("content") def serialize_content( self, content: Optional[Union[str, Sequence[PromptMessageContent]]] diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 3c5a2dce4f..7d5ce1e47e 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -24,7 +24,6 @@ from core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity from core.plugin.impl.model import PluginModelClient @@ -253,15 +252,3 @@ class AIModel(BaseModel): raise Exception(f"Invalid model parameter rule name {name}") return default_parameter_rule - - def _get_num_tokens_by_gpt2(self, text: str) -> int: - """ - Get number of tokens for given prompt messages by gpt2 - Some provider models do not provide an interface for obtaining the number of tokens. - Here, the gpt2 tokenizer is used to calculate the number of tokens. - This method can be executed offline, and the gpt2 tokenizer has been cached in the project. - - :param text: plain text of prompt. You need to convert the original message to plain text - :return: number of tokens - """ - return GPT2Tokenizer.get_num_tokens(text) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 6312587861..e2cc576f83 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Sequence -from typing import Optional, Union, cast +from typing import Optional, Union from pydantic import ConfigDict @@ -13,14 +13,15 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, + PromptMessageContentUnionTypes, PromptMessageTool, + TextPromptMessageContent, ) from core.model_runtime.entities.model_entities import ( ModelType, PriceType, ) from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.impl.model import PluginModelClient logger = logging.getLogger(__name__) @@ -238,7 +239,7 @@ class LargeLanguageModel(AIModel): def _invoke_result_generator( self, model: str, - result: Generator, + result: Generator[LLMResultChunk, None, None], credentials: dict, prompt_messages: Sequence[PromptMessage], model_parameters: dict, @@ -255,11 +256,21 @@ class LargeLanguageModel(AIModel): :return: result generator """ callbacks = callbacks or [] - assistant_message = AssistantPromptMessage(content="") + message_content: list[PromptMessageContentUnionTypes] = [] usage = None system_fingerprint = None real_model = model + def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): + if not content: + return + if isinstance(content, list): + message_content.extend(content) + return + if isinstance(content, str): + message_content.append(TextPromptMessageContent(data=content)) + return + try: for chunk in result: # Following https://github.com/langgenius/dify/issues/17799, @@ -281,9 +292,8 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) - text = convert_llm_result_chunk_to_str(chunk.delta.message.content) - current_content = cast(str, assistant_message.content) - assistant_message.content = current_content + text + _update_message_content(chunk.delta.message.content) + real_model = chunk.model if chunk.delta.usage: usage = chunk.delta.usage @@ -293,6 +303,7 @@ class LargeLanguageModel(AIModel): except Exception as e: raise self._transform_invoke_error(e) + assistant_message = AssistantPromptMessage(content=message_content) self._trigger_after_invoke_callbacks( model=model, result=LLMResult( diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 2f6f4fbbef..b7db0b78bc 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -30,6 +30,8 @@ class GPT2Tokenizer: @staticmethod def get_encoder() -> Any: global _tokenizer, _lock + if _tokenizer is not None: + return _tokenizer with _lock: if _tokenizer is None: # Try to use tiktoken to get the tokenizer because it is faster diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 53789a8e91..5e8a723ec7 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -1,8 +1,6 @@ import pydantic from pydantic import BaseModel -from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes - def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): @@ -10,18 +8,3 @@ def dump_model(model: BaseModel) -> dict: return pydantic.model_dump(model) # type: ignore else: return model.model_dump() - - -def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str: - if content is None: - message_text = "" - elif isinstance(content, str): - message_text = content - elif isinstance(content, list): - # Assuming the list contains PromptMessageContent objects with a "data" attribute - message_text = "".join( - item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content - ) - else: - message_text = str(content) - return message_text diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index 34b4056cf5..b711e8434a 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): ) return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs) - @classmethod - def from_tiktoken_encoder( - cls: type[TS], - encoding_name: str = "gpt2", - model_name: Optional[str] = None, - allowed_special: Union[Literal["all"], Set[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", - **kwargs: Any, - ) -> TS: - """Text splitter that uses tiktoken encoder to count length.""" - try: - import tiktoken - except ImportError: - raise ImportError( - "Could not import tiktoken python package. " - "This is needed in order to calculate max_tokens_for_prompt. " - "Please install it with `pip install tiktoken`." - ) - - if model_name is not None: - enc = tiktoken.encoding_for_model(model_name) - else: - enc = tiktoken.get_encoding(encoding_name) - - def _tiktoken_encoder(text: str) -> int: - return len( - enc.encode( - text, - allowed_special=allowed_special, - disallowed_special=disallowed_special, - ) - ) - - if issubclass(cls, TokenTextSplitter): - extra_kwargs = { - "encoding_name": encoding_name, - "model_name": model_name, - "allowed_special": allowed_special, - "disallowed_special": disallowed_special, - } - kwargs = {**kwargs, **extra_kwargs} - - return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs) - def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: """Transform sequence of documents by splitting them.""" return self.split_documents(list(documents)) diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py new file mode 100644 index 0000000000..e80005d7bf --- /dev/null +++ b/api/core/tools/signature.py @@ -0,0 +1,41 @@ +import base64 +import hashlib +import hmac +import os +import time + +from configs import dify_config + + +def sign_tool_file(tool_file_id: str, extension: str) -> str: + """ + sign file to get a temporary url + """ + base_url = dify_config.FILES_URL + file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" + + +def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: + """ + verify signature + """ + data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 7e8d4280d4..c60e4f4e4a 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -9,18 +9,28 @@ from typing import Optional, Union from uuid import uuid4 import httpx +from sqlalchemy.orm import Session from configs import dify_config from core.helper import ssrf_proxy -from extensions.ext_database import db +from extensions.ext_database import db as global_db from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile logger = logging.getLogger(__name__) +from sqlalchemy.engine import Engine + class ToolFileManager: + _engine: Engine + + def __init__(self, engine: Engine | None = None): + if engine is None: + engine = global_db.engine + self._engine = engine + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -55,8 +65,8 @@ class ToolFileManager: current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT - @staticmethod def create_file_by_raw( + self, *, user_id: str, tenant_id: str, @@ -77,24 +87,25 @@ class ToolFileManager: filepath = f"tools/{tenant_id}/{unique_filename}" storage.save(filepath, file_binary) - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - name=present_filename, - size=len(file_binary), - ) + with Session(self._engine, expire_on_commit=False) as session: + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + name=present_filename, + size=len(file_binary), + ) - db.session.add(tool_file) - db.session.commit() - db.session.refresh(tool_file) + session.add(tool_file) + session.commit() + session.refresh(tool_file) return tool_file - @staticmethod def create_file_by_url( + self, user_id: str, tenant_id: str, file_url: str, @@ -119,24 +130,24 @@ class ToolFileManager: filepath = f"tools/{tenant_id}/{filename}" storage.save(filepath, blob) - tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=conversation_id, - file_key=filepath, - mimetype=mimetype, - original_url=file_url, - name=filename, - size=len(blob), - ) + with Session(self._engine, expire_on_commit=False) as session: + tool_file = ToolFile( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=conversation_id, + file_key=filepath, + mimetype=mimetype, + original_url=file_url, + name=filename, + size=len(blob), + ) - db.session.add(tool_file) - db.session.commit() + session.add(tool_file) + session.commit() return tool_file - @staticmethod - def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]: """ get file binary @@ -144,13 +155,14 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == id, + with Session(self._engine, expire_on_commit=False) as session: + tool_file: ToolFile | None = ( + session.query(ToolFile) + .filter( + ToolFile.id == id, + ) + .first() ) - .first() - ) if not tool_file: return None @@ -159,8 +171,7 @@ class ToolFileManager: return blob, tool_file.mimetype - @staticmethod - def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: + def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]: """ get file binary @@ -168,33 +179,34 @@ class ToolFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile | None = ( - db.session.query(MessageFile) - .filter( - MessageFile.id == id, + with Session(self._engine, expire_on_commit=False) as session: + message_file: MessageFile | None = ( + session.query(MessageFile) + .filter( + MessageFile.id == id, + ) + .first() ) - .first() - ) - # Check if message_file is not None - if message_file is not None: - # get tool file id - if message_file.url is not None: - tool_file_id = message_file.url.split("/")[-1] - # trim extension - tool_file_id = tool_file_id.split(".")[0] + # Check if message_file is not None + if message_file is not None: + # get tool file id + if message_file.url is not None: + tool_file_id = message_file.url.split("/")[-1] + # trim extension + tool_file_id = tool_file_id.split(".")[0] + else: + tool_file_id = None else: tool_file_id = None - else: - tool_file_id = None - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, + tool_file: ToolFile | None = ( + session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() ) - .first() - ) if not tool_file: return None @@ -203,8 +215,7 @@ class ToolFileManager: return blob, tool_file.mimetype - @staticmethod - def get_file_generator_by_tool_file_id(tool_file_id: str): + def get_file_generator_by_tool_file_id(self, tool_file_id: str): """ get file binary @@ -212,13 +223,14 @@ class ToolFileManager: :return: the binary of the file, mime type """ - tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, + with Session(self._engine, expire_on_commit=False) as session: + tool_file: ToolFile | None = ( + session.query(ToolFile) + .filter( + ToolFile.id == tool_file_id, + ) + .first() ) - .first() - ) if not tool_file: return None, None @@ -229,6 +241,11 @@ class ToolFileManager: # init tool_file_parser -from core.file.tool_file_parser import tool_file_manager +from core.file.tool_file_parser import set_tool_file_manager_factory -tool_file_manager["manager"] = ToolFileManager + +def _factory() -> ToolFileManager: + return ToolFileManager() + + +set_tool_file_manager_factory(_factory) diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fd0c201e3..0dbc9ccbcc 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -31,8 +31,8 @@ class ToolFileMessageTransformer: # try to download image try: assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - file = ToolFileManager.create_file_by_url( + tool_file_manager = ToolFileManager() + file = tool_file_manager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, file_url=message.message.text, @@ -68,7 +68,8 @@ class ToolFileMessageTransformer: # FIXME: should do a type check here. assert isinstance(message.message.blob, bytes) - file = ToolFileManager.create_file_by_raw( + tool_file_manager = ToolFileManager() + file = tool_file_manager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index fd2b0f9ae8..1c82637974 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): mime_type = ( content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" ) + tool_file_manager = ToolFileManager() - tool_file = ToolFileManager.create_file_by_raw( + tool_file = tool_file_manager.create_file_by_raw( user_id=self.user_id, tenant_id=self.tenant_id, conversation_id=None, diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4ec033572c..66be95cdd9 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(LLMNode): filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value) case "after" | ">": filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value) - case "≤" | ">=": + case "≤" | "<=": filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value) case "≥" | ">=": filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value) diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index 6599221691..42b8f4e6ce 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError): class FileTypeNotSupportError(LLMNodeError): def __init__(self, *, type_name: str): super().__init__(f"{type_name} type is not supported by this model") + + +class UnsupportedPromptContentTypeError(LLMNodeError): + def __init__(self, *, type_name: str) -> None: + super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py new file mode 100644 index 0000000000..c85baade03 --- /dev/null +++ b/api/core/workflow/nodes/llm/file_saver.py @@ -0,0 +1,160 @@ +import mimetypes +import typing as tp + +from sqlalchemy import Engine + +from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE +from core.file import File, FileTransferMethod, FileType +from core.helper import ssrf_proxy +from core.tools.signature import sign_tool_file +from core.tools.tool_file_manager import ToolFileManager +from models import db as global_db + + +class LLMFileSaver(tp.Protocol): + """LLMFileSaver is responsible for save multimodal output returned by + LLM. + """ + + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: FileType, + extension_override: str | None = None, + ) -> File: + """save_binary_string saves the inline file data returned by LLM. + + Currently (2025-04-30), only some of Google Gemini models will return + multimodal output as inline data. + + :param data: the contents of the file + :param mime_type: the media type of the file, specified by rfc6838 + (https://datatracker.ietf.org/doc/html/rfc6838) + :param file_type: The file type of the inline file. + :param extension_override: Override the auto-detected file extension while saving this file. + + The default value is `None`, which means do not override the file extension and guessing it + from the `mime_type` attribute while saving the file. + + Setting it to values other than `None` means override the file's extension, and + will bypass the extension guessing saving the file. + + Specially, setting it to empty string (`""`) will leave the file extension empty. + + When it is not `None` or empty string (`""`), it should be a string beginning with a + dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` + and `tar.gz` are not. + """ + pass + + def save_remote_url(self, url: str, file_type: FileType) -> File: + """save_remote_url saves the file from a remote url returned by LLM. + + Currently (2025-04-30), no model returns multimodel output as a url. + + :param url: the url of the file. + :param file_type: the file type of the file, check `FileType` enum for reference. + """ + pass + + +EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] + + +class FileSaverImpl(LLMFileSaver): + _engine_factory: EngineFactory + _tenant_id: str + _user_id: str + + def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None): + if engine_factory is None: + + def _factory(): + return global_db.engine + + engine_factory = _factory + self._engine_factory = engine_factory + self._user_id = user_id + self._tenant_id = tenant_id + + def _get_tool_file_manager(self): + return ToolFileManager(engine=self._engine_factory()) + + def save_remote_url(self, url: str, file_type: FileType) -> File: + http_response = ssrf_proxy.get(url) + http_response.raise_for_status() + data = http_response.content + mime_type_from_header = http_response.headers.get("Content-Type") + mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) + return self.save_binary_string(data, mime_type, file_type, extension_override=extension) + + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: FileType, + extension_override: str | None = None, + ) -> File: + tool_file_manager = self._get_tool_file_manager() + tool_file = tool_file_manager.create_file_by_raw( + user_id=self._user_id, + tenant_id=self._tenant_id, + # TODO(QuantumGhost): what is conversation id? + conversation_id=None, + file_binary=data, + mimetype=mime_type, + ) + extension_override = _validate_extension_override(extension_override) + extension = _get_extension(mime_type, extension_override) + url = sign_tool_file(tool_file.id, extension) + + return File( + tenant_id=self._tenant_id, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + filename=tool_file.name, + extension=extension, + mime_type=mime_type, + size=len(data), + related_id=tool_file.id, + url=url, + # TODO(QuantumGhost): how should I set the following key? + # What's the difference between `remote_url` and `url`? + # What's the purpose of `storage_key` and `dify_model_identity`? + storage_key=tool_file.file_key, + ) + + +def _get_extension(mime_type: str, extension_override: str | None = None) -> str: + """get_extension return the extension of file. + + If the `extension_override` parameter is set, this function should honor it and + return its value. + """ + if extension_override is not None: + return extension_override + return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION + + +def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: + """_extract_content_type_and_extension tries to + guess content type of file from url and `Content-Type` header in response. + """ + if content_type_header: + extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION + return content_type_header, extension + content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE + extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION + return content_type, extension + + +def _validate_extension_override(extension_override: str | None) -> str | None: + # `extension_override` is allow to be `None or `""`. + if extension_override is None: + return None + if extension_override == "": + return "" + if not extension_override.startswith("."): + raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) + return extension_override diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 35b146e5d9..5481bd383a 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,3 +1,5 @@ +import base64 +import io import json import logging from collections.abc import Generator, Mapping, Sequence @@ -21,7 +23,7 @@ from core.model_runtime.entities import ( PromptMessageContentType, TextPromptMessageContent, ) -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessageContentUnionTypes, @@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -95,9 +96,13 @@ from .exc import ( TemplateTypeNotSupportError, VariableNotFoundError, ) +from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File + from core.workflow.graph_engine.entities.graph import Graph + from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams + from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState logger = logging.getLogger(__name__) @@ -106,6 +111,43 @@ class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData _node_type = NodeType.LLM + # Instance attributes specific to LLMNode. + # Output variable for file + _file_outputs: list["File"] + + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]: """Process structured output if enabled""" @@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]): structured_output = process_structured_output(result_text) if structured_output: outputs["structured_output"] = structured_output + if self._file_outputs is not None: + outputs["files"] = self._file_outputs + yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) except Exception as e: + logger.exception("error while executing llm node") yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]): return self._handle_invoke_result(invoke_result=invoke_result) - def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: + def _handle_invoke_result( + self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None] + ) -> Generator[NodeEvent, None, None]: + # For blocking mode if isinstance(invoke_result, LLMResult): - message_text = convert_llm_result_chunk_to_str(invoke_result.message.content) - - yield ModelInvokeCompletedEvent( - text=message_text, - usage=invoke_result.usage, - finish_reason=None, - ) + event = self._handle_blocking_result(invoke_result=invoke_result) + yield event return - model = None + # For streaming mode + model = "" prompt_messages: list[PromptMessage] = [] - full_text = "" - usage = None + + usage = LLMUsage.empty_usage() finish_reason = None + full_text_buffer = io.StringIO() for result in invoke_result: - text = convert_llm_result_chunk_to_str(result.delta.message.content) - full_text += text + contents = result.delta.message.content + for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): + full_text_buffer.write(text_part) + yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"]) - yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) - - if not model: + # Update the whole metadata + if not model and result.model: model = result.model - - if not prompt_messages: - prompt_messages = result.prompt_messages - - if not usage and result.delta.usage: + if len(prompt_messages) == 0: + # TODO(QuantumGhost): it seems that this update has no visable effect. + # What's the purpose of the line below? + prompt_messages = list(result.prompt_messages) + if usage.prompt_tokens == 0 and result.delta.usage: usage = result.delta.usage - - if not finish_reason and result.delta.finish_reason: + if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason - if not usage: - usage = LLMUsage.empty_usage() + yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) - yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason) + def _image_file_to_markdown(self, file: "File", /): + text_chunk = f"})" + return text_chunk def _transform_chat_messages( self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / @@ -963,6 +1010,42 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages + def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: + buffer = io.StringIO() + for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): + buffer.write(text_part) + + return ModelInvokeCompletedEvent( + text=buffer.getvalue(), + usage=invoke_result.usage, + finish_reason=None, + ) + + def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": + """_save_multimodal_output saves multi-modal contents generated by LLM plugins. + + There are two kinds of multimodal outputs: + + - Inlined data encoded in base64, which would be saved to storage directly. + - Remote files referenced by an url, which would be downloaded and then saved to storage. + + Currently, only image files are supported. + """ + # Inject the saver somehow... + _saver = self._llm_file_saver + + # If this + if content.url != "": + saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) + else: + saved_file = _saver.save_binary_string( + data=base64.b64decode(content.base64_data), + mime_type=content.mime_type, + file_type=FileType.IMAGE, + ) + self._file_outputs.append(saved_file) + return saved_file + def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict: """ Handle structured output for models with native JSON schema support. @@ -1123,6 +1206,41 @@ class LLMNode(BaseNode[LLMNodeData]): else SupportStructuredOutputStatus.UNSUPPORTED ) + def _save_multimodal_output_and_convert_result_to_markdown( + self, + contents: str | list[PromptMessageContentUnionTypes] | None, + ) -> Generator[str, None, None]: + """Convert intermediate prompt messages into strings and yield them to the caller. + + If the messages contain non-textual content (e.g., multimedia like images or videos), + it will be saved separately, and the corresponding Markdown representation will + be yielded to the caller. + """ + + # NOTE(QuantumGhost): This function should yield results to the caller immediately + # whenever new content or partial content is available. Avoid any intermediate buffering + # of results. Additionally, do not yield empty strings; instead, yield from an empty list + # if necessary. + if contents is None: + yield from [] + return + if isinstance(contents, str): + yield contents + elif isinstance(contents, list): + for item in contents: + if isinstance(item, TextPromptMessageContent): + yield item.data + elif isinstance(item, ImagePromptMessageContent): + file = self._save_multimodal_image_output(item) + self._file_outputs.append(file) + yield self._image_file_to_markdown(file) + else: + logger.warning("unknown item type encountered, type=%s", type(item)) + yield str(item) + else: + logger.warning("unknown contents type encountered, type=%s", type(contents)) + yield str(contents) + def _combine_message_content_with_role( *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole diff --git a/api/models/engine.py b/api/models/engine.py index dda93bc941..05c1cacdcb 100644 --- a/api/models/engine.py +++ b/api/models/engine.py @@ -10,4 +10,16 @@ POSTGRES_INDEXES_NAMING_CONVENTION = { } metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) + +# ****** IMPORTANT NOTICE ****** +# +# NOTE(QuantumGhost): Avoid directly importing and using `db` in modules outside of the +# `controllers` package. +# +# Instead, import `db` within the `controllers` package and pass it as an argument to +# functions or class constructors. +# +# Directly importing `db` in other modules can make the code more difficult to read, test, and maintain. +# +# Whenever possible, avoid this pattern in new code. db = SQLAlchemy(metadata=metadata) diff --git a/api/models/model.py b/api/models/model.py index 901e92284a..fd05d67e9a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast from core.plugin.entities.plugin import GenericProviderID from core.tools.entities.tool_entities import ToolProviderType +from core.tools.signature import sign_tool_file from services.plugin.plugin_service import PluginService if TYPE_CHECKING: @@ -23,7 +24,6 @@ from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers -from core.file.tool_file_parser import ToolFileParser from libs.helper import generate_string from models.base import Base from models.enums import CreatedByRole @@ -986,9 +986,7 @@ class Message(db.Model): # type: ignore[name-defined] if not tool_file_id: continue - sign_url = ToolFileParser.get_tool_file_manager().sign_file( - tool_file_id=tool_file_id, extension=extension - ) + sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) elif "file-preview" in url: # get upload file id upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp=" diff --git a/api/models/tools.py b/api/models/tools.py index aef1490729..05604b9330 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -263,8 +263,8 @@ class ToolConversationVariables(Base): class ToolFile(Base): - """ - store the file created by agent + """This table stores file metadata generated in workflows, + not only files created by agent. """ __tablename__ = "tool_files" diff --git a/api/pyproject.toml b/api/pyproject.toml index f3526ec717..dbaa3588b7 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -77,7 +77,7 @@ dependencies = [ "sentry-sdk[flask]~=1.44.1", "sqlalchemy~=2.0.29", "starlette==0.41.0", - "tiktoken~=0.8.0", + "tiktoken~=0.9.0", "tokenizers~=0.15.0", "transformers~=4.35.0", "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py new file mode 100644 index 0000000000..7c722660bc --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -0,0 +1,192 @@ +import uuid +from typing import NamedTuple +from unittest import mock + +import httpx +import pytest +from sqlalchemy import Engine + +from core.file import FileTransferMethod, FileType, models +from core.helper import ssrf_proxy +from core.tools import signature +from core.tools.tool_file_manager import ToolFileManager +from core.workflow.nodes.llm.file_saver import ( + FileSaverImpl, + _extract_content_type_and_extension, + _get_extension, + _validate_extension_override, +) +from models import ToolFile + +_PNG_DATA = b"\x89PNG\r\n\x1a\n" + + +def _gen_id(): + return str(uuid.uuid4()) + + +class TestFileSaverImpl: + def test_save_binary_string(self, monkeypatch): + user_id = _gen_id() + tenant_id = _gen_id() + file_type = FileType.IMAGE + mime_type = "image/png" + mock_signed_url = "https://example.com/image.png" + mock_tool_file = ToolFile( + id=_gen_id(), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_key="test-file-key", + mimetype=mime_type, + original_url=None, + name=f"{_gen_id()}.png", + size=len(_PNG_DATA), + ) + mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) + mocked_engine = mock.MagicMock(spec=Engine) + + mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file + monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) + # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. + mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) + # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. + monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) + mocked_sign_file.return_value = mock_signed_url + + storage_file_manager = FileSaverImpl( + user_id=user_id, + tenant_id=tenant_id, + engine_factory=mocked_engine, + ) + + file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) + assert file.tenant_id == tenant_id + assert file.type == file_type + assert file.transfer_method == FileTransferMethod.TOOL_FILE + assert file.extension == ".png" + assert file.mime_type == mime_type + assert file.size == len(_PNG_DATA) + assert file.related_id == mock_tool_file.id + + assert file.generate_url() == mock_signed_url + + mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=_PNG_DATA, + mimetype=mime_type, + ) + mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png") + + def test_save_remote_url_request_failed(self, monkeypatch): + _TEST_URL = "https://example.com/image.png" + mock_request = httpx.Request("GET", _TEST_URL) + mock_response = httpx.Response( + status_code=401, + request=mock_request, + ) + file_saver = FileSaverImpl( + user_id=_gen_id(), + tenant_id=_gen_id(), + ) + mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) + monkeypatch.setattr(ssrf_proxy, "get", mock_get) + + with pytest.raises(httpx.HTTPStatusError) as exc: + file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) + mock_get.assert_called_once_with(_TEST_URL) + assert exc.value.response.status_code == 401 + + def test_save_remote_url_success(self, monkeypatch): + _TEST_URL = "https://example.com/image.png" + mime_type = "image/png" + user_id = _gen_id() + tenant_id = _gen_id() + + mock_request = httpx.Request("GET", _TEST_URL) + mock_response = httpx.Response( + status_code=200, + content=b"test-data", + headers={"Content-Type": mime_type}, + request=mock_request, + ) + + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + mock_tool_file = ToolFile( + id=_gen_id(), + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_key="test-file-key", + mimetype=mime_type, + original_url=None, + name=f"{_gen_id()}.png", + size=len(_PNG_DATA), + ) + mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) + monkeypatch.setattr(ssrf_proxy, "get", mock_get) + mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) + monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) + + file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) + mock_save_binary_string.assert_called_once_with( + mock_response.content, + mime_type, + FileType.IMAGE, + extension_override=".png", + ) + assert file == mock_tool_file + + +def test_validate_extension_override(): + class TestCase(NamedTuple): + extension_override: str | None + expected: str | None + + cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] + + for valid_ext_override in [None, "", ".png", ".tar.gz"]: + assert valid_ext_override == _validate_extension_override(valid_ext_override) + + for invalid_ext_override in ["png", "tar.gz"]: + with pytest.raises(ValueError) as exc: + _validate_extension_override(invalid_ext_override) + + +class TestExtractContentTypeAndExtension: + def test_with_both_content_type_and_extension(self): + content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") + assert content_type == "image/png" + assert extension == ".png" + + def test_url_with_file_extension(self): + for content_type in [None, ""]: + content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) + assert content_type == "image/png" + assert extension == ".png" + + def test_response_with_content_type(self): + content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") + assert content_type == "image/png" + assert extension == ".png" + + def test_no_content_type_and_no_extension(self): + for content_type in [None, ""]: + content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) + assert content_type == "application/octet-stream" + assert extension == ".bin" + + +class TestGetExtension: + def test_with_extension_override(self): + mime_type = "image/png" + for override in [".jpg", ""]: + extension = _get_extension(mime_type, override) + assert extension == override + + def test_without_extension_override(self): + mime_type = "image/png" + extension = _get_extension(mime_type) + assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 5c3e5540c4..519dd73787 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,5 +1,8 @@ +import base64 +import uuid from collections.abc import Sequence from typing import Optional +from unittest import mock import pytest @@ -30,6 +33,7 @@ from core.workflow.nodes.llm.entities import ( VisionConfig, VisionConfigOptions, ) +from core.workflow.nodes.llm.file_saver import LLMFileSaver from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom from models.provider import ProviderType @@ -49,8 +53,8 @@ class MockTokenBufferMemory: @pytest.fixture -def llm_node(): - data = LLMNodeData( +def llm_node_data() -> LLMNodeData: + return LLMNodeData( title="Test LLM", model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), prompt_template=[], @@ -64,42 +68,65 @@ def llm_node(): ), ), ) + + +@pytest.fixture +def graph_init_params() -> GraphInitParams: + return GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + + +@pytest.fixture +def graph() -> Graph: + return Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ) + + +@pytest.fixture +def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( system_variables={}, user_inputs={}, ) + return GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ) + + +@pytest.fixture +def llm_node( + llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState +) -> LLMNode: + mock_file_saver = mock.MagicMock(spec=LLMFileSaver) node = LLMNode( id="1", config={ "id": "1", - "data": data.model_dump(), + "data": llm_node_data.model_dump(), }, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + llm_file_saver=mock_file_saver, ) return node @@ -465,3 +492,167 @@ def test_handle_list_messages_basic(llm_node): assert len(result) == 1 assert isinstance(result[0], UserPromptMessage) assert result[0].content == [TextPromptMessageContent(data="Hello, world")] + + +@pytest.fixture +def llm_node_for_multimodal( + llm_node_data, graph_init_params, graph, graph_runtime_state +) -> tuple[LLMNode, LLMFileSaver]: + mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": llm_node_data.model_dump(), + }, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + llm_file_saver=mock_file_saver, + ) + return node, mock_file_saver + + +class TestLLMNodeSaveMultiModalImageOutput: + def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]): + llm_node, mock_file_saver = llm_node_for_multimodal + content = ImagePromptMessageContent( + format="png", + base64_data=base64.b64encode(b"test-data").decode(), + mime_type="image/png", + ) + mock_file = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=str(uuid.uuid4()), + filename="test-file.png", + extension=".png", + mime_type="image/png", + size=9, + ) + mock_file_saver.save_binary_string.return_value = mock_file + file = llm_node._save_multimodal_image_output(content=content) + assert llm_node._file_outputs == [mock_file] + assert file == mock_file + mock_file_saver.save_binary_string.assert_called_once_with( + data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE + ) + + def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]): + llm_node, mock_file_saver = llm_node_for_multimodal + content = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/jpg", + ) + mock_file = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id=str(uuid.uuid4()), + filename="test-file.png", + extension=".png", + mime_type="image/png", + size=9, + ) + mock_file_saver.save_remote_url.return_value = mock_file + file = llm_node._save_multimodal_image_output(content=content) + assert llm_node._file_outputs == [mock_file] + assert file == mock_file + mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE) + + +def test_llm_node_image_file_to_markdown(llm_node: LLMNode): + mock_file = mock.MagicMock(spec=File) + mock_file.generate_url.return_value = "https://example.com/image.png" + markdown = llm_node._image_file_to_markdown(mock_file) + assert markdown == "" + + +class TestSaveMultimodalOutputAndConvertResultToMarkdown: + def test_str_content(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world") + assert list(gen) == ["hello world"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_text_prompt_message_content(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + [TextPromptMessageContent(data="hello world")] + ) + assert list(gen) == ["hello world"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch): + llm_node, mock_file_saver = llm_node_for_multimodal + + image_raw_data = b"PNG_DATA" + image_b64_data = base64.b64encode(image_raw_data).decode() + + mock_saved_file = File( + id=str(uuid.uuid4()), + tenant_id="1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + filename="test.png", + extension=".png", + size=len(image_raw_data), + related_id=str(uuid.uuid4()), + url="https://example.com/test.png", + storage_key="test_storage_key", + ) + mock_file_saver.save_binary_string.return_value = mock_saved_file + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( + [ + ImagePromptMessageContent( + format="png", + base64_data=image_b64_data, + mime_type="image/png", + ) + ] + ) + yielded_strs = list(gen) + assert len(yielded_strs) == 1 + + # This assertion requires careful handling. + # `FILES_URL` settings can vary across environments, which might lead to fragile tests. + # + # Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown, + # we verify that the result includes the markdown image syntax and the expected file URL path. + expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png" + assert yielded_strs[0].startswith(" + assert expected_file_url_path in yielded_strs[0] + assert yielded_strs[0].endswith(")") + mock_file_saver.save_binary_string.assert_called_once_with( + data=image_raw_data, + mime_type="image/png", + file_type=FileType.IMAGE, + ) + assert mock_saved_file in llm_node._file_outputs + + def test_unknown_content_type(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"])) + assert list(gen) == ["frozenset({'hello world'})"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_unknown_item_type(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])]) + assert list(gen) == ["frozenset({'hello world'})"] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() + + def test_none_content(self, llm_node_for_multimodal): + llm_node, mock_file_saver = llm_node_for_multimodal + gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None) + assert list(gen) == [] + mock_file_saver.save_binary_string.assert_not_called() + mock_file_saver.save_remote_url.assert_not_called() diff --git a/api/uv.lock b/api/uv.lock index 9ae14dbd25..9594c54b3d 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1398,7 +1398,7 @@ requires-dist = [ { name = "sentry-sdk", extras = ["flask"], specifier = "~=1.44.1" }, { name = "sqlalchemy", specifier = "~=2.0.29" }, { name = "starlette", specifier = "==0.41.0" }, - { name = "tiktoken", specifier = "~=0.8.0" }, + { name = "tiktoken", specifier = "~=0.9.0" }, { name = "tokenizers", specifier = "~=0.15.0" }, { name = "transformers", specifier = "~=4.35.0" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, @@ -5335,26 +5335,26 @@ wheels = [ [[package]] name = "tiktoken" -version = "0.8.0" +version = "0.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "regex" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/37/02/576ff3a6639e755c4f70997b2d315f56d6d71e0d046f4fb64cb81a3fb099/tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2", size = 35107 } +sdist = { url = "https://files.pythonhosted.org/packages/ea/cf/756fedf6981e82897f2d570dd25fa597eb3f4459068ae0572d7e888cfd6f/tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d", size = 35991 } wheels = [ - { url = "https://files.pythonhosted.org/packages/f6/1e/ca48e7bfeeccaf76f3a501bd84db1fa28b3c22c9d1a1f41af9fb7579c5f6/tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1", size = 1039700 }, - { url = "https://files.pythonhosted.org/packages/8c/f8/f0101d98d661b34534769c3818f5af631e59c36ac6d07268fbfc89e539ce/tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a", size = 982413 }, - { url = "https://files.pythonhosted.org/packages/ac/3c/2b95391d9bd520a73830469f80a96e3790e6c0a5ac2444f80f20b4b31051/tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d", size = 1144242 }, - { url = "https://files.pythonhosted.org/packages/01/c4/c4a4360de845217b6aa9709c15773484b50479f36bb50419c443204e5de9/tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47", size = 1176588 }, - { url = "https://files.pythonhosted.org/packages/f8/a3/ef984e976822cd6c2227c854f74d2e60cf4cd6fbfca46251199914746f78/tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419", size = 1237261 }, - { url = "https://files.pythonhosted.org/packages/1e/86/eea2309dc258fb86c7d9b10db536434fc16420feaa3b6113df18b23db7c2/tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99", size = 884537 }, - { url = "https://files.pythonhosted.org/packages/c1/22/34b2e136a6f4af186b6640cbfd6f93400783c9ef6cd550d9eab80628d9de/tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586", size = 1039357 }, - { url = "https://files.pythonhosted.org/packages/04/d2/c793cf49c20f5855fd6ce05d080c0537d7418f22c58e71f392d5e8c8dbf7/tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b", size = 982616 }, - { url = "https://files.pythonhosted.org/packages/b3/a1/79846e5ef911cd5d75c844de3fa496a10c91b4b5f550aad695c5df153d72/tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab", size = 1144011 }, - { url = "https://files.pythonhosted.org/packages/26/32/e0e3a859136e95c85a572e4806dc58bf1ddf651108ae8b97d5f3ebe1a244/tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04", size = 1175432 }, - { url = "https://files.pythonhosted.org/packages/c7/89/926b66e9025b97e9fbabeaa59048a736fe3c3e4530a204109571104f921c/tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc", size = 1236576 }, - { url = "https://files.pythonhosted.org/packages/45/e2/39d4aa02a52bba73b2cd21ba4533c84425ff8786cc63c511d68c8897376e/tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db", size = 883824 }, + { url = "https://files.pythonhosted.org/packages/4d/ae/4613a59a2a48e761c5161237fc850eb470b4bb93696db89da51b79a871f1/tiktoken-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f32cc56168eac4851109e9b5d327637f15fd662aa30dd79f964b7c39fbadd26e", size = 1065987 }, + { url = "https://files.pythonhosted.org/packages/3f/86/55d9d1f5b5a7e1164d0f1538a85529b5fcba2b105f92db3622e5d7de6522/tiktoken-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45556bc41241e5294063508caf901bf92ba52d8ef9222023f83d2483a3055348", size = 1009155 }, + { url = "https://files.pythonhosted.org/packages/03/58/01fb6240df083b7c1916d1dcb024e2b761213c95d576e9f780dfb5625a76/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03935988a91d6d3216e2ec7c645afbb3d870b37bcb67ada1943ec48678e7ee33", size = 1142898 }, + { url = "https://files.pythonhosted.org/packages/b1/73/41591c525680cd460a6becf56c9b17468d3711b1df242c53d2c7b2183d16/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3d80aad8d2c6b9238fc1a5524542087c52b860b10cbf952429ffb714bc1136", size = 1197535 }, + { url = "https://files.pythonhosted.org/packages/7d/7c/1069f25521c8f01a1a182f362e5c8e0337907fae91b368b7da9c3e39b810/tiktoken-0.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b2a21133be05dc116b1d0372af051cd2c6aa1d2188250c9b553f9fa49301b336", size = 1259548 }, + { url = "https://files.pythonhosted.org/packages/6f/07/c67ad1724b8e14e2b4c8cca04b15da158733ac60136879131db05dda7c30/tiktoken-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:11a20e67fdf58b0e2dea7b8654a288e481bb4fc0289d3ad21291f8d0849915fb", size = 893895 }, + { url = "https://files.pythonhosted.org/packages/cf/e5/21ff33ecfa2101c1bb0f9b6df750553bd873b7fb532ce2cb276ff40b197f/tiktoken-0.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e88f121c1c22b726649ce67c089b90ddda8b9662545a8aeb03cfef15967ddd03", size = 1065073 }, + { url = "https://files.pythonhosted.org/packages/8e/03/a95e7b4863ee9ceec1c55983e4cc9558bcfd8f4f80e19c4f8a99642f697d/tiktoken-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a6600660f2f72369acb13a57fb3e212434ed38b045fd8cc6cdd74947b4b5d210", size = 1008075 }, + { url = "https://files.pythonhosted.org/packages/40/10/1305bb02a561595088235a513ec73e50b32e74364fef4de519da69bc8010/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95e811743b5dfa74f4b227927ed86cbc57cad4df859cb3b643be797914e41794", size = 1140754 }, + { url = "https://files.pythonhosted.org/packages/1b/40/da42522018ca496432ffd02793c3a72a739ac04c3794a4914570c9bb2925/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99376e1370d59bcf6935c933cb9ba64adc29033b7e73f5f7569f3aad86552b22", size = 1196678 }, + { url = "https://files.pythonhosted.org/packages/5c/41/1e59dddaae270ba20187ceb8aa52c75b24ffc09f547233991d5fd822838b/tiktoken-0.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:badb947c32739fb6ddde173e14885fb3de4d32ab9d8c591cbd013c22b4c31dd2", size = 1259283 }, + { url = "https://files.pythonhosted.org/packages/5b/64/b16003419a1d7728d0d8c0d56a4c24325e7b10a21a9dd1fc0f7115c02f0a/tiktoken-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a62d7a25225bafed786a524c1b9f0910a1128f4232615bf3f8257a73aaa3b16", size = 894897 }, ] [[package]] diff --git a/web/app/components/base/markdown.tsx b/web/app/components/base/markdown.tsx index 6ea84a2842..ca2ec72791 100644 --- a/web/app/components/base/markdown.tsx +++ b/web/app/components/base/markdown.tsx @@ -121,7 +121,7 @@ export function PreCode(props: { children: any }) { // visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message // or use the non-minified dev environment for full errors and additional helpful warnings. -const CodeBlock: any = memo(({ inline, className, children, ...props }: any) => { +const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any) => { const { theme } = useTheme() const [isSVG, setIsSVG] = useState(true) const match = /language-(\w+)/.exec(className || '') @@ -258,7 +258,7 @@ const Link = ({ node, children, ...props }: any) => { const { onSend } = useChatContext() const hidden_text = decodeURIComponent(node.properties.href.toString().split('abbr:')[1]) - return onSend?.(hidden_text)} title={node.children[0]?.value}>{node.children[0]?.value} + return onSend?.(hidden_text)} title={node.children[0]?.value || ''}>{node.children[0]?.value || ''} } else { return {children || 'Download'} diff --git a/web/app/components/base/mermaid/index.tsx b/web/app/components/base/mermaid/index.tsx index 8fd8ae8b59..a484261a51 100644 --- a/web/app/components/base/mermaid/index.tsx +++ b/web/app/components/base/mermaid/index.tsx @@ -476,15 +476,15 @@ const Flowchart = React.forwardRef((props: { 'bg-white': currentTheme === Theme.light, 'bg-slate-900': currentTheme === Theme.dark, }), - mermaidDiv: cn('mermaid cursor-pointer h-auto w-full relative', { + mermaidDiv: cn('mermaid relative h-auto w-full cursor-pointer', { 'bg-white': currentTheme === Theme.light, 'bg-slate-900': currentTheme === Theme.dark, }), - errorMessage: cn('py-4 px-[26px]', { + errorMessage: cn('px-[26px] py-4', { 'text-red-500': currentTheme === Theme.light, 'text-red-400': currentTheme === Theme.dark, }), - errorIcon: cn('w-6 h-6', { + errorIcon: cn('h-6 w-6', { 'text-red-500': currentTheme === Theme.light, 'text-red-400': currentTheme === Theme.dark, }), @@ -492,7 +492,7 @@ const Flowchart = React.forwardRef((props: { 'text-gray-700': currentTheme === Theme.light, 'text-gray-300': currentTheme === Theme.dark, }), - themeToggle: cn('flex items-center justify-center w-10 h-10 rounded-full transition-all duration-300 shadow-md backdrop-blur-sm', { + themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', { 'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light, 'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark, }), @@ -501,7 +501,7 @@ const Flowchart = React.forwardRef((props: { // Style classes for look options const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => { return cn( - 'flex items-center justify-center mb-4 w-[calc((100%-8px)/2)] h-8 rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg cursor-pointer system-sm-medium text-text-secondary', + 'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary', look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary', currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300', look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white', @@ -512,7 +512,7 @@ const Flowchart = React.forwardRef((props: {