mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-14 06:05:51 +08:00
Merge remote-tracking branch 'upstream/main' into deploy/dev
This commit is contained in:
commit
c8449df128
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -49,8 +49,8 @@ jobs:
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: |
|
||||
uv run --directory api ruff --version
|
||||
uv run --directory api ruff check ./
|
||||
uv run --directory api ruff format --check ./
|
||||
uv run --directory api ruff check --diff ./
|
||||
uv run --directory api ruff format --check --diff ./
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
7
api/constants/mimetypes.py
Normal file
7
api/constants/mimetypes.py
Normal file
@ -0,0 +1,7 @@
|
||||
# The two constants below should keep in sync.
|
||||
# Default content type for files which have no explicit content type.
|
||||
|
||||
DEFAULT_MIME_TYPE = "application/octet-stream"
|
||||
# Default file extension for files which have no explicit content type, should
|
||||
# correspond to the `DEFAULT_MIME_TYPE` above.
|
||||
DEFAULT_EXTENSION = ".bin"
|
@ -4,7 +4,9 @@ from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from controllers.files import api
|
||||
from controllers.files.error import UnsupportedFileTypeError
|
||||
from core.tools.signature import verify_tool_file_signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
|
||||
|
||||
class ToolFilePreviewApi(Resource):
|
||||
@ -19,17 +21,14 @@ class ToolFilePreviewApi(Resource):
|
||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not ToolFileManager.verify_file(
|
||||
file_id=file_id,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
sign=args["sign"],
|
||||
if not verify_tool_file_signature(
|
||||
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
|
||||
):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
|
||||
tool_file_manager = ToolFileManager(engine=global_db.engine)
|
||||
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||
file_id,
|
||||
)
|
||||
|
||||
|
@ -53,7 +53,7 @@ class PluginUploadFileApi(Resource):
|
||||
raise Forbidden("Invalid request.")
|
||||
|
||||
try:
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
tool_file = ToolFileManager().create_file_by_raw(
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
file_binary=file.read(),
|
||||
|
@ -24,7 +24,7 @@ from core.app.entities.task_entities import (
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
@ -154,7 +154,7 @@ class MessageCycleManage:
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
else:
|
||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
|
||||
return MessageFileStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
|
@ -10,12 +10,12 @@ from core.model_runtime.entities import (
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
@ -130,6 +130,6 @@ def _to_url(f: File, /):
|
||||
# add sign url
|
||||
if f.related_id is None or f.extension is None:
|
||||
raise ValueError("Missing file related_id or extension")
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
@ -4,11 +4,11 @@ from typing import Any, Optional
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.tools.signature import sign_tool_file
|
||||
|
||||
from . import helpers
|
||||
from .constants import FILE_MODEL_IDENTITY
|
||||
from .enums import FileTransferMethod, FileType
|
||||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
class ImageConfig(BaseModel):
|
||||
@ -34,13 +34,21 @@ class FileUploadConfig(BaseModel):
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
# NOTE: dify_model_identity is a special identifier used to distinguish between
|
||||
# new and old data formats during serialization and deserialization.
|
||||
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||
|
||||
id: Optional[str] = None # message file id
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
# If `transfer_method` is `FileTransferMethod.remote_url`, the
|
||||
# `remote_url` attribute must not be `None`.
|
||||
remote_url: Optional[str] = None # remote url
|
||||
# If `transfer_method` is `FileTransferMethod.local_file` or
|
||||
# `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`.
|
||||
#
|
||||
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
|
||||
related_id: Optional[str] = None
|
||||
filename: Optional[str] = None
|
||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||
@ -110,9 +118,7 @@ class File(BaseModel):
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=self.related_id, extension=self.extension
|
||||
)
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
@ -1,12 +1,19 @@
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
||||
|
||||
|
||||
class ToolFileParser:
|
||||
@staticmethod
|
||||
def get_tool_file_manager() -> "ToolFileManager":
|
||||
return cast("ToolFileManager", tool_file_manager["manager"])
|
||||
assert _tool_file_manager_factory is not None
|
||||
return _tool_file_manager_factory()
|
||||
|
||||
|
||||
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
|
||||
global _tool_file_manager_factory
|
||||
_tool_file_manager_factory = factory
|
||||
|
@ -101,7 +101,7 @@ class ModelInstance:
|
||||
@overload
|
||||
def invoke_llm(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from abc import ABC
|
||||
from collections.abc import Mapping, Sequence
|
||||
from enum import Enum, StrEnum
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
@ -60,8 +61,12 @@ class PromptMessageContentType(StrEnum):
|
||||
DOCUMENT = "document"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
pass
|
||||
class PromptMessageContent(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message content.
|
||||
"""
|
||||
|
||||
type: PromptMessageContentType
|
||||
|
||||
|
||||
class TextPromptMessageContent(PromptMessageContent):
|
||||
@ -125,7 +130,16 @@ PromptMessageContentUnionTypes = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
|
||||
PromptMessageContentType.TEXT: TextPromptMessageContent,
|
||||
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
|
||||
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
|
||||
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
|
||||
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
|
||||
}
|
||||
|
||||
|
||||
class PromptMessage(ABC, BaseModel):
|
||||
"""
|
||||
Model class for prompt message.
|
||||
"""
|
||||
@ -142,6 +156,23 @@ class PromptMessage(BaseModel):
|
||||
"""
|
||||
return not self.content
|
||||
|
||||
@field_validator("content", mode="before")
|
||||
@classmethod
|
||||
def validate_content(cls, v):
|
||||
if isinstance(v, list):
|
||||
prompts = []
|
||||
for prompt in v:
|
||||
if isinstance(prompt, PromptMessageContent):
|
||||
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
|
||||
elif isinstance(prompt, dict):
|
||||
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
|
||||
else:
|
||||
raise ValueError(f"invalid prompt message {prompt}")
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
return v
|
||||
|
||||
@field_serializer("content")
|
||||
def serialize_content(
|
||||
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
|
||||
|
@ -24,7 +24,6 @@ from core.model_runtime.errors.invoke import (
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
@ -253,15 +252,3 @@ class AIModel(BaseModel):
|
||||
raise Exception(f"Invalid model parameter rule name {name}")
|
||||
|
||||
return default_parameter_rule
|
||||
|
||||
def _get_num_tokens_by_gpt2(self, text: str) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages by gpt2
|
||||
Some provider models do not provide an interface for obtaining the number of tokens.
|
||||
Here, the gpt2 tokenizer is used to calculate the number of tokens.
|
||||
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
|
||||
|
||||
:param text: plain text of prompt. You need to convert the original message to plain text
|
||||
:return: number of tokens
|
||||
"""
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
@ -2,7 +2,7 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@ -13,14 +13,15 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageTool,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelType,
|
||||
PriceType,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
|
||||
from core.plugin.impl.model import PluginModelClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -238,7 +239,7 @@ class LargeLanguageModel(AIModel):
|
||||
def _invoke_result_generator(
|
||||
self,
|
||||
model: str,
|
||||
result: Generator,
|
||||
result: Generator[LLMResultChunk, None, None],
|
||||
credentials: dict,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
model_parameters: dict,
|
||||
@ -255,11 +256,21 @@ class LargeLanguageModel(AIModel):
|
||||
:return: result generator
|
||||
"""
|
||||
callbacks = callbacks or []
|
||||
assistant_message = AssistantPromptMessage(content="")
|
||||
message_content: list[PromptMessageContentUnionTypes] = []
|
||||
usage = None
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
|
||||
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
|
||||
if not content:
|
||||
return
|
||||
if isinstance(content, list):
|
||||
message_content.extend(content)
|
||||
return
|
||||
if isinstance(content, str):
|
||||
message_content.append(TextPromptMessageContent(data=content))
|
||||
return
|
||||
|
||||
try:
|
||||
for chunk in result:
|
||||
# Following https://github.com/langgenius/dify/issues/17799,
|
||||
@ -281,9 +292,8 @@ class LargeLanguageModel(AIModel):
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
text = convert_llm_result_chunk_to_str(chunk.delta.message.content)
|
||||
current_content = cast(str, assistant_message.content)
|
||||
assistant_message.content = current_content + text
|
||||
_update_message_content(chunk.delta.message.content)
|
||||
|
||||
real_model = chunk.model
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
@ -293,6 +303,7 @@ class LargeLanguageModel(AIModel):
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
assistant_message = AssistantPromptMessage(content=message_content)
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
result=LLMResult(
|
||||
|
@ -30,6 +30,8 @@ class GPT2Tokenizer:
|
||||
@staticmethod
|
||||
def get_encoder() -> Any:
|
||||
global _tokenizer, _lock
|
||||
if _tokenizer is not None:
|
||||
return _tokenizer
|
||||
with _lock:
|
||||
if _tokenizer is None:
|
||||
# Try to use tiktoken to get the tokenizer because it is faster
|
||||
|
@ -1,8 +1,6 @@
|
||||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
|
||||
|
||||
def dump_model(model: BaseModel) -> dict:
|
||||
if hasattr(pydantic, "model_dump"):
|
||||
@ -10,18 +8,3 @@ def dump_model(model: BaseModel) -> dict:
|
||||
return pydantic.model_dump(model) # type: ignore
|
||||
else:
|
||||
return model.model_dump()
|
||||
|
||||
|
||||
def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str:
|
||||
if content is None:
|
||||
message_text = ""
|
||||
elif isinstance(content, str):
|
||||
message_text = content
|
||||
elif isinstance(content, list):
|
||||
# Assuming the list contains PromptMessageContent objects with a "data" attribute
|
||||
message_text = "".join(
|
||||
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
|
||||
)
|
||||
else:
|
||||
message_text = str(content)
|
||||
return message_text
|
||||
|
@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
||||
)
|
||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_tiktoken_encoder(
|
||||
cls: type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
) -> TS:
|
||||
"""Text splitter that uses tiktoken encoder to count length."""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tiktoken python package. "
|
||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
||||
"Please install it with `pip install tiktoken`."
|
||||
)
|
||||
|
||||
if model_name is not None:
|
||||
enc = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
enc = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def _tiktoken_encoder(text: str) -> int:
|
||||
return len(
|
||||
enc.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
disallowed_special=disallowed_special,
|
||||
)
|
||||
)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"encoding_name": encoding_name,
|
||||
"model_name": model_name,
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
|
||||
|
||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||
"""Transform sequence of documents by splitting them."""
|
||||
return self.split_documents(list(documents))
|
||||
|
41
api/core/tools/signature.py
Normal file
41
api/core/tools/signature.py
Normal file
@ -0,0 +1,41 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import os
|
||||
import time
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
sign file to get a temporary url
|
||||
"""
|
||||
base_url = dify_config.FILES_URL
|
||||
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
"""
|
||||
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
@ -9,18 +9,28 @@ from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_database import db as global_db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
|
||||
class ToolFileManager:
|
||||
_engine: Engine
|
||||
|
||||
def __init__(self, engine: Engine | None = None):
|
||||
if engine is None:
|
||||
engine = global_db.engine
|
||||
self._engine = engine
|
||||
|
||||
@staticmethod
|
||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||
"""
|
||||
@ -55,8 +65,8 @@ class ToolFileManager:
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_raw(
|
||||
self,
|
||||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
@ -77,24 +87,25 @@ class ToolFileManager:
|
||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||
storage.save(filepath, file_binary)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
)
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
name=present_filename,
|
||||
size=len(file_binary),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(tool_file)
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
session.refresh(tool_file)
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def create_file_by_url(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
@ -119,24 +130,24 @@ class ToolFileManager:
|
||||
filepath = f"tools/{tenant_id}/{filename}"
|
||||
storage.save(filepath, blob)
|
||||
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
)
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file = ToolFile(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_key=filepath,
|
||||
mimetype=mimetype,
|
||||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
db.session.commit()
|
||||
session.add(tool_file)
|
||||
session.commit()
|
||||
|
||||
return tool_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -144,13 +155,14 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -159,8 +171,7 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -168,33 +179,34 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
message_file: MessageFile | None = (
|
||||
session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
# get tool file id
|
||||
if message_file.url is not None:
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
@ -203,8 +215,7 @@ class ToolFileManager:
|
||||
|
||||
return blob, tool_file.mimetype
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_tool_file_id(tool_file_id: str):
|
||||
def get_file_generator_by_tool_file_id(self, tool_file_id: str):
|
||||
"""
|
||||
get file binary
|
||||
|
||||
@ -212,13 +223,14 @@ class ToolFileManager:
|
||||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
tool_file: ToolFile | None = (
|
||||
session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not tool_file:
|
||||
return None, None
|
||||
@ -229,6 +241,11 @@ class ToolFileManager:
|
||||
|
||||
|
||||
# init tool_file_parser
|
||||
from core.file.tool_file_parser import tool_file_manager
|
||||
from core.file.tool_file_parser import set_tool_file_manager_factory
|
||||
|
||||
tool_file_manager["manager"] = ToolFileManager
|
||||
|
||||
def _factory() -> ToolFileManager:
|
||||
return ToolFileManager()
|
||||
|
||||
|
||||
set_tool_file_manager_factory(_factory)
|
||||
|
@ -31,8 +31,8 @@ class ToolFileMessageTransformer:
|
||||
# try to download image
|
||||
try:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
@ -68,7 +68,8 @@ class ToolFileMessageTransformer:
|
||||
|
||||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
|
@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
||||
mime_type = (
|
||||
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||
)
|
||||
tool_file_manager = ToolFileManager()
|
||||
|
||||
tool_file = ToolFileManager.create_file_by_raw(
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=None,
|
||||
|
@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
|
||||
case "after" | ">":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
|
||||
case "≤" | ">=":
|
||||
case "≤" | "<=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
|
||||
case "≥" | ">=":
|
||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
|
||||
|
@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
|
||||
class FileTypeNotSupportError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str):
|
||||
super().__init__(f"{type_name} type is not supported by this model")
|
||||
|
||||
|
||||
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||
def __init__(self, *, type_name: str) -> None:
|
||||
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||
|
160
api/core/workflow/nodes/llm/file_saver.py
Normal file
160
api/core/workflow/nodes/llm/file_saver.py
Normal file
@ -0,0 +1,160 @@
|
||||
import mimetypes
|
||||
import typing as tp
|
||||
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools.signature import sign_tool_file
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models import db as global_db
|
||||
|
||||
|
||||
class LLMFileSaver(tp.Protocol):
|
||||
"""LLMFileSaver is responsible for save multimodal output returned by
|
||||
LLM.
|
||||
"""
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
"""save_binary_string saves the inline file data returned by LLM.
|
||||
|
||||
Currently (2025-04-30), only some of Google Gemini models will return
|
||||
multimodal output as inline data.
|
||||
|
||||
:param data: the contents of the file
|
||||
:param mime_type: the media type of the file, specified by rfc6838
|
||||
(https://datatracker.ietf.org/doc/html/rfc6838)
|
||||
:param file_type: The file type of the inline file.
|
||||
:param extension_override: Override the auto-detected file extension while saving this file.
|
||||
|
||||
The default value is `None`, which means do not override the file extension and guessing it
|
||||
from the `mime_type` attribute while saving the file.
|
||||
|
||||
Setting it to values other than `None` means override the file's extension, and
|
||||
will bypass the extension guessing saving the file.
|
||||
|
||||
Specially, setting it to empty string (`""`) will leave the file extension empty.
|
||||
|
||||
When it is not `None` or empty string (`""`), it should be a string beginning with a
|
||||
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||
and `tar.gz` are not.
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||
|
||||
Currently (2025-04-30), no model returns multimodel output as a url.
|
||||
|
||||
:param url: the url of the file.
|
||||
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||
|
||||
|
||||
class FileSaverImpl(LLMFileSaver):
|
||||
_engine_factory: EngineFactory
|
||||
_tenant_id: str
|
||||
_user_id: str
|
||||
|
||||
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||
if engine_factory is None:
|
||||
|
||||
def _factory():
|
||||
return global_db.engine
|
||||
|
||||
engine_factory = _factory
|
||||
self._engine_factory = engine_factory
|
||||
self._user_id = user_id
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _get_tool_file_manager(self):
|
||||
return ToolFileManager(engine=self._engine_factory())
|
||||
|
||||
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||
http_response = ssrf_proxy.get(url)
|
||||
http_response.raise_for_status()
|
||||
data = http_response.content
|
||||
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
|
||||
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
|
||||
|
||||
def save_binary_string(
|
||||
self,
|
||||
data: bytes,
|
||||
mime_type: str,
|
||||
file_type: FileType,
|
||||
extension_override: str | None = None,
|
||||
) -> File:
|
||||
tool_file_manager = self._get_tool_file_manager()
|
||||
tool_file = tool_file_manager.create_file_by_raw(
|
||||
user_id=self._user_id,
|
||||
tenant_id=self._tenant_id,
|
||||
# TODO(QuantumGhost): what is conversation id?
|
||||
conversation_id=None,
|
||||
file_binary=data,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
extension_override = _validate_extension_override(extension_override)
|
||||
extension = _get_extension(mime_type, extension_override)
|
||||
url = sign_tool_file(tool_file.id, extension)
|
||||
|
||||
return File(
|
||||
tenant_id=self._tenant_id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename=tool_file.name,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
size=len(data),
|
||||
related_id=tool_file.id,
|
||||
url=url,
|
||||
# TODO(QuantumGhost): how should I set the following key?
|
||||
# What's the difference between `remote_url` and `url`?
|
||||
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
|
||||
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
|
||||
"""get_extension return the extension of file.
|
||||
|
||||
If the `extension_override` parameter is set, this function should honor it and
|
||||
return its value.
|
||||
"""
|
||||
if extension_override is not None:
|
||||
return extension_override
|
||||
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
|
||||
|
||||
|
||||
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
|
||||
"""_extract_content_type_and_extension tries to
|
||||
guess content type of file from url and `Content-Type` header in response.
|
||||
"""
|
||||
if content_type_header:
|
||||
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
|
||||
return content_type_header, extension
|
||||
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
|
||||
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
|
||||
return content_type, extension
|
||||
|
||||
|
||||
def _validate_extension_override(extension_override: str | None) -> str | None:
|
||||
# `extension_override` is allow to be `None or `""`.
|
||||
if extension_override is None:
|
||||
return None
|
||||
if extension_override == "":
|
||||
return ""
|
||||
if not extension_override.startswith("."):
|
||||
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
|
||||
return extension_override
|
@ -1,3 +1,5 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
@ -21,7 +23,7 @@ from core.model_runtime.entities import (
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -95,9 +96,13 @@ from .exc import (
|
||||
TemplateTypeNotSupportError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -106,6 +111,43 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
_node_data_cls = LLMNodeData
|
||||
_node_type = NodeType.LLM
|
||||
|
||||
# Instance attributes specific to LLMNode.
|
||||
# Output variable for file
|
||||
_file_outputs: list["File"]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph: "Graph",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
previous_node_id: Optional[str] = None,
|
||||
thread_pool_id: Optional[str] = None,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=thread_pool_id,
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs: list[File] = []
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
tenant_id=graph_init_params.tenant_id,
|
||||
)
|
||||
self._llm_file_saver = llm_file_saver
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||
"""Process structured output if enabled"""
|
||||
@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
structured_output = process_structured_output(result_text)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output
|
||||
if self._file_outputs is not None:
|
||||
outputs["files"] = self._file_outputs
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("error while executing llm node")
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
||||
def _handle_invoke_result(
|
||||
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[NodeEvent, None, None]:
|
||||
# For blocking mode
|
||||
if isinstance(invoke_result, LLMResult):
|
||||
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
|
||||
|
||||
yield ModelInvokeCompletedEvent(
|
||||
text=message_text,
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
)
|
||||
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||
yield event
|
||||
return
|
||||
|
||||
model = None
|
||||
# For streaming mode
|
||||
model = ""
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
full_text = ""
|
||||
usage = None
|
||||
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
for result in invoke_result:
|
||||
text = convert_llm_result_chunk_to_str(result.delta.message.content)
|
||||
full_text += text
|
||||
contents = result.delta.message.content
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||
full_text_buffer.write(text_part)
|
||||
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
|
||||
|
||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
||||
|
||||
if not model:
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
if len(prompt_messages) == 0:
|
||||
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
||||
# What's the purpose of the line below?
|
||||
prompt_messages = list(result.prompt_messages)
|
||||
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
if not finish_reason and result.delta.finish_reason:
|
||||
if finish_reason is None and result.delta.finish_reason:
|
||||
finish_reason = result.delta.finish_reason
|
||||
|
||||
if not usage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||
|
||||
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
||||
def _image_file_to_markdown(self, file: "File", /):
|
||||
text_chunk = f"})"
|
||||
return text_chunk
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
@ -963,6 +1010,42 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
|
||||
buffer = io.StringIO()
|
||||
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
|
||||
buffer.write(text_part)
|
||||
|
||||
return ModelInvokeCompletedEvent(
|
||||
text=buffer.getvalue(),
|
||||
usage=invoke_result.usage,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
|
||||
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||
|
||||
There are two kinds of multimodal outputs:
|
||||
|
||||
- Inlined data encoded in base64, which would be saved to storage directly.
|
||||
- Remote files referenced by an url, which would be downloaded and then saved to storage.
|
||||
|
||||
Currently, only image files are supported.
|
||||
"""
|
||||
# Inject the saver somehow...
|
||||
_saver = self._llm_file_saver
|
||||
|
||||
# If this
|
||||
if content.url != "":
|
||||
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
|
||||
else:
|
||||
saved_file = _saver.save_binary_string(
|
||||
data=base64.b64decode(content.base64_data),
|
||||
mime_type=content.mime_type,
|
||||
file_type=FileType.IMAGE,
|
||||
)
|
||||
self._file_outputs.append(saved_file)
|
||||
return saved_file
|
||||
|
||||
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||
"""
|
||||
Handle structured output for models with native JSON schema support.
|
||||
@ -1123,6 +1206,41 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||
)
|
||||
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
self,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||
|
||||
If the messages contain non-textual content (e.g., multimedia like images or videos),
|
||||
it will be saved separately, and the corresponding Markdown representation will
|
||||
be yielded to the caller.
|
||||
"""
|
||||
|
||||
# NOTE(QuantumGhost): This function should yield results to the caller immediately
|
||||
# whenever new content or partial content is available. Avoid any intermediate buffering
|
||||
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
|
||||
# if necessary.
|
||||
if contents is None:
|
||||
yield from []
|
||||
return
|
||||
if isinstance(contents, str):
|
||||
yield contents
|
||||
elif isinstance(contents, list):
|
||||
for item in contents:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
yield item.data
|
||||
elif isinstance(item, ImagePromptMessageContent):
|
||||
file = self._save_multimodal_image_output(item)
|
||||
self._file_outputs.append(file)
|
||||
yield self._image_file_to_markdown(file)
|
||||
else:
|
||||
logger.warning("unknown item type encountered, type=%s", type(item))
|
||||
yield str(item)
|
||||
else:
|
||||
logger.warning("unknown contents type encountered, type=%s", type(contents))
|
||||
yield str(contents)
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||
|
@ -10,4 +10,16 @@ POSTGRES_INDEXES_NAMING_CONVENTION = {
|
||||
}
|
||||
|
||||
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
|
||||
|
||||
# ****** IMPORTANT NOTICE ******
|
||||
#
|
||||
# NOTE(QuantumGhost): Avoid directly importing and using `db` in modules outside of the
|
||||
# `controllers` package.
|
||||
#
|
||||
# Instead, import `db` within the `controllers` package and pass it as an argument to
|
||||
# functions or class constructors.
|
||||
#
|
||||
# Directly importing `db` in other modules can make the code more difficult to read, test, and maintain.
|
||||
#
|
||||
# Whenever possible, avoid this pattern in new code.
|
||||
db = SQLAlchemy(metadata=metadata)
|
||||
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.signature import sign_tool_file
|
||||
from services.plugin.plugin_service import PluginService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -23,7 +24,6 @@ from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.file import helpers as file_helpers
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from libs.helper import generate_string
|
||||
from models.base import Base
|
||||
from models.enums import CreatedByRole
|
||||
@ -986,9 +986,7 @@ class Message(db.Model): # type: ignore[name-defined]
|
||||
if not tool_file_id:
|
||||
continue
|
||||
|
||||
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=tool_file_id, extension=extension
|
||||
)
|
||||
sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
elif "file-preview" in url:
|
||||
# get upload file id
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
|
||||
|
@ -263,8 +263,8 @@ class ToolConversationVariables(Base):
|
||||
|
||||
|
||||
class ToolFile(Base):
|
||||
"""
|
||||
store the file created by agent
|
||||
"""This table stores file metadata generated in workflows,
|
||||
not only files created by agent.
|
||||
"""
|
||||
|
||||
__tablename__ = "tool_files"
|
||||
|
@ -77,7 +77,7 @@ dependencies = [
|
||||
"sentry-sdk[flask]~=1.44.1",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.41.0",
|
||||
"tiktoken~=0.8.0",
|
||||
"tiktoken~=0.9.0",
|
||||
"tokenizers~=0.15.0",
|
||||
"transformers~=4.35.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||
|
192
api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py
Normal file
192
api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py
Normal file
@ -0,0 +1,192 @@
|
||||
import uuid
|
||||
from typing import NamedTuple
|
||||
from unittest import mock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from sqlalchemy import Engine
|
||||
|
||||
from core.file import FileTransferMethod, FileType, models
|
||||
from core.helper import ssrf_proxy
|
||||
from core.tools import signature
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.nodes.llm.file_saver import (
|
||||
FileSaverImpl,
|
||||
_extract_content_type_and_extension,
|
||||
_get_extension,
|
||||
_validate_extension_override,
|
||||
)
|
||||
from models import ToolFile
|
||||
|
||||
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||
|
||||
|
||||
def _gen_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class TestFileSaverImpl:
|
||||
def test_save_binary_string(self, monkeypatch):
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
file_type = FileType.IMAGE
|
||||
mime_type = "image/png"
|
||||
mock_signed_url = "https://example.com/image.png"
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||
mocked_engine = mock.MagicMock(spec=Engine)
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||
mocked_sign_file.return_value = mock_signed_url
|
||||
|
||||
storage_file_manager = FileSaverImpl(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
engine_factory=mocked_engine,
|
||||
)
|
||||
|
||||
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||
assert file.tenant_id == tenant_id
|
||||
assert file.type == file_type
|
||||
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||
assert file.extension == ".png"
|
||||
assert file.mime_type == mime_type
|
||||
assert file.size == len(_PNG_DATA)
|
||||
assert file.related_id == mock_tool_file.id
|
||||
|
||||
assert file.generate_url() == mock_signed_url
|
||||
|
||||
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=_PNG_DATA,
|
||||
mimetype=mime_type,
|
||||
)
|
||||
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||
|
||||
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=401,
|
||||
request=mock_request,
|
||||
)
|
||||
file_saver = FileSaverImpl(
|
||||
user_id=_gen_id(),
|
||||
tenant_id=_gen_id(),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_get.assert_called_once_with(_TEST_URL)
|
||||
assert exc.value.response.status_code == 401
|
||||
|
||||
def test_save_remote_url_success(self, monkeypatch):
|
||||
_TEST_URL = "https://example.com/image.png"
|
||||
mime_type = "image/png"
|
||||
user_id = _gen_id()
|
||||
tenant_id = _gen_id()
|
||||
|
||||
mock_request = httpx.Request("GET", _TEST_URL)
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=b"test-data",
|
||||
headers={"Content-Type": mime_type},
|
||||
request=mock_request,
|
||||
)
|
||||
|
||||
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||
mock_tool_file = ToolFile(
|
||||
id=_gen_id(),
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_key="test-file-key",
|
||||
mimetype=mime_type,
|
||||
original_url=None,
|
||||
name=f"{_gen_id()}.png",
|
||||
size=len(_PNG_DATA),
|
||||
)
|
||||
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||
|
||||
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||
mock_save_binary_string.assert_called_once_with(
|
||||
mock_response.content,
|
||||
mime_type,
|
||||
FileType.IMAGE,
|
||||
extension_override=".png",
|
||||
)
|
||||
assert file == mock_tool_file
|
||||
|
||||
|
||||
def test_validate_extension_override():
|
||||
class TestCase(NamedTuple):
|
||||
extension_override: str | None
|
||||
expected: str | None
|
||||
|
||||
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
|
||||
|
||||
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
|
||||
assert valid_ext_override == _validate_extension_override(valid_ext_override)
|
||||
|
||||
for invalid_ext_override in ["png", "tar.gz"]:
|
||||
with pytest.raises(ValueError) as exc:
|
||||
_validate_extension_override(invalid_ext_override)
|
||||
|
||||
|
||||
class TestExtractContentTypeAndExtension:
|
||||
def test_with_both_content_type_and_extension(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_url_with_file_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_response_with_content_type(self):
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
|
||||
assert content_type == "image/png"
|
||||
assert extension == ".png"
|
||||
|
||||
def test_no_content_type_and_no_extension(self):
|
||||
for content_type in [None, ""]:
|
||||
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
|
||||
assert content_type == "application/octet-stream"
|
||||
assert extension == ".bin"
|
||||
|
||||
|
||||
class TestGetExtension:
|
||||
def test_with_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
for override in [".jpg", ""]:
|
||||
extension = _get_extension(mime_type, override)
|
||||
assert extension == override
|
||||
|
||||
def test_without_extension_override(self):
|
||||
mime_type = "image/png"
|
||||
extension = _get_extension(mime_type)
|
||||
assert extension == ".png"
|
@ -1,5 +1,8 @@
|
||||
import base64
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -30,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
VisionConfig,
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from models.enums import UserFrom
|
||||
from models.provider import ProviderType
|
||||
@ -49,8 +53,8 @@ class MockTokenBufferMemory:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node():
|
||||
data = LLMNodeData(
|
||||
def llm_node_data() -> LLMNodeData:
|
||||
return LLMNodeData(
|
||||
title="Test LLM",
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[],
|
||||
@ -64,42 +68,65 @@ def llm_node():
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_init_params() -> GraphInitParams:
|
||||
return GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph() -> Graph:
|
||||
return Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_runtime_state() -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
return GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node(
|
||||
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
|
||||
) -> LLMNode:
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": data.model_dump(),
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config={},
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
),
|
||||
graph=Graph(
|
||||
root_node_id="1",
|
||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||
answer_dependencies={},
|
||||
answer_generate_route={},
|
||||
),
|
||||
end_stream_param=EndStreamParam(
|
||||
end_dependencies={},
|
||||
end_stream_variable_selector_mapping={},
|
||||
),
|
||||
),
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=0,
|
||||
),
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node
|
||||
|
||||
@ -465,3 +492,167 @@ def test_handle_list_messages_basic(llm_node):
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_node_for_multimodal(
|
||||
llm_node_data, graph_init_params, graph, graph_runtime_state
|
||||
) -> tuple[LLMNode, LLMFileSaver]:
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
node = LLMNode(
|
||||
id="1",
|
||||
config={
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
},
|
||||
graph_init_params=graph_init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
class TestLLMNodeSaveMultiModalImageOutput:
|
||||
def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=base64.b64encode(b"test-data").decode(),
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_file = File(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test-file.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=9,
|
||||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||
data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
|
||||
)
|
||||
|
||||
def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/jpg",
|
||||
)
|
||||
mock_file = File(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=str(uuid.uuid4()),
|
||||
filename="test-file.png",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
size=9,
|
||||
)
|
||||
mock_file_saver.save_remote_url.return_value = mock_file
|
||||
file = llm_node._save_multimodal_image_output(content=content)
|
||||
assert llm_node._file_outputs == [mock_file]
|
||||
assert file == mock_file
|
||||
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
|
||||
|
||||
|
||||
def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
|
||||
mock_file = mock.MagicMock(spec=File)
|
||||
mock_file.generate_url.return_value = "https://example.com/image.png"
|
||||
markdown = llm_node._image_file_to_markdown(mock_file)
|
||||
assert markdown == ""
|
||||
|
||||
|
||||
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||
def test_str_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_text_prompt_message_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[TextPromptMessageContent(data="hello world")]
|
||||
)
|
||||
assert list(gen) == ["hello world"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
|
||||
image_raw_data = b"PNG_DATA"
|
||||
image_b64_data = base64.b64encode(image_raw_data).decode()
|
||||
|
||||
mock_saved_file = File(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id="1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
filename="test.png",
|
||||
extension=".png",
|
||||
size=len(image_raw_data),
|
||||
related_id=str(uuid.uuid4()),
|
||||
url="https://example.com/test.png",
|
||||
storage_key="test_storage_key",
|
||||
)
|
||||
mock_file_saver.save_binary_string.return_value = mock_saved_file
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||
[
|
||||
ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=image_b64_data,
|
||||
mime_type="image/png",
|
||||
)
|
||||
]
|
||||
)
|
||||
yielded_strs = list(gen)
|
||||
assert len(yielded_strs) == 1
|
||||
|
||||
# This assertion requires careful handling.
|
||||
# `FILES_URL` settings can vary across environments, which might lead to fragile tests.
|
||||
#
|
||||
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
|
||||
# we verify that the result includes the markdown image syntax and the expected file URL path.
|
||||
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
|
||||
assert yielded_strs[0].startswith("
|
||||
assert expected_file_url_path in yielded_strs[0]
|
||||
assert yielded_strs[0].endswith(")")
|
||||
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||
data=image_raw_data,
|
||||
mime_type="image/png",
|
||||
file_type=FileType.IMAGE,
|
||||
)
|
||||
assert mock_saved_file in llm_node._file_outputs
|
||||
|
||||
def test_unknown_content_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_unknown_item_type(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
|
||||
assert list(gen) == ["frozenset({'hello world'})"]
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
||||
def test_none_content(self, llm_node_for_multimodal):
|
||||
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
|
||||
assert list(gen) == []
|
||||
mock_file_saver.save_binary_string.assert_not_called()
|
||||
mock_file_saver.save_remote_url.assert_not_called()
|
||||
|
30
api/uv.lock
generated
30
api/uv.lock
generated
@ -1398,7 +1398,7 @@ requires-dist = [
|
||||
{ name = "sentry-sdk", extras = ["flask"], specifier = "~=1.44.1" },
|
||||
{ name = "sqlalchemy", specifier = "~=2.0.29" },
|
||||
{ name = "starlette", specifier = "==0.41.0" },
|
||||
{ name = "tiktoken", specifier = "~=0.8.0" },
|
||||
{ name = "tiktoken", specifier = "~=0.9.0" },
|
||||
{ name = "tokenizers", specifier = "~=0.15.0" },
|
||||
{ name = "transformers", specifier = "~=4.35.0" },
|
||||
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
|
||||
@ -5335,26 +5335,26 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "tiktoken"
|
||||
version = "0.8.0"
|
||||
version = "0.9.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "regex" },
|
||||
{ name = "requests" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/37/02/576ff3a6639e755c4f70997b2d315f56d6d71e0d046f4fb64cb81a3fb099/tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2", size = 35107 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ea/cf/756fedf6981e82897f2d570dd25fa597eb3f4459068ae0572d7e888cfd6f/tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d", size = 35991 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f6/1e/ca48e7bfeeccaf76f3a501bd84db1fa28b3c22c9d1a1f41af9fb7579c5f6/tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1", size = 1039700 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/f8/f0101d98d661b34534769c3818f5af631e59c36ac6d07268fbfc89e539ce/tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a", size = 982413 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/3c/2b95391d9bd520a73830469f80a96e3790e6c0a5ac2444f80f20b4b31051/tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d", size = 1144242 },
|
||||
{ url = "https://files.pythonhosted.org/packages/01/c4/c4a4360de845217b6aa9709c15773484b50479f36bb50419c443204e5de9/tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47", size = 1176588 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f8/a3/ef984e976822cd6c2227c854f74d2e60cf4cd6fbfca46251199914746f78/tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419", size = 1237261 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/86/eea2309dc258fb86c7d9b10db536434fc16420feaa3b6113df18b23db7c2/tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99", size = 884537 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/22/34b2e136a6f4af186b6640cbfd6f93400783c9ef6cd550d9eab80628d9de/tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586", size = 1039357 },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/d2/c793cf49c20f5855fd6ce05d080c0537d7418f22c58e71f392d5e8c8dbf7/tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b", size = 982616 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/a1/79846e5ef911cd5d75c844de3fa496a10c91b4b5f550aad695c5df153d72/tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab", size = 1144011 },
|
||||
{ url = "https://files.pythonhosted.org/packages/26/32/e0e3a859136e95c85a572e4806dc58bf1ddf651108ae8b97d5f3ebe1a244/tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04", size = 1175432 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/89/926b66e9025b97e9fbabeaa59048a736fe3c3e4530a204109571104f921c/tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc", size = 1236576 },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/e2/39d4aa02a52bba73b2cd21ba4533c84425ff8786cc63c511d68c8897376e/tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db", size = 883824 },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/ae/4613a59a2a48e761c5161237fc850eb470b4bb93696db89da51b79a871f1/tiktoken-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f32cc56168eac4851109e9b5d327637f15fd662aa30dd79f964b7c39fbadd26e", size = 1065987 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3f/86/55d9d1f5b5a7e1164d0f1538a85529b5fcba2b105f92db3622e5d7de6522/tiktoken-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45556bc41241e5294063508caf901bf92ba52d8ef9222023f83d2483a3055348", size = 1009155 },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/58/01fb6240df083b7c1916d1dcb024e2b761213c95d576e9f780dfb5625a76/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03935988a91d6d3216e2ec7c645afbb3d870b37bcb67ada1943ec48678e7ee33", size = 1142898 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b1/73/41591c525680cd460a6becf56c9b17468d3711b1df242c53d2c7b2183d16/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3d80aad8d2c6b9238fc1a5524542087c52b860b10cbf952429ffb714bc1136", size = 1197535 },
|
||||
{ url = "https://files.pythonhosted.org/packages/7d/7c/1069f25521c8f01a1a182f362e5c8e0337907fae91b368b7da9c3e39b810/tiktoken-0.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b2a21133be05dc116b1d0372af051cd2c6aa1d2188250c9b553f9fa49301b336", size = 1259548 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/07/c67ad1724b8e14e2b4c8cca04b15da158733ac60136879131db05dda7c30/tiktoken-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:11a20e67fdf58b0e2dea7b8654a288e481bb4fc0289d3ad21291f8d0849915fb", size = 893895 },
|
||||
{ url = "https://files.pythonhosted.org/packages/cf/e5/21ff33ecfa2101c1bb0f9b6df750553bd873b7fb532ce2cb276ff40b197f/tiktoken-0.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e88f121c1c22b726649ce67c089b90ddda8b9662545a8aeb03cfef15967ddd03", size = 1065073 },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/03/a95e7b4863ee9ceec1c55983e4cc9558bcfd8f4f80e19c4f8a99642f697d/tiktoken-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a6600660f2f72369acb13a57fb3e212434ed38b045fd8cc6cdd74947b4b5d210", size = 1008075 },
|
||||
{ url = "https://files.pythonhosted.org/packages/40/10/1305bb02a561595088235a513ec73e50b32e74364fef4de519da69bc8010/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95e811743b5dfa74f4b227927ed86cbc57cad4df859cb3b643be797914e41794", size = 1140754 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1b/40/da42522018ca496432ffd02793c3a72a739ac04c3794a4914570c9bb2925/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99376e1370d59bcf6935c933cb9ba64adc29033b7e73f5f7569f3aad86552b22", size = 1196678 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/41/1e59dddaae270ba20187ceb8aa52c75b24ffc09f547233991d5fd822838b/tiktoken-0.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:badb947c32739fb6ddde173e14885fb3de4d32ab9d8c591cbd013c22b4c31dd2", size = 1259283 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/64/b16003419a1d7728d0d8c0d56a4c24325e7b10a21a9dd1fc0f7115c02f0a/tiktoken-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a62d7a25225bafed786a524c1b9f0910a1128f4232615bf3f8257a73aaa3b16", size = 894897 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -121,7 +121,7 @@ export function PreCode(props: { children: any }) {
|
||||
// visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message
|
||||
// or use the non-minified dev environment for full errors and additional helpful warnings.
|
||||
|
||||
const CodeBlock: any = memo(({ inline, className, children, ...props }: any) => {
|
||||
const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any) => {
|
||||
const { theme } = useTheme()
|
||||
const [isSVG, setIsSVG] = useState(true)
|
||||
const match = /language-(\w+)/.exec(className || '')
|
||||
@ -258,7 +258,7 @@ const Link = ({ node, children, ...props }: any) => {
|
||||
const { onSend } = useChatContext()
|
||||
const hidden_text = decodeURIComponent(node.properties.href.toString().split('abbr:')[1])
|
||||
|
||||
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value}>{node.children[0]?.value}</abbr>
|
||||
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value || ''}>{node.children[0]?.value || ''}</abbr>
|
||||
}
|
||||
else {
|
||||
return <a {...props} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a>
|
||||
|
@ -476,15 +476,15 @@ const Flowchart = React.forwardRef((props: {
|
||||
'bg-white': currentTheme === Theme.light,
|
||||
'bg-slate-900': currentTheme === Theme.dark,
|
||||
}),
|
||||
mermaidDiv: cn('mermaid cursor-pointer h-auto w-full relative', {
|
||||
mermaidDiv: cn('mermaid relative h-auto w-full cursor-pointer', {
|
||||
'bg-white': currentTheme === Theme.light,
|
||||
'bg-slate-900': currentTheme === Theme.dark,
|
||||
}),
|
||||
errorMessage: cn('py-4 px-[26px]', {
|
||||
errorMessage: cn('px-[26px] py-4', {
|
||||
'text-red-500': currentTheme === Theme.light,
|
||||
'text-red-400': currentTheme === Theme.dark,
|
||||
}),
|
||||
errorIcon: cn('w-6 h-6', {
|
||||
errorIcon: cn('h-6 w-6', {
|
||||
'text-red-500': currentTheme === Theme.light,
|
||||
'text-red-400': currentTheme === Theme.dark,
|
||||
}),
|
||||
@ -492,7 +492,7 @@ const Flowchart = React.forwardRef((props: {
|
||||
'text-gray-700': currentTheme === Theme.light,
|
||||
'text-gray-300': currentTheme === Theme.dark,
|
||||
}),
|
||||
themeToggle: cn('flex items-center justify-center w-10 h-10 rounded-full transition-all duration-300 shadow-md backdrop-blur-sm', {
|
||||
themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', {
|
||||
'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light,
|
||||
'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark,
|
||||
}),
|
||||
@ -501,7 +501,7 @@ const Flowchart = React.forwardRef((props: {
|
||||
// Style classes for look options
|
||||
const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => {
|
||||
return cn(
|
||||
'flex items-center justify-center mb-4 w-[calc((100%-8px)/2)] h-8 rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg cursor-pointer system-sm-medium text-text-secondary',
|
||||
'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary',
|
||||
look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary',
|
||||
currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300',
|
||||
look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white',
|
||||
@ -512,7 +512,7 @@ const Flowchart = React.forwardRef((props: {
|
||||
<div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
|
||||
<div className={themeClasses.segmented}>
|
||||
<div className="msh-segmented-group">
|
||||
<label className="msh-segmented-item flex items-center space-x-1 m-2 w-[200px]">
|
||||
<label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1">
|
||||
<div
|
||||
key='classic'
|
||||
className={getLookButtonClass('classic')}
|
||||
@ -534,7 +534,7 @@ const Flowchart = React.forwardRef((props: {
|
||||
<div ref={containerRef} style={{ position: 'absolute', visibility: 'hidden', height: 0, overflow: 'hidden' }} />
|
||||
|
||||
{isLoading && !svgCode && (
|
||||
<div className='py-4 px-[26px]'>
|
||||
<div className='px-[26px] py-4'>
|
||||
<LoadingAnim type='text'/>
|
||||
{!isCodeComplete && (
|
||||
<div className="mt-2 text-sm text-gray-500">
|
||||
@ -546,7 +546,7 @@ const Flowchart = React.forwardRef((props: {
|
||||
|
||||
{svgCode && (
|
||||
<div className={themeClasses.mermaidDiv} style={{ objectFit: 'cover' }} onClick={() => setImagePreviewUrl(svgCode)}>
|
||||
<div className="absolute left-2 bottom-2 z-[100]">
|
||||
<div className="absolute bottom-2 left-2 z-[100]">
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
|
@ -15,6 +15,7 @@ import { useProviderContext } from '@/context/provider-context'
|
||||
import GridMask from '@/app/components/base/grid-mask'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import classNames from '@/utils/classnames'
|
||||
import { useGetPricingPageLanguage } from '@/context/i18n'
|
||||
|
||||
type Props = {
|
||||
onCancel: () => void
|
||||
@ -33,6 +34,11 @@ const Pricing: FC<Props> = ({
|
||||
|
||||
useKeyPress(['esc'], onCancel)
|
||||
|
||||
const pricingPageLanguage = useGetPricingPageLanguage()
|
||||
const pricingPageURL = pricingPageLanguage
|
||||
? `https://dify.ai/${pricingPageLanguage}/pricing#plans-and-features`
|
||||
: 'https://dify.ai/pricing#plans-and-features'
|
||||
|
||||
return createPortal(
|
||||
<div
|
||||
className='fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] bg-background-overlay-backdrop p-4 backdrop-blur-[6px]'
|
||||
@ -127,7 +133,7 @@ const Pricing: FC<Props> = ({
|
||||
</div>
|
||||
<div className='flex items-center justify-center py-4'>
|
||||
<div className='flex items-center justify-center gap-x-0.5 rounded-lg px-3 py-2 text-components-button-secondary-accent-text hover:cursor-pointer hover:bg-state-accent-hover'>
|
||||
<Link href='https://dify.ai/pricing#plans-and-features' className='system-sm-medium'>{t('billing.plansCommon.comparePlanAndFeatures')}</Link>
|
||||
<Link href={pricingPageURL} className='system-sm-medium'>{t('billing.plansCommon.comparePlanAndFeatures')}</Link>
|
||||
<RiArrowRightUpLine className='size-4' />
|
||||
</div>
|
||||
</div>
|
||||
|
@ -2,7 +2,7 @@
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Fragment, useState } from 'react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useContext, useContextSelector } from 'use-context-selector'
|
||||
import { useContextSelector } from 'use-context-selector'
|
||||
import {
|
||||
RiAccountCircleLine,
|
||||
RiArrowRightUpLine,
|
||||
@ -23,13 +23,12 @@ import GithubStar from '../github-star'
|
||||
import Support from './support'
|
||||
import Compliance from './compliance'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import I18n from '@/context/i18n'
|
||||
import { useGetDocLanguage } from '@/context/i18n'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { logout } from '@/service/common'
|
||||
import AppContext, { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useModalContext } from '@/context/modal-context'
|
||||
import { LanguagesSupported } from '@/i18n/language'
|
||||
import { LicenseStatus } from '@/types/feature'
|
||||
import { IS_CLOUD_EDITION } from '@/config'
|
||||
import cn from '@/utils/classnames'
|
||||
@ -43,11 +42,11 @@ export default function AppSelector() {
|
||||
const [aboutVisible, setAboutVisible] = useState(false)
|
||||
const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures)
|
||||
|
||||
const { locale } = useContext(I18n)
|
||||
const { t } = useTranslation()
|
||||
const { userProfile, langeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
const { setShowAccountSettingModal } = useModalContext()
|
||||
const docLanguage = useGetDocLanguage()
|
||||
|
||||
const handleLogout = async () => {
|
||||
await logout({
|
||||
@ -132,9 +131,7 @@ export default function AppSelector() {
|
||||
className={cn(itemClassName, 'group justify-between',
|
||||
'data-[active]:bg-state-base-hover',
|
||||
)}
|
||||
href={
|
||||
locale !== LanguagesSupported[1] ? 'https://docs.dify.ai/' : `https://docs.dify.ai/v/${locale.toLowerCase()}/`
|
||||
}
|
||||
href={`https://docs.dify.ai/${docLanguage}/introduction`}
|
||||
target='_blank' rel='noopener noreferrer'>
|
||||
<RiBookOpenLine className='size-4 shrink-0 text-text-tertiary' />
|
||||
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.helpCenter')}</div>
|
||||
|
@ -316,7 +316,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
nodesConnectable={!nodesReadOnly}
|
||||
nodesFocusable={!nodesReadOnly}
|
||||
edgesFocusable={!nodesReadOnly}
|
||||
panOnScroll
|
||||
panOnScroll={false}
|
||||
panOnDrag={controlMode === ControlMode.Hand && !workflowReadOnly}
|
||||
zoomOnPinch={!workflowReadOnly}
|
||||
zoomOnScroll={!workflowReadOnly}
|
||||
|
@ -3,7 +3,7 @@ import {
|
||||
useContext,
|
||||
} from 'use-context-selector'
|
||||
import type { Locale } from '@/i18n'
|
||||
import { getLanguage } from '@/i18n/language'
|
||||
import { getDocLanguage, getLanguage, getPricingPageLanguage } from '@/i18n/language'
|
||||
import { noop } from 'lodash-es'
|
||||
|
||||
type II18NContext = {
|
||||
@ -24,5 +24,15 @@ export const useGetLanguage = () => {
|
||||
|
||||
return getLanguage(locale)
|
||||
}
|
||||
export const useGetDocLanguage = () => {
|
||||
const { locale } = useI18N()
|
||||
|
||||
return getDocLanguage(locale)
|
||||
}
|
||||
export const useGetPricingPageLanguage = () => {
|
||||
const { locale } = useI18N()
|
||||
|
||||
return getPricingPageLanguage(locale)
|
||||
}
|
||||
|
||||
export default I18NContext
|
||||
|
@ -39,6 +39,24 @@ export const getLanguage = (locale: string) => {
|
||||
return LanguagesSupported[0].replace('-', '_')
|
||||
}
|
||||
|
||||
const DOC_LANGUAGE: Record<string, string> = {
|
||||
'zh-Hans': 'zh-hans',
|
||||
'ja-JP': 'ja-jp',
|
||||
'en-US': 'en',
|
||||
}
|
||||
|
||||
export const getDocLanguage = (locale: string) => {
|
||||
return DOC_LANGUAGE[locale] || 'en'
|
||||
}
|
||||
|
||||
const PRICING_PAGE_LANGUAGE: Record<string, string> = {
|
||||
'ja-JP': 'jp',
|
||||
}
|
||||
|
||||
export const getPricingPageLanguage = (locale: string) => {
|
||||
return PRICING_PAGE_LANGUAGE[locale] || ''
|
||||
}
|
||||
|
||||
export const NOTICE_I18N = {
|
||||
title: {
|
||||
en_US: 'Important Notice',
|
||||
|
Loading…
x
Reference in New Issue
Block a user