mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-15 11:15:54 +08:00
Merge remote-tracking branch 'upstream/main' into deploy/dev
This commit is contained in:
commit
c8449df128
4
.github/workflows/style.yml
vendored
4
.github/workflows/style.yml
vendored
@ -49,8 +49,8 @@ jobs:
|
|||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: |
|
run: |
|
||||||
uv run --directory api ruff --version
|
uv run --directory api ruff --version
|
||||||
uv run --directory api ruff check ./
|
uv run --directory api ruff check --diff ./
|
||||||
uv run --directory api ruff format --check ./
|
uv run --directory api ruff format --check --diff ./
|
||||||
|
|
||||||
- name: Dotenv check
|
- name: Dotenv check
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
|
7
api/constants/mimetypes.py
Normal file
7
api/constants/mimetypes.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# The two constants below should keep in sync.
|
||||||
|
# Default content type for files which have no explicit content type.
|
||||||
|
|
||||||
|
DEFAULT_MIME_TYPE = "application/octet-stream"
|
||||||
|
# Default file extension for files which have no explicit content type, should
|
||||||
|
# correspond to the `DEFAULT_MIME_TYPE` above.
|
||||||
|
DEFAULT_EXTENSION = ".bin"
|
@ -4,7 +4,9 @@ from werkzeug.exceptions import Forbidden, NotFound
|
|||||||
|
|
||||||
from controllers.files import api
|
from controllers.files import api
|
||||||
from controllers.files.error import UnsupportedFileTypeError
|
from controllers.files.error import UnsupportedFileTypeError
|
||||||
|
from core.tools.signature import verify_tool_file_signature
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from models import db as global_db
|
||||||
|
|
||||||
|
|
||||||
class ToolFilePreviewApi(Resource):
|
class ToolFilePreviewApi(Resource):
|
||||||
@ -19,17 +21,14 @@ class ToolFilePreviewApi(Resource):
|
|||||||
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
if not verify_tool_file_signature(
|
||||||
if not ToolFileManager.verify_file(
|
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
|
||||||
file_id=file_id,
|
|
||||||
timestamp=args["timestamp"],
|
|
||||||
nonce=args["nonce"],
|
|
||||||
sign=args["sign"],
|
|
||||||
):
|
):
|
||||||
raise Forbidden("Invalid request.")
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
|
tool_file_manager = ToolFileManager(engine=global_db.engine)
|
||||||
|
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
|
||||||
file_id,
|
file_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -53,7 +53,7 @@ class PluginUploadFileApi(Resource):
|
|||||||
raise Forbidden("Invalid request.")
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tool_file = ToolFileManager.create_file_by_raw(
|
tool_file = ToolFileManager().create_file_by_raw(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
file_binary=file.read(),
|
file_binary=file.read(),
|
||||||
|
@ -24,7 +24,7 @@ from core.app.entities.task_entities import (
|
|||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
)
|
)
|
||||||
from core.llm_generator.llm_generator import LLMGenerator
|
from core.llm_generator.llm_generator import LLMGenerator
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.signature import sign_tool_file
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||||
from services.annotation_service import AppAnnotationService
|
from services.annotation_service import AppAnnotationService
|
||||||
@ -154,7 +154,7 @@ class MessageCycleManage:
|
|||||||
if message_file.url.startswith("http"):
|
if message_file.url.startswith("http"):
|
||||||
url = message_file.url
|
url = message_file.url
|
||||||
else:
|
else:
|
||||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||||
|
|
||||||
return MessageFileStreamResponse(
|
return MessageFileStreamResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
@ -10,12 +10,12 @@ from core.model_runtime.entities import (
|
|||||||
VideoPromptMessageContent,
|
VideoPromptMessageContent,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||||
|
from core.tools.signature import sign_tool_file
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
|
|
||||||
from . import helpers
|
from . import helpers
|
||||||
from .enums import FileAttribute
|
from .enums import FileAttribute
|
||||||
from .models import File, FileTransferMethod, FileType
|
from .models import File, FileTransferMethod, FileType
|
||||||
from .tool_file_parser import ToolFileParser
|
|
||||||
|
|
||||||
|
|
||||||
def get_attr(*, file: File, attr: FileAttribute):
|
def get_attr(*, file: File, attr: FileAttribute):
|
||||||
@ -130,6 +130,6 @@ def _to_url(f: File, /):
|
|||||||
# add sign url
|
# add sign url
|
||||||
if f.related_id is None or f.extension is None:
|
if f.related_id is None or f.extension is None:
|
||||||
raise ValueError("Missing file related_id or extension")
|
raise ValueError("Missing file related_id or extension")
|
||||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||||
|
@ -4,11 +4,11 @@ from typing import Any, Optional
|
|||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
|
from core.tools.signature import sign_tool_file
|
||||||
|
|
||||||
from . import helpers
|
from . import helpers
|
||||||
from .constants import FILE_MODEL_IDENTITY
|
from .constants import FILE_MODEL_IDENTITY
|
||||||
from .enums import FileTransferMethod, FileType
|
from .enums import FileTransferMethod, FileType
|
||||||
from .tool_file_parser import ToolFileParser
|
|
||||||
|
|
||||||
|
|
||||||
class ImageConfig(BaseModel):
|
class ImageConfig(BaseModel):
|
||||||
@ -34,13 +34,21 @@ class FileUploadConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class File(BaseModel):
|
class File(BaseModel):
|
||||||
|
# NOTE: dify_model_identity is a special identifier used to distinguish between
|
||||||
|
# new and old data formats during serialization and deserialization.
|
||||||
dify_model_identity: str = FILE_MODEL_IDENTITY
|
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||||
|
|
||||||
id: Optional[str] = None # message file id
|
id: Optional[str] = None # message file id
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
type: FileType
|
type: FileType
|
||||||
transfer_method: FileTransferMethod
|
transfer_method: FileTransferMethod
|
||||||
|
# If `transfer_method` is `FileTransferMethod.remote_url`, the
|
||||||
|
# `remote_url` attribute must not be `None`.
|
||||||
remote_url: Optional[str] = None # remote url
|
remote_url: Optional[str] = None # remote url
|
||||||
|
# If `transfer_method` is `FileTransferMethod.local_file` or
|
||||||
|
# `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`.
|
||||||
|
#
|
||||||
|
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
|
||||||
related_id: Optional[str] = None
|
related_id: Optional[str] = None
|
||||||
filename: Optional[str] = None
|
filename: Optional[str] = None
|
||||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||||
@ -110,9 +118,7 @@ class File(BaseModel):
|
|||||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
assert self.related_id is not None
|
assert self.related_id is not None
|
||||||
assert self.extension is not None
|
assert self.extension is not None
|
||||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||||
tool_file_id=self.related_id, extension=self.extension
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
@ -1,12 +1,19 @@
|
|||||||
from typing import TYPE_CHECKING, Any, cast
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ToolFileParser:
|
class ToolFileParser:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tool_file_manager() -> "ToolFileManager":
|
def get_tool_file_manager() -> "ToolFileManager":
|
||||||
return cast("ToolFileManager", tool_file_manager["manager"])
|
assert _tool_file_manager_factory is not None
|
||||||
|
return _tool_file_manager_factory()
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
|
||||||
|
global _tool_file_manager_factory
|
||||||
|
_tool_file_manager_factory = factory
|
||||||
|
@ -101,7 +101,7 @@ class ModelInstance:
|
|||||||
@overload
|
@overload
|
||||||
def invoke_llm(
|
def invoke_llm(
|
||||||
self,
|
self,
|
||||||
prompt_messages: list[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: Optional[dict] = None,
|
model_parameters: Optional[dict] = None,
|
||||||
tools: Sequence[PromptMessageTool] | None = None,
|
tools: Sequence[PromptMessageTool] | None = None,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from collections.abc import Sequence
|
from abc import ABC
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum, StrEnum
|
from enum import Enum, StrEnum
|
||||||
from typing import Annotated, Any, Literal, Optional, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
@ -60,8 +61,12 @@ class PromptMessageContentType(StrEnum):
|
|||||||
DOCUMENT = "document"
|
DOCUMENT = "document"
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageContent(BaseModel):
|
class PromptMessageContent(ABC, BaseModel):
|
||||||
pass
|
"""
|
||||||
|
Model class for prompt message content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: PromptMessageContentType
|
||||||
|
|
||||||
|
|
||||||
class TextPromptMessageContent(PromptMessageContent):
|
class TextPromptMessageContent(PromptMessageContent):
|
||||||
@ -125,7 +130,16 @@ PromptMessageContentUnionTypes = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PromptMessage(BaseModel):
|
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
|
||||||
|
PromptMessageContentType.TEXT: TextPromptMessageContent,
|
||||||
|
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
|
||||||
|
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
|
||||||
|
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
|
||||||
|
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PromptMessage(ABC, BaseModel):
|
||||||
"""
|
"""
|
||||||
Model class for prompt message.
|
Model class for prompt message.
|
||||||
"""
|
"""
|
||||||
@ -142,6 +156,23 @@ class PromptMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return not self.content
|
return not self.content
|
||||||
|
|
||||||
|
@field_validator("content", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_content(cls, v):
|
||||||
|
if isinstance(v, list):
|
||||||
|
prompts = []
|
||||||
|
for prompt in v:
|
||||||
|
if isinstance(prompt, PromptMessageContent):
|
||||||
|
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
|
||||||
|
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
|
||||||
|
elif isinstance(prompt, dict):
|
||||||
|
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"invalid prompt message {prompt}")
|
||||||
|
prompts.append(prompt)
|
||||||
|
return prompts
|
||||||
|
return v
|
||||||
|
|
||||||
@field_serializer("content")
|
@field_serializer("content")
|
||||||
def serialize_content(
|
def serialize_content(
|
||||||
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
|
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
|
||||||
|
@ -24,7 +24,6 @@ from core.model_runtime.errors.invoke import (
|
|||||||
InvokeRateLimitError,
|
InvokeRateLimitError,
|
||||||
InvokeServerUnavailableError,
|
InvokeServerUnavailableError,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
|
||||||
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
|
||||||
from core.plugin.impl.model import PluginModelClient
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
@ -253,15 +252,3 @@ class AIModel(BaseModel):
|
|||||||
raise Exception(f"Invalid model parameter rule name {name}")
|
raise Exception(f"Invalid model parameter rule name {name}")
|
||||||
|
|
||||||
return default_parameter_rule
|
return default_parameter_rule
|
||||||
|
|
||||||
def _get_num_tokens_by_gpt2(self, text: str) -> int:
|
|
||||||
"""
|
|
||||||
Get number of tokens for given prompt messages by gpt2
|
|
||||||
Some provider models do not provide an interface for obtaining the number of tokens.
|
|
||||||
Here, the gpt2 tokenizer is used to calculate the number of tokens.
|
|
||||||
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
|
|
||||||
|
|
||||||
:param text: plain text of prompt. You need to convert the original message to plain text
|
|
||||||
:return: number of tokens
|
|
||||||
"""
|
|
||||||
return GPT2Tokenizer.get_num_tokens(text)
|
|
||||||
|
@ -2,7 +2,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Sequence
|
from collections.abc import Generator, Sequence
|
||||||
from typing import Optional, Union, cast
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
@ -13,14 +13,15 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
|
|||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContentUnionTypes,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
|
TextPromptMessageContent,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.model_entities import (
|
from core.model_runtime.entities.model_entities import (
|
||||||
ModelType,
|
ModelType,
|
||||||
PriceType,
|
PriceType,
|
||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
|
|
||||||
from core.plugin.impl.model import PluginModelClient
|
from core.plugin.impl.model import PluginModelClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -238,7 +239,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
def _invoke_result_generator(
|
def _invoke_result_generator(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
result: Generator,
|
result: Generator[LLMResultChunk, None, None],
|
||||||
credentials: dict,
|
credentials: dict,
|
||||||
prompt_messages: Sequence[PromptMessage],
|
prompt_messages: Sequence[PromptMessage],
|
||||||
model_parameters: dict,
|
model_parameters: dict,
|
||||||
@ -255,11 +256,21 @@ class LargeLanguageModel(AIModel):
|
|||||||
:return: result generator
|
:return: result generator
|
||||||
"""
|
"""
|
||||||
callbacks = callbacks or []
|
callbacks = callbacks or []
|
||||||
assistant_message = AssistantPromptMessage(content="")
|
message_content: list[PromptMessageContentUnionTypes] = []
|
||||||
usage = None
|
usage = None
|
||||||
system_fingerprint = None
|
system_fingerprint = None
|
||||||
real_model = model
|
real_model = model
|
||||||
|
|
||||||
|
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
if isinstance(content, list):
|
||||||
|
message_content.extend(content)
|
||||||
|
return
|
||||||
|
if isinstance(content, str):
|
||||||
|
message_content.append(TextPromptMessageContent(data=content))
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for chunk in result:
|
for chunk in result:
|
||||||
# Following https://github.com/langgenius/dify/issues/17799,
|
# Following https://github.com/langgenius/dify/issues/17799,
|
||||||
@ -281,9 +292,8 @@ class LargeLanguageModel(AIModel):
|
|||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
text = convert_llm_result_chunk_to_str(chunk.delta.message.content)
|
_update_message_content(chunk.delta.message.content)
|
||||||
current_content = cast(str, assistant_message.content)
|
|
||||||
assistant_message.content = current_content + text
|
|
||||||
real_model = chunk.model
|
real_model = chunk.model
|
||||||
if chunk.delta.usage:
|
if chunk.delta.usage:
|
||||||
usage = chunk.delta.usage
|
usage = chunk.delta.usage
|
||||||
@ -293,6 +303,7 @@ class LargeLanguageModel(AIModel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise self._transform_invoke_error(e)
|
raise self._transform_invoke_error(e)
|
||||||
|
|
||||||
|
assistant_message = AssistantPromptMessage(content=message_content)
|
||||||
self._trigger_after_invoke_callbacks(
|
self._trigger_after_invoke_callbacks(
|
||||||
model=model,
|
model=model,
|
||||||
result=LLMResult(
|
result=LLMResult(
|
||||||
|
@ -30,6 +30,8 @@ class GPT2Tokenizer:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_encoder() -> Any:
|
def get_encoder() -> Any:
|
||||||
global _tokenizer, _lock
|
global _tokenizer, _lock
|
||||||
|
if _tokenizer is not None:
|
||||||
|
return _tokenizer
|
||||||
with _lock:
|
with _lock:
|
||||||
if _tokenizer is None:
|
if _tokenizer is None:
|
||||||
# Try to use tiktoken to get the tokenizer because it is faster
|
# Try to use tiktoken to get the tokenizer because it is faster
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
|
||||||
|
|
||||||
|
|
||||||
def dump_model(model: BaseModel) -> dict:
|
def dump_model(model: BaseModel) -> dict:
|
||||||
if hasattr(pydantic, "model_dump"):
|
if hasattr(pydantic, "model_dump"):
|
||||||
@ -10,18 +8,3 @@ def dump_model(model: BaseModel) -> dict:
|
|||||||
return pydantic.model_dump(model) # type: ignore
|
return pydantic.model_dump(model) # type: ignore
|
||||||
else:
|
else:
|
||||||
return model.model_dump()
|
return model.model_dump()
|
||||||
|
|
||||||
|
|
||||||
def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str:
|
|
||||||
if content is None:
|
|
||||||
message_text = ""
|
|
||||||
elif isinstance(content, str):
|
|
||||||
message_text = content
|
|
||||||
elif isinstance(content, list):
|
|
||||||
# Assuming the list contains PromptMessageContent objects with a "data" attribute
|
|
||||||
message_text = "".join(
|
|
||||||
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
message_text = str(content)
|
|
||||||
return message_text
|
|
||||||
|
@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
|
|||||||
)
|
)
|
||||||
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_tiktoken_encoder(
|
|
||||||
cls: type[TS],
|
|
||||||
encoding_name: str = "gpt2",
|
|
||||||
model_name: Optional[str] = None,
|
|
||||||
allowed_special: Union[Literal["all"], Set[str]] = set(),
|
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> TS:
|
|
||||||
"""Text splitter that uses tiktoken encoder to count length."""
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
"Could not import tiktoken python package. "
|
|
||||||
"This is needed in order to calculate max_tokens_for_prompt. "
|
|
||||||
"Please install it with `pip install tiktoken`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_name is not None:
|
|
||||||
enc = tiktoken.encoding_for_model(model_name)
|
|
||||||
else:
|
|
||||||
enc = tiktoken.get_encoding(encoding_name)
|
|
||||||
|
|
||||||
def _tiktoken_encoder(text: str) -> int:
|
|
||||||
return len(
|
|
||||||
enc.encode(
|
|
||||||
text,
|
|
||||||
allowed_special=allowed_special,
|
|
||||||
disallowed_special=disallowed_special,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if issubclass(cls, TokenTextSplitter):
|
|
||||||
extra_kwargs = {
|
|
||||||
"encoding_name": encoding_name,
|
|
||||||
"model_name": model_name,
|
|
||||||
"allowed_special": allowed_special,
|
|
||||||
"disallowed_special": disallowed_special,
|
|
||||||
}
|
|
||||||
kwargs = {**kwargs, **extra_kwargs}
|
|
||||||
|
|
||||||
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
|
|
||||||
|
|
||||||
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
|
||||||
"""Transform sequence of documents by splitting them."""
|
"""Transform sequence of documents by splitting them."""
|
||||||
return self.split_documents(list(documents))
|
return self.split_documents(list(documents))
|
||||||
|
41
api/core/tools/signature.py
Normal file
41
api/core/tools/signature.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
def sign_tool_file(tool_file_id: str, extension: str) -> str:
|
||||||
|
"""
|
||||||
|
sign file to get a temporary url
|
||||||
|
"""
|
||||||
|
base_url = dify_config.FILES_URL
|
||||||
|
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
"""
|
||||||
|
verify signature
|
||||||
|
"""
|
||||||
|
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
@ -9,18 +9,28 @@ from typing import Optional, Union
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db as global_db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.model import MessageFile
|
from models.model import MessageFile
|
||||||
from models.tools import ToolFile
|
from models.tools import ToolFile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
|
||||||
class ToolFileManager:
|
class ToolFileManager:
|
||||||
|
_engine: Engine
|
||||||
|
|
||||||
|
def __init__(self, engine: Engine | None = None):
|
||||||
|
if engine is None:
|
||||||
|
engine = global_db.engine
|
||||||
|
self._engine = engine
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sign_file(tool_file_id: str, extension: str) -> str:
|
def sign_file(tool_file_id: str, extension: str) -> str:
|
||||||
"""
|
"""
|
||||||
@ -55,8 +65,8 @@ class ToolFileManager:
|
|||||||
current_time = int(time.time())
|
current_time = int(time.time())
|
||||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_file_by_raw(
|
def create_file_by_raw(
|
||||||
|
self,
|
||||||
*,
|
*,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
@ -77,24 +87,25 @@ class ToolFileManager:
|
|||||||
filepath = f"tools/{tenant_id}/{unique_filename}"
|
filepath = f"tools/{tenant_id}/{unique_filename}"
|
||||||
storage.save(filepath, file_binary)
|
storage.save(filepath, file_binary)
|
||||||
|
|
||||||
tool_file = ToolFile(
|
with Session(self._engine, expire_on_commit=False) as session:
|
||||||
user_id=user_id,
|
tool_file = ToolFile(
|
||||||
tenant_id=tenant_id,
|
user_id=user_id,
|
||||||
conversation_id=conversation_id,
|
tenant_id=tenant_id,
|
||||||
file_key=filepath,
|
conversation_id=conversation_id,
|
||||||
mimetype=mimetype,
|
file_key=filepath,
|
||||||
name=present_filename,
|
mimetype=mimetype,
|
||||||
size=len(file_binary),
|
name=present_filename,
|
||||||
)
|
size=len(file_binary),
|
||||||
|
)
|
||||||
|
|
||||||
db.session.add(tool_file)
|
session.add(tool_file)
|
||||||
db.session.commit()
|
session.commit()
|
||||||
db.session.refresh(tool_file)
|
session.refresh(tool_file)
|
||||||
|
|
||||||
return tool_file
|
return tool_file
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_file_by_url(
|
def create_file_by_url(
|
||||||
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
file_url: str,
|
file_url: str,
|
||||||
@ -119,24 +130,24 @@ class ToolFileManager:
|
|||||||
filepath = f"tools/{tenant_id}/{filename}"
|
filepath = f"tools/{tenant_id}/{filename}"
|
||||||
storage.save(filepath, blob)
|
storage.save(filepath, blob)
|
||||||
|
|
||||||
tool_file = ToolFile(
|
with Session(self._engine, expire_on_commit=False) as session:
|
||||||
user_id=user_id,
|
tool_file = ToolFile(
|
||||||
tenant_id=tenant_id,
|
user_id=user_id,
|
||||||
conversation_id=conversation_id,
|
tenant_id=tenant_id,
|
||||||
file_key=filepath,
|
conversation_id=conversation_id,
|
||||||
mimetype=mimetype,
|
file_key=filepath,
|
||||||
original_url=file_url,
|
mimetype=mimetype,
|
||||||
name=filename,
|
original_url=file_url,
|
||||||
size=len(blob),
|
name=filename,
|
||||||
)
|
size=len(blob),
|
||||||
|
)
|
||||||
|
|
||||||
db.session.add(tool_file)
|
session.add(tool_file)
|
||||||
db.session.commit()
|
session.commit()
|
||||||
|
|
||||||
return tool_file
|
return tool_file
|
||||||
|
|
||||||
@staticmethod
|
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
|
||||||
"""
|
"""
|
||||||
get file binary
|
get file binary
|
||||||
|
|
||||||
@ -144,13 +155,14 @@ class ToolFileManager:
|
|||||||
|
|
||||||
:return: the binary of the file, mime type
|
:return: the binary of the file, mime type
|
||||||
"""
|
"""
|
||||||
tool_file: ToolFile | None = (
|
with Session(self._engine, expire_on_commit=False) as session:
|
||||||
db.session.query(ToolFile)
|
tool_file: ToolFile | None = (
|
||||||
.filter(
|
session.query(ToolFile)
|
||||||
ToolFile.id == id,
|
.filter(
|
||||||
|
ToolFile.id == id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_file:
|
if not tool_file:
|
||||||
return None
|
return None
|
||||||
@ -159,8 +171,7 @@ class ToolFileManager:
|
|||||||
|
|
||||||
return blob, tool_file.mimetype
|
return blob, tool_file.mimetype
|
||||||
|
|
||||||
@staticmethod
|
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
|
||||||
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
|
||||||
"""
|
"""
|
||||||
get file binary
|
get file binary
|
||||||
|
|
||||||
@ -168,33 +179,34 @@ class ToolFileManager:
|
|||||||
|
|
||||||
:return: the binary of the file, mime type
|
:return: the binary of the file, mime type
|
||||||
"""
|
"""
|
||||||
message_file: MessageFile | None = (
|
with Session(self._engine, expire_on_commit=False) as session:
|
||||||
db.session.query(MessageFile)
|
message_file: MessageFile | None = (
|
||||||
.filter(
|
session.query(MessageFile)
|
||||||
MessageFile.id == id,
|
.filter(
|
||||||
|
MessageFile.id == id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if message_file is not None
|
# Check if message_file is not None
|
||||||
if message_file is not None:
|
if message_file is not None:
|
||||||
# get tool file id
|
# get tool file id
|
||||||
if message_file.url is not None:
|
if message_file.url is not None:
|
||||||
tool_file_id = message_file.url.split("/")[-1]
|
tool_file_id = message_file.url.split("/")[-1]
|
||||||
# trim extension
|
# trim extension
|
||||||
tool_file_id = tool_file_id.split(".")[0]
|
tool_file_id = tool_file_id.split(".")[0]
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
else:
|
else:
|
||||||
tool_file_id = None
|
tool_file_id = None
|
||||||
else:
|
|
||||||
tool_file_id = None
|
|
||||||
|
|
||||||
tool_file: ToolFile | None = (
|
tool_file: ToolFile | None = (
|
||||||
db.session.query(ToolFile)
|
session.query(ToolFile)
|
||||||
.filter(
|
.filter(
|
||||||
ToolFile.id == tool_file_id,
|
ToolFile.id == tool_file_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_file:
|
if not tool_file:
|
||||||
return None
|
return None
|
||||||
@ -203,8 +215,7 @@ class ToolFileManager:
|
|||||||
|
|
||||||
return blob, tool_file.mimetype
|
return blob, tool_file.mimetype
|
||||||
|
|
||||||
@staticmethod
|
def get_file_generator_by_tool_file_id(self, tool_file_id: str):
|
||||||
def get_file_generator_by_tool_file_id(tool_file_id: str):
|
|
||||||
"""
|
"""
|
||||||
get file binary
|
get file binary
|
||||||
|
|
||||||
@ -212,13 +223,14 @@ class ToolFileManager:
|
|||||||
|
|
||||||
:return: the binary of the file, mime type
|
:return: the binary of the file, mime type
|
||||||
"""
|
"""
|
||||||
tool_file: ToolFile | None = (
|
with Session(self._engine, expire_on_commit=False) as session:
|
||||||
db.session.query(ToolFile)
|
tool_file: ToolFile | None = (
|
||||||
.filter(
|
session.query(ToolFile)
|
||||||
ToolFile.id == tool_file_id,
|
.filter(
|
||||||
|
ToolFile.id == tool_file_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
)
|
)
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_file:
|
if not tool_file:
|
||||||
return None, None
|
return None, None
|
||||||
@ -229,6 +241,11 @@ class ToolFileManager:
|
|||||||
|
|
||||||
|
|
||||||
# init tool_file_parser
|
# init tool_file_parser
|
||||||
from core.file.tool_file_parser import tool_file_manager
|
from core.file.tool_file_parser import set_tool_file_manager_factory
|
||||||
|
|
||||||
tool_file_manager["manager"] = ToolFileManager
|
|
||||||
|
def _factory() -> ToolFileManager:
|
||||||
|
return ToolFileManager()
|
||||||
|
|
||||||
|
|
||||||
|
set_tool_file_manager_factory(_factory)
|
||||||
|
@ -31,8 +31,8 @@ class ToolFileMessageTransformer:
|
|||||||
# try to download image
|
# try to download image
|
||||||
try:
|
try:
|
||||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||||
|
tool_file_manager = ToolFileManager()
|
||||||
file = ToolFileManager.create_file_by_url(
|
file = tool_file_manager.create_file_by_url(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
file_url=message.message.text,
|
file_url=message.message.text,
|
||||||
@ -68,7 +68,8 @@ class ToolFileMessageTransformer:
|
|||||||
|
|
||||||
# FIXME: should do a type check here.
|
# FIXME: should do a type check here.
|
||||||
assert isinstance(message.message.blob, bytes)
|
assert isinstance(message.message.blob, bytes)
|
||||||
file = ToolFileManager.create_file_by_raw(
|
tool_file_manager = ToolFileManager()
|
||||||
|
file = tool_file_manager.create_file_by_raw(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
mime_type = (
|
mime_type = (
|
||||||
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
|
||||||
)
|
)
|
||||||
|
tool_file_manager = ToolFileManager()
|
||||||
|
|
||||||
tool_file = ToolFileManager.create_file_by_raw(
|
tool_file = tool_file_manager.create_file_by_raw(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
|
@ -507,7 +507,7 @@ class KnowledgeRetrievalNode(LLMNode):
|
|||||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
|
||||||
case "after" | ">":
|
case "after" | ">":
|
||||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
|
||||||
case "≤" | ">=":
|
case "≤" | "<=":
|
||||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
|
||||||
case "≥" | ">=":
|
case "≥" | ">=":
|
||||||
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
|
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
|
||||||
|
@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
|
|||||||
class FileTypeNotSupportError(LLMNodeError):
|
class FileTypeNotSupportError(LLMNodeError):
|
||||||
def __init__(self, *, type_name: str):
|
def __init__(self, *, type_name: str):
|
||||||
super().__init__(f"{type_name} type is not supported by this model")
|
super().__init__(f"{type_name} type is not supported by this model")
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedPromptContentTypeError(LLMNodeError):
|
||||||
|
def __init__(self, *, type_name: str) -> None:
|
||||||
|
super().__init__(f"Prompt content type {type_name} is not supported.")
|
||||||
|
160
api/core/workflow/nodes/llm/file_saver.py
Normal file
160
api/core/workflow/nodes/llm/file_saver.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import mimetypes
|
||||||
|
import typing as tp
|
||||||
|
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
|
||||||
|
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
|
||||||
|
from core.file import File, FileTransferMethod, FileType
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from core.tools.signature import sign_tool_file
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from models import db as global_db
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFileSaver(tp.Protocol):
|
||||||
|
"""LLMFileSaver is responsible for save multimodal output returned by
|
||||||
|
LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save_binary_string(
|
||||||
|
self,
|
||||||
|
data: bytes,
|
||||||
|
mime_type: str,
|
||||||
|
file_type: FileType,
|
||||||
|
extension_override: str | None = None,
|
||||||
|
) -> File:
|
||||||
|
"""save_binary_string saves the inline file data returned by LLM.
|
||||||
|
|
||||||
|
Currently (2025-04-30), only some of Google Gemini models will return
|
||||||
|
multimodal output as inline data.
|
||||||
|
|
||||||
|
:param data: the contents of the file
|
||||||
|
:param mime_type: the media type of the file, specified by rfc6838
|
||||||
|
(https://datatracker.ietf.org/doc/html/rfc6838)
|
||||||
|
:param file_type: The file type of the inline file.
|
||||||
|
:param extension_override: Override the auto-detected file extension while saving this file.
|
||||||
|
|
||||||
|
The default value is `None`, which means do not override the file extension and guessing it
|
||||||
|
from the `mime_type` attribute while saving the file.
|
||||||
|
|
||||||
|
Setting it to values other than `None` means override the file's extension, and
|
||||||
|
will bypass the extension guessing saving the file.
|
||||||
|
|
||||||
|
Specially, setting it to empty string (`""`) will leave the file extension empty.
|
||||||
|
|
||||||
|
When it is not `None` or empty string (`""`), it should be a string beginning with a
|
||||||
|
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
|
||||||
|
and `tar.gz` are not.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||||
|
"""save_remote_url saves the file from a remote url returned by LLM.
|
||||||
|
|
||||||
|
Currently (2025-04-30), no model returns multimodel output as a url.
|
||||||
|
|
||||||
|
:param url: the url of the file.
|
||||||
|
:param file_type: the file type of the file, check `FileType` enum for reference.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
|
||||||
|
|
||||||
|
|
||||||
|
class FileSaverImpl(LLMFileSaver):
|
||||||
|
_engine_factory: EngineFactory
|
||||||
|
_tenant_id: str
|
||||||
|
_user_id: str
|
||||||
|
|
||||||
|
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
|
||||||
|
if engine_factory is None:
|
||||||
|
|
||||||
|
def _factory():
|
||||||
|
return global_db.engine
|
||||||
|
|
||||||
|
engine_factory = _factory
|
||||||
|
self._engine_factory = engine_factory
|
||||||
|
self._user_id = user_id
|
||||||
|
self._tenant_id = tenant_id
|
||||||
|
|
||||||
|
def _get_tool_file_manager(self):
|
||||||
|
return ToolFileManager(engine=self._engine_factory())
|
||||||
|
|
||||||
|
def save_remote_url(self, url: str, file_type: FileType) -> File:
|
||||||
|
http_response = ssrf_proxy.get(url)
|
||||||
|
http_response.raise_for_status()
|
||||||
|
data = http_response.content
|
||||||
|
mime_type_from_header = http_response.headers.get("Content-Type")
|
||||||
|
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
|
||||||
|
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
|
||||||
|
|
||||||
|
def save_binary_string(
|
||||||
|
self,
|
||||||
|
data: bytes,
|
||||||
|
mime_type: str,
|
||||||
|
file_type: FileType,
|
||||||
|
extension_override: str | None = None,
|
||||||
|
) -> File:
|
||||||
|
tool_file_manager = self._get_tool_file_manager()
|
||||||
|
tool_file = tool_file_manager.create_file_by_raw(
|
||||||
|
user_id=self._user_id,
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
# TODO(QuantumGhost): what is conversation id?
|
||||||
|
conversation_id=None,
|
||||||
|
file_binary=data,
|
||||||
|
mimetype=mime_type,
|
||||||
|
)
|
||||||
|
extension_override = _validate_extension_override(extension_override)
|
||||||
|
extension = _get_extension(mime_type, extension_override)
|
||||||
|
url = sign_tool_file(tool_file.id, extension)
|
||||||
|
|
||||||
|
return File(
|
||||||
|
tenant_id=self._tenant_id,
|
||||||
|
type=file_type,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
filename=tool_file.name,
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mime_type,
|
||||||
|
size=len(data),
|
||||||
|
related_id=tool_file.id,
|
||||||
|
url=url,
|
||||||
|
# TODO(QuantumGhost): how should I set the following key?
|
||||||
|
# What's the difference between `remote_url` and `url`?
|
||||||
|
# What's the purpose of `storage_key` and `dify_model_identity`?
|
||||||
|
storage_key=tool_file.file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
|
||||||
|
"""get_extension return the extension of file.
|
||||||
|
|
||||||
|
If the `extension_override` parameter is set, this function should honor it and
|
||||||
|
return its value.
|
||||||
|
"""
|
||||||
|
if extension_override is not None:
|
||||||
|
return extension_override
|
||||||
|
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
|
||||||
|
"""_extract_content_type_and_extension tries to
|
||||||
|
guess content type of file from url and `Content-Type` header in response.
|
||||||
|
"""
|
||||||
|
if content_type_header:
|
||||||
|
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
|
||||||
|
return content_type_header, extension
|
||||||
|
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
|
||||||
|
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
|
||||||
|
return content_type, extension
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_extension_override(extension_override: str | None) -> str | None:
|
||||||
|
# `extension_override` is allow to be `None or `""`.
|
||||||
|
if extension_override is None:
|
||||||
|
return None
|
||||||
|
if extension_override == "":
|
||||||
|
return ""
|
||||||
|
if not extension_override.startswith("."):
|
||||||
|
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
|
||||||
|
return extension_override
|
@ -1,3 +1,5 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
@ -21,7 +23,7 @@ from core.model_runtime.entities import (
|
|||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
)
|
)
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessageContentUnionTypes,
|
PromptMessageContentUnionTypes,
|
||||||
@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
|
|
||||||
from core.plugin.entities.plugin import ModelProviderID
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||||
@ -95,9 +96,13 @@ from .exc import (
|
|||||||
TemplateTypeNotSupportError,
|
TemplateTypeNotSupportError,
|
||||||
VariableNotFoundError,
|
VariableNotFoundError,
|
||||||
)
|
)
|
||||||
|
from .file_saver import FileSaverImpl, LLMFileSaver
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||||
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -106,6 +111,43 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
_node_data_cls = LLMNodeData
|
_node_data_cls = LLMNodeData
|
||||||
_node_type = NodeType.LLM
|
_node_type = NodeType.LLM
|
||||||
|
|
||||||
|
# Instance attributes specific to LLMNode.
|
||||||
|
# Output variable for file
|
||||||
|
_file_outputs: list["File"]
|
||||||
|
|
||||||
|
_llm_file_saver: LLMFileSaver
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
id: str,
|
||||||
|
config: Mapping[str, Any],
|
||||||
|
graph_init_params: "GraphInitParams",
|
||||||
|
graph: "Graph",
|
||||||
|
graph_runtime_state: "GraphRuntimeState",
|
||||||
|
previous_node_id: Optional[str] = None,
|
||||||
|
thread_pool_id: Optional[str] = None,
|
||||||
|
*,
|
||||||
|
llm_file_saver: LLMFileSaver | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
id=id,
|
||||||
|
config=config,
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
previous_node_id=previous_node_id,
|
||||||
|
thread_pool_id=thread_pool_id,
|
||||||
|
)
|
||||||
|
# LLM file outputs, used for MultiModal outputs.
|
||||||
|
self._file_outputs: list[File] = []
|
||||||
|
|
||||||
|
if llm_file_saver is None:
|
||||||
|
llm_file_saver = FileSaverImpl(
|
||||||
|
user_id=graph_init_params.user_id,
|
||||||
|
tenant_id=graph_init_params.tenant_id,
|
||||||
|
)
|
||||||
|
self._llm_file_saver = llm_file_saver
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||||
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
|
||||||
"""Process structured output if enabled"""
|
"""Process structured output if enabled"""
|
||||||
@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
structured_output = process_structured_output(result_text)
|
structured_output = process_structured_output(result_text)
|
||||||
if structured_output:
|
if structured_output:
|
||||||
outputs["structured_output"] = structured_output
|
outputs["structured_output"] = structured_output
|
||||||
|
if self._file_outputs is not None:
|
||||||
|
outputs["files"] = self._file_outputs
|
||||||
|
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.exception("error while executing llm node")
|
||||||
yield RunCompletedEvent(
|
yield RunCompletedEvent(
|
||||||
run_result=NodeRunResult(
|
run_result=NodeRunResult(
|
||||||
status=WorkflowNodeExecutionStatus.FAILED,
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return self._handle_invoke_result(invoke_result=invoke_result)
|
return self._handle_invoke_result(invoke_result=invoke_result)
|
||||||
|
|
||||||
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
|
def _handle_invoke_result(
|
||||||
|
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
|
||||||
|
) -> Generator[NodeEvent, None, None]:
|
||||||
|
# For blocking mode
|
||||||
if isinstance(invoke_result, LLMResult):
|
if isinstance(invoke_result, LLMResult):
|
||||||
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
|
event = self._handle_blocking_result(invoke_result=invoke_result)
|
||||||
|
yield event
|
||||||
yield ModelInvokeCompletedEvent(
|
|
||||||
text=message_text,
|
|
||||||
usage=invoke_result.usage,
|
|
||||||
finish_reason=None,
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
model = None
|
# For streaming mode
|
||||||
|
model = ""
|
||||||
prompt_messages: list[PromptMessage] = []
|
prompt_messages: list[PromptMessage] = []
|
||||||
full_text = ""
|
|
||||||
usage = None
|
usage = LLMUsage.empty_usage()
|
||||||
finish_reason = None
|
finish_reason = None
|
||||||
|
full_text_buffer = io.StringIO()
|
||||||
for result in invoke_result:
|
for result in invoke_result:
|
||||||
text = convert_llm_result_chunk_to_str(result.delta.message.content)
|
contents = result.delta.message.content
|
||||||
full_text += text
|
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
|
||||||
|
full_text_buffer.write(text_part)
|
||||||
|
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
|
||||||
|
|
||||||
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
|
# Update the whole metadata
|
||||||
|
if not model and result.model:
|
||||||
if not model:
|
|
||||||
model = result.model
|
model = result.model
|
||||||
|
if len(prompt_messages) == 0:
|
||||||
if not prompt_messages:
|
# TODO(QuantumGhost): it seems that this update has no visable effect.
|
||||||
prompt_messages = result.prompt_messages
|
# What's the purpose of the line below?
|
||||||
|
prompt_messages = list(result.prompt_messages)
|
||||||
if not usage and result.delta.usage:
|
if usage.prompt_tokens == 0 and result.delta.usage:
|
||||||
usage = result.delta.usage
|
usage = result.delta.usage
|
||||||
|
if finish_reason is None and result.delta.finish_reason:
|
||||||
if not finish_reason and result.delta.finish_reason:
|
|
||||||
finish_reason = result.delta.finish_reason
|
finish_reason = result.delta.finish_reason
|
||||||
|
|
||||||
if not usage:
|
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
|
||||||
usage = LLMUsage.empty_usage()
|
|
||||||
|
|
||||||
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
|
def _image_file_to_markdown(self, file: "File", /):
|
||||||
|
text_chunk = f"})"
|
||||||
|
return text_chunk
|
||||||
|
|
||||||
def _transform_chat_messages(
|
def _transform_chat_messages(
|
||||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||||
@ -963,6 +1010,42 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
|
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
|
||||||
|
buffer = io.StringIO()
|
||||||
|
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
|
||||||
|
buffer.write(text_part)
|
||||||
|
|
||||||
|
return ModelInvokeCompletedEvent(
|
||||||
|
text=buffer.getvalue(),
|
||||||
|
usage=invoke_result.usage,
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
|
||||||
|
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
|
||||||
|
|
||||||
|
There are two kinds of multimodal outputs:
|
||||||
|
|
||||||
|
- Inlined data encoded in base64, which would be saved to storage directly.
|
||||||
|
- Remote files referenced by an url, which would be downloaded and then saved to storage.
|
||||||
|
|
||||||
|
Currently, only image files are supported.
|
||||||
|
"""
|
||||||
|
# Inject the saver somehow...
|
||||||
|
_saver = self._llm_file_saver
|
||||||
|
|
||||||
|
# If this
|
||||||
|
if content.url != "":
|
||||||
|
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
|
||||||
|
else:
|
||||||
|
saved_file = _saver.save_binary_string(
|
||||||
|
data=base64.b64decode(content.base64_data),
|
||||||
|
mime_type=content.mime_type,
|
||||||
|
file_type=FileType.IMAGE,
|
||||||
|
)
|
||||||
|
self._file_outputs.append(saved_file)
|
||||||
|
return saved_file
|
||||||
|
|
||||||
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
|
||||||
"""
|
"""
|
||||||
Handle structured output for models with native JSON schema support.
|
Handle structured output for models with native JSON schema support.
|
||||||
@ -1123,6 +1206,41 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||||||
else SupportStructuredOutputStatus.UNSUPPORTED
|
else SupportStructuredOutputStatus.UNSUPPORTED
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||||
|
self,
|
||||||
|
contents: str | list[PromptMessageContentUnionTypes] | None,
|
||||||
|
) -> Generator[str, None, None]:
|
||||||
|
"""Convert intermediate prompt messages into strings and yield them to the caller.
|
||||||
|
|
||||||
|
If the messages contain non-textual content (e.g., multimedia like images or videos),
|
||||||
|
it will be saved separately, and the corresponding Markdown representation will
|
||||||
|
be yielded to the caller.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# NOTE(QuantumGhost): This function should yield results to the caller immediately
|
||||||
|
# whenever new content or partial content is available. Avoid any intermediate buffering
|
||||||
|
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
|
||||||
|
# if necessary.
|
||||||
|
if contents is None:
|
||||||
|
yield from []
|
||||||
|
return
|
||||||
|
if isinstance(contents, str):
|
||||||
|
yield contents
|
||||||
|
elif isinstance(contents, list):
|
||||||
|
for item in contents:
|
||||||
|
if isinstance(item, TextPromptMessageContent):
|
||||||
|
yield item.data
|
||||||
|
elif isinstance(item, ImagePromptMessageContent):
|
||||||
|
file = self._save_multimodal_image_output(item)
|
||||||
|
self._file_outputs.append(file)
|
||||||
|
yield self._image_file_to_markdown(file)
|
||||||
|
else:
|
||||||
|
logger.warning("unknown item type encountered, type=%s", type(item))
|
||||||
|
yield str(item)
|
||||||
|
else:
|
||||||
|
logger.warning("unknown contents type encountered, type=%s", type(contents))
|
||||||
|
yield str(contents)
|
||||||
|
|
||||||
|
|
||||||
def _combine_message_content_with_role(
|
def _combine_message_content_with_role(
|
||||||
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
|
||||||
|
@ -10,4 +10,16 @@ POSTGRES_INDEXES_NAMING_CONVENTION = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
|
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
|
||||||
|
|
||||||
|
# ****** IMPORTANT NOTICE ******
|
||||||
|
#
|
||||||
|
# NOTE(QuantumGhost): Avoid directly importing and using `db` in modules outside of the
|
||||||
|
# `controllers` package.
|
||||||
|
#
|
||||||
|
# Instead, import `db` within the `controllers` package and pass it as an argument to
|
||||||
|
# functions or class constructors.
|
||||||
|
#
|
||||||
|
# Directly importing `db` in other modules can make the code more difficult to read, test, and maintain.
|
||||||
|
#
|
||||||
|
# Whenever possible, avoid this pattern in new code.
|
||||||
db = SQLAlchemy(metadata=metadata)
|
db = SQLAlchemy(metadata=metadata)
|
||||||
|
@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
|||||||
|
|
||||||
from core.plugin.entities.plugin import GenericProviderID
|
from core.plugin.entities.plugin import GenericProviderID
|
||||||
from core.tools.entities.tool_entities import ToolProviderType
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from core.tools.signature import sign_tool_file
|
||||||
from services.plugin.plugin_service import PluginService
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -23,7 +24,6 @@ from configs import dify_config
|
|||||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.file.tool_file_parser import ToolFileParser
|
|
||||||
from libs.helper import generate_string
|
from libs.helper import generate_string
|
||||||
from models.base import Base
|
from models.base import Base
|
||||||
from models.enums import CreatedByRole
|
from models.enums import CreatedByRole
|
||||||
@ -986,9 +986,7 @@ class Message(db.Model): # type: ignore[name-defined]
|
|||||||
if not tool_file_id:
|
if not tool_file_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
|
sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||||
tool_file_id=tool_file_id, extension=extension
|
|
||||||
)
|
|
||||||
elif "file-preview" in url:
|
elif "file-preview" in url:
|
||||||
# get upload file id
|
# get upload file id
|
||||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
|
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
|
||||||
|
@ -263,8 +263,8 @@ class ToolConversationVariables(Base):
|
|||||||
|
|
||||||
|
|
||||||
class ToolFile(Base):
|
class ToolFile(Base):
|
||||||
"""
|
"""This table stores file metadata generated in workflows,
|
||||||
store the file created by agent
|
not only files created by agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__tablename__ = "tool_files"
|
__tablename__ = "tool_files"
|
||||||
|
@ -77,7 +77,7 @@ dependencies = [
|
|||||||
"sentry-sdk[flask]~=1.44.1",
|
"sentry-sdk[flask]~=1.44.1",
|
||||||
"sqlalchemy~=2.0.29",
|
"sqlalchemy~=2.0.29",
|
||||||
"starlette==0.41.0",
|
"starlette==0.41.0",
|
||||||
"tiktoken~=0.8.0",
|
"tiktoken~=0.9.0",
|
||||||
"tokenizers~=0.15.0",
|
"tokenizers~=0.15.0",
|
||||||
"transformers~=4.35.0",
|
"transformers~=4.35.0",
|
||||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||||
|
192
api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py
Normal file
192
api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import NamedTuple
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Engine
|
||||||
|
|
||||||
|
from core.file import FileTransferMethod, FileType, models
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from core.tools import signature
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
from core.workflow.nodes.llm.file_saver import (
|
||||||
|
FileSaverImpl,
|
||||||
|
_extract_content_type_and_extension,
|
||||||
|
_get_extension,
|
||||||
|
_validate_extension_override,
|
||||||
|
)
|
||||||
|
from models import ToolFile
|
||||||
|
|
||||||
|
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_id():
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileSaverImpl:
|
||||||
|
def test_save_binary_string(self, monkeypatch):
|
||||||
|
user_id = _gen_id()
|
||||||
|
tenant_id = _gen_id()
|
||||||
|
file_type = FileType.IMAGE
|
||||||
|
mime_type = "image/png"
|
||||||
|
mock_signed_url = "https://example.com/image.png"
|
||||||
|
mock_tool_file = ToolFile(
|
||||||
|
id=_gen_id(),
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_key="test-file-key",
|
||||||
|
mimetype=mime_type,
|
||||||
|
original_url=None,
|
||||||
|
name=f"{_gen_id()}.png",
|
||||||
|
size=len(_PNG_DATA),
|
||||||
|
)
|
||||||
|
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
|
||||||
|
mocked_engine = mock.MagicMock(spec=Engine)
|
||||||
|
|
||||||
|
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
|
||||||
|
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
|
||||||
|
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
|
||||||
|
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
|
||||||
|
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
|
||||||
|
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
|
||||||
|
mocked_sign_file.return_value = mock_signed_url
|
||||||
|
|
||||||
|
storage_file_manager = FileSaverImpl(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
engine_factory=mocked_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
|
||||||
|
assert file.tenant_id == tenant_id
|
||||||
|
assert file.type == file_type
|
||||||
|
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||||
|
assert file.extension == ".png"
|
||||||
|
assert file.mime_type == mime_type
|
||||||
|
assert file.size == len(_PNG_DATA)
|
||||||
|
assert file.related_id == mock_tool_file.id
|
||||||
|
|
||||||
|
assert file.generate_url() == mock_signed_url
|
||||||
|
|
||||||
|
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_binary=_PNG_DATA,
|
||||||
|
mimetype=mime_type,
|
||||||
|
)
|
||||||
|
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
|
||||||
|
|
||||||
|
def test_save_remote_url_request_failed(self, monkeypatch):
|
||||||
|
_TEST_URL = "https://example.com/image.png"
|
||||||
|
mock_request = httpx.Request("GET", _TEST_URL)
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=401,
|
||||||
|
request=mock_request,
|
||||||
|
)
|
||||||
|
file_saver = FileSaverImpl(
|
||||||
|
user_id=_gen_id(),
|
||||||
|
tenant_id=_gen_id(),
|
||||||
|
)
|
||||||
|
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||||
|
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||||
|
|
||||||
|
with pytest.raises(httpx.HTTPStatusError) as exc:
|
||||||
|
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||||
|
mock_get.assert_called_once_with(_TEST_URL)
|
||||||
|
assert exc.value.response.status_code == 401
|
||||||
|
|
||||||
|
def test_save_remote_url_success(self, monkeypatch):
|
||||||
|
_TEST_URL = "https://example.com/image.png"
|
||||||
|
mime_type = "image/png"
|
||||||
|
user_id = _gen_id()
|
||||||
|
tenant_id = _gen_id()
|
||||||
|
|
||||||
|
mock_request = httpx.Request("GET", _TEST_URL)
|
||||||
|
mock_response = httpx.Response(
|
||||||
|
status_code=200,
|
||||||
|
content=b"test-data",
|
||||||
|
headers={"Content-Type": mime_type},
|
||||||
|
request=mock_request,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
|
||||||
|
mock_tool_file = ToolFile(
|
||||||
|
id=_gen_id(),
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
file_key="test-file-key",
|
||||||
|
mimetype=mime_type,
|
||||||
|
original_url=None,
|
||||||
|
name=f"{_gen_id()}.png",
|
||||||
|
size=len(_PNG_DATA),
|
||||||
|
)
|
||||||
|
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
|
||||||
|
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
|
||||||
|
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
|
||||||
|
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
|
||||||
|
|
||||||
|
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
|
||||||
|
mock_save_binary_string.assert_called_once_with(
|
||||||
|
mock_response.content,
|
||||||
|
mime_type,
|
||||||
|
FileType.IMAGE,
|
||||||
|
extension_override=".png",
|
||||||
|
)
|
||||||
|
assert file == mock_tool_file
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_extension_override():
|
||||||
|
class TestCase(NamedTuple):
|
||||||
|
extension_override: str | None
|
||||||
|
expected: str | None
|
||||||
|
|
||||||
|
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
|
||||||
|
|
||||||
|
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
|
||||||
|
assert valid_ext_override == _validate_extension_override(valid_ext_override)
|
||||||
|
|
||||||
|
for invalid_ext_override in ["png", "tar.gz"]:
|
||||||
|
with pytest.raises(ValueError) as exc:
|
||||||
|
_validate_extension_override(invalid_ext_override)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractContentTypeAndExtension:
|
||||||
|
def test_with_both_content_type_and_extension(self):
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
|
||||||
|
assert content_type == "image/png"
|
||||||
|
assert extension == ".png"
|
||||||
|
|
||||||
|
def test_url_with_file_extension(self):
|
||||||
|
for content_type in [None, ""]:
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
|
||||||
|
assert content_type == "image/png"
|
||||||
|
assert extension == ".png"
|
||||||
|
|
||||||
|
def test_response_with_content_type(self):
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
|
||||||
|
assert content_type == "image/png"
|
||||||
|
assert extension == ".png"
|
||||||
|
|
||||||
|
def test_no_content_type_and_no_extension(self):
|
||||||
|
for content_type in [None, ""]:
|
||||||
|
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
|
||||||
|
assert content_type == "application/octet-stream"
|
||||||
|
assert extension == ".bin"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetExtension:
|
||||||
|
def test_with_extension_override(self):
|
||||||
|
mime_type = "image/png"
|
||||||
|
for override in [".jpg", ""]:
|
||||||
|
extension = _get_extension(mime_type, override)
|
||||||
|
assert extension == override
|
||||||
|
|
||||||
|
def test_without_extension_override(self):
|
||||||
|
mime_type = "image/png"
|
||||||
|
extension = _get_extension(mime_type)
|
||||||
|
assert extension == ".png"
|
@ -1,5 +1,8 @@
|
|||||||
|
import base64
|
||||||
|
import uuid
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -30,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
|
|||||||
VisionConfig,
|
VisionConfig,
|
||||||
VisionConfigOptions,
|
VisionConfigOptions,
|
||||||
)
|
)
|
||||||
|
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||||
from core.workflow.nodes.llm.node import LLMNode
|
from core.workflow.nodes.llm.node import LLMNode
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.provider import ProviderType
|
from models.provider import ProviderType
|
||||||
@ -49,8 +53,8 @@ class MockTokenBufferMemory:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def llm_node():
|
def llm_node_data() -> LLMNodeData:
|
||||||
data = LLMNodeData(
|
return LLMNodeData(
|
||||||
title="Test LLM",
|
title="Test LLM",
|
||||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||||
prompt_template=[],
|
prompt_template=[],
|
||||||
@ -64,42 +68,65 @@ def llm_node():
|
|||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_init_params() -> GraphInitParams:
|
||||||
|
return GraphInitParams(
|
||||||
|
tenant_id="1",
|
||||||
|
app_id="1",
|
||||||
|
workflow_type=WorkflowType.WORKFLOW,
|
||||||
|
workflow_id="1",
|
||||||
|
graph_config={},
|
||||||
|
user_id="1",
|
||||||
|
user_from=UserFrom.ACCOUNT,
|
||||||
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
|
call_depth=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph() -> Graph:
|
||||||
|
return Graph(
|
||||||
|
root_node_id="1",
|
||||||
|
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
||||||
|
answer_dependencies={},
|
||||||
|
answer_generate_route={},
|
||||||
|
),
|
||||||
|
end_stream_param=EndStreamParam(
|
||||||
|
end_dependencies={},
|
||||||
|
end_stream_variable_selector_mapping={},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def graph_runtime_state() -> GraphRuntimeState:
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables={},
|
system_variables={},
|
||||||
user_inputs={},
|
user_inputs={},
|
||||||
)
|
)
|
||||||
|
return GraphRuntimeState(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
start_at=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_node(
|
||||||
|
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
|
||||||
|
) -> LLMNode:
|
||||||
|
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
node = LLMNode(
|
node = LLMNode(
|
||||||
id="1",
|
id="1",
|
||||||
config={
|
config={
|
||||||
"id": "1",
|
"id": "1",
|
||||||
"data": data.model_dump(),
|
"data": llm_node_data.model_dump(),
|
||||||
},
|
},
|
||||||
graph_init_params=GraphInitParams(
|
graph_init_params=graph_init_params,
|
||||||
tenant_id="1",
|
graph=graph,
|
||||||
app_id="1",
|
graph_runtime_state=graph_runtime_state,
|
||||||
workflow_type=WorkflowType.WORKFLOW,
|
llm_file_saver=mock_file_saver,
|
||||||
workflow_id="1",
|
|
||||||
graph_config={},
|
|
||||||
user_id="1",
|
|
||||||
user_from=UserFrom.ACCOUNT,
|
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
|
||||||
call_depth=0,
|
|
||||||
),
|
|
||||||
graph=Graph(
|
|
||||||
root_node_id="1",
|
|
||||||
answer_stream_generate_routes=AnswerStreamGenerateRoute(
|
|
||||||
answer_dependencies={},
|
|
||||||
answer_generate_route={},
|
|
||||||
),
|
|
||||||
end_stream_param=EndStreamParam(
|
|
||||||
end_dependencies={},
|
|
||||||
end_stream_variable_selector_mapping={},
|
|
||||||
),
|
|
||||||
),
|
|
||||||
graph_runtime_state=GraphRuntimeState(
|
|
||||||
variable_pool=variable_pool,
|
|
||||||
start_at=0,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
return node
|
return node
|
||||||
|
|
||||||
@ -465,3 +492,167 @@ def test_handle_list_messages_basic(llm_node):
|
|||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert isinstance(result[0], UserPromptMessage)
|
assert isinstance(result[0], UserPromptMessage)
|
||||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_node_for_multimodal(
|
||||||
|
llm_node_data, graph_init_params, graph, graph_runtime_state
|
||||||
|
) -> tuple[LLMNode, LLMFileSaver]:
|
||||||
|
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||||
|
node = LLMNode(
|
||||||
|
id="1",
|
||||||
|
config={
|
||||||
|
"id": "1",
|
||||||
|
"data": llm_node_data.model_dump(),
|
||||||
|
},
|
||||||
|
graph_init_params=graph_init_params,
|
||||||
|
graph=graph,
|
||||||
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
llm_file_saver=mock_file_saver,
|
||||||
|
)
|
||||||
|
return node, mock_file_saver
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMNodeSaveMultiModalImageOutput:
|
||||||
|
def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
format="png",
|
||||||
|
base64_data=base64.b64encode(b"test-data").decode(),
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
mock_file = File(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id="1",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
related_id=str(uuid.uuid4()),
|
||||||
|
filename="test-file.png",
|
||||||
|
extension=".png",
|
||||||
|
mime_type="image/png",
|
||||||
|
size=9,
|
||||||
|
)
|
||||||
|
mock_file_saver.save_binary_string.return_value = mock_file
|
||||||
|
file = llm_node._save_multimodal_image_output(content=content)
|
||||||
|
assert llm_node._file_outputs == [mock_file]
|
||||||
|
assert file == mock_file
|
||||||
|
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||||
|
data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
content = ImagePromptMessageContent(
|
||||||
|
format="png",
|
||||||
|
url="https://example.com/image.png",
|
||||||
|
mime_type="image/jpg",
|
||||||
|
)
|
||||||
|
mock_file = File(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id="1",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
related_id=str(uuid.uuid4()),
|
||||||
|
filename="test-file.png",
|
||||||
|
extension=".png",
|
||||||
|
mime_type="image/png",
|
||||||
|
size=9,
|
||||||
|
)
|
||||||
|
mock_file_saver.save_remote_url.return_value = mock_file
|
||||||
|
file = llm_node._save_multimodal_image_output(content=content)
|
||||||
|
assert llm_node._file_outputs == [mock_file]
|
||||||
|
assert file == mock_file
|
||||||
|
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
|
||||||
|
mock_file = mock.MagicMock(spec=File)
|
||||||
|
mock_file.generate_url.return_value = "https://example.com/image.png"
|
||||||
|
markdown = llm_node._image_file_to_markdown(mock_file)
|
||||||
|
assert markdown == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
|
||||||
|
def test_str_content(self, llm_node_for_multimodal):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
|
||||||
|
assert list(gen) == ["hello world"]
|
||||||
|
mock_file_saver.save_binary_string.assert_not_called()
|
||||||
|
mock_file_saver.save_remote_url.assert_not_called()
|
||||||
|
|
||||||
|
def test_text_prompt_message_content(self, llm_node_for_multimodal):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||||
|
[TextPromptMessageContent(data="hello world")]
|
||||||
|
)
|
||||||
|
assert list(gen) == ["hello world"]
|
||||||
|
mock_file_saver.save_binary_string.assert_not_called()
|
||||||
|
mock_file_saver.save_remote_url.assert_not_called()
|
||||||
|
|
||||||
|
def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
|
||||||
|
image_raw_data = b"PNG_DATA"
|
||||||
|
image_b64_data = base64.b64encode(image_raw_data).decode()
|
||||||
|
|
||||||
|
mock_saved_file = File(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
tenant_id="1",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
filename="test.png",
|
||||||
|
extension=".png",
|
||||||
|
size=len(image_raw_data),
|
||||||
|
related_id=str(uuid.uuid4()),
|
||||||
|
url="https://example.com/test.png",
|
||||||
|
storage_key="test_storage_key",
|
||||||
|
)
|
||||||
|
mock_file_saver.save_binary_string.return_value = mock_saved_file
|
||||||
|
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
|
||||||
|
[
|
||||||
|
ImagePromptMessageContent(
|
||||||
|
format="png",
|
||||||
|
base64_data=image_b64_data,
|
||||||
|
mime_type="image/png",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
yielded_strs = list(gen)
|
||||||
|
assert len(yielded_strs) == 1
|
||||||
|
|
||||||
|
# This assertion requires careful handling.
|
||||||
|
# `FILES_URL` settings can vary across environments, which might lead to fragile tests.
|
||||||
|
#
|
||||||
|
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
|
||||||
|
# we verify that the result includes the markdown image syntax and the expected file URL path.
|
||||||
|
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
|
||||||
|
assert yielded_strs[0].startswith("
|
||||||
|
assert expected_file_url_path in yielded_strs[0]
|
||||||
|
assert yielded_strs[0].endswith(")")
|
||||||
|
mock_file_saver.save_binary_string.assert_called_once_with(
|
||||||
|
data=image_raw_data,
|
||||||
|
mime_type="image/png",
|
||||||
|
file_type=FileType.IMAGE,
|
||||||
|
)
|
||||||
|
assert mock_saved_file in llm_node._file_outputs
|
||||||
|
|
||||||
|
def test_unknown_content_type(self, llm_node_for_multimodal):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
|
||||||
|
assert list(gen) == ["frozenset({'hello world'})"]
|
||||||
|
mock_file_saver.save_binary_string.assert_not_called()
|
||||||
|
mock_file_saver.save_remote_url.assert_not_called()
|
||||||
|
|
||||||
|
def test_unknown_item_type(self, llm_node_for_multimodal):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
|
||||||
|
assert list(gen) == ["frozenset({'hello world'})"]
|
||||||
|
mock_file_saver.save_binary_string.assert_not_called()
|
||||||
|
mock_file_saver.save_remote_url.assert_not_called()
|
||||||
|
|
||||||
|
def test_none_content(self, llm_node_for_multimodal):
|
||||||
|
llm_node, mock_file_saver = llm_node_for_multimodal
|
||||||
|
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
|
||||||
|
assert list(gen) == []
|
||||||
|
mock_file_saver.save_binary_string.assert_not_called()
|
||||||
|
mock_file_saver.save_remote_url.assert_not_called()
|
||||||
|
30
api/uv.lock
generated
30
api/uv.lock
generated
@ -1398,7 +1398,7 @@ requires-dist = [
|
|||||||
{ name = "sentry-sdk", extras = ["flask"], specifier = "~=1.44.1" },
|
{ name = "sentry-sdk", extras = ["flask"], specifier = "~=1.44.1" },
|
||||||
{ name = "sqlalchemy", specifier = "~=2.0.29" },
|
{ name = "sqlalchemy", specifier = "~=2.0.29" },
|
||||||
{ name = "starlette", specifier = "==0.41.0" },
|
{ name = "starlette", specifier = "==0.41.0" },
|
||||||
{ name = "tiktoken", specifier = "~=0.8.0" },
|
{ name = "tiktoken", specifier = "~=0.9.0" },
|
||||||
{ name = "tokenizers", specifier = "~=0.15.0" },
|
{ name = "tokenizers", specifier = "~=0.15.0" },
|
||||||
{ name = "transformers", specifier = "~=4.35.0" },
|
{ name = "transformers", specifier = "~=4.35.0" },
|
||||||
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
|
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
|
||||||
@ -5335,26 +5335,26 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tiktoken"
|
name = "tiktoken"
|
||||||
version = "0.8.0"
|
version = "0.9.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "regex" },
|
{ name = "regex" },
|
||||||
{ name = "requests" },
|
{ name = "requests" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/37/02/576ff3a6639e755c4f70997b2d315f56d6d71e0d046f4fb64cb81a3fb099/tiktoken-0.8.0.tar.gz", hash = "sha256:9ccbb2740f24542534369c5635cfd9b2b3c2490754a78ac8831d99f89f94eeb2", size = 35107 }
|
sdist = { url = "https://files.pythonhosted.org/packages/ea/cf/756fedf6981e82897f2d570dd25fa597eb3f4459068ae0572d7e888cfd6f/tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d", size = 35991 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/f6/1e/ca48e7bfeeccaf76f3a501bd84db1fa28b3c22c9d1a1f41af9fb7579c5f6/tiktoken-0.8.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d622d8011e6d6f239297efa42a2657043aaed06c4f68833550cac9e9bc723ef1", size = 1039700 },
|
{ url = "https://files.pythonhosted.org/packages/4d/ae/4613a59a2a48e761c5161237fc850eb470b4bb93696db89da51b79a871f1/tiktoken-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f32cc56168eac4851109e9b5d327637f15fd662aa30dd79f964b7c39fbadd26e", size = 1065987 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/8c/f8/f0101d98d661b34534769c3818f5af631e59c36ac6d07268fbfc89e539ce/tiktoken-0.8.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2efaf6199717b4485031b4d6edb94075e4d79177a172f38dd934d911b588d54a", size = 982413 },
|
{ url = "https://files.pythonhosted.org/packages/3f/86/55d9d1f5b5a7e1164d0f1538a85529b5fcba2b105f92db3622e5d7de6522/tiktoken-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45556bc41241e5294063508caf901bf92ba52d8ef9222023f83d2483a3055348", size = 1009155 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/ac/3c/2b95391d9bd520a73830469f80a96e3790e6c0a5ac2444f80f20b4b31051/tiktoken-0.8.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5637e425ce1fc49cf716d88df3092048359a4b3bbb7da762840426e937ada06d", size = 1144242 },
|
{ url = "https://files.pythonhosted.org/packages/03/58/01fb6240df083b7c1916d1dcb024e2b761213c95d576e9f780dfb5625a76/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03935988a91d6d3216e2ec7c645afbb3d870b37bcb67ada1943ec48678e7ee33", size = 1142898 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/01/c4/c4a4360de845217b6aa9709c15773484b50479f36bb50419c443204e5de9/tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fb0e352d1dbe15aba082883058b3cce9e48d33101bdaac1eccf66424feb5b47", size = 1176588 },
|
{ url = "https://files.pythonhosted.org/packages/b1/73/41591c525680cd460a6becf56c9b17468d3711b1df242c53d2c7b2183d16/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3d80aad8d2c6b9238fc1a5524542087c52b860b10cbf952429ffb714bc1136", size = 1197535 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/f8/a3/ef984e976822cd6c2227c854f74d2e60cf4cd6fbfca46251199914746f78/tiktoken-0.8.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:56edfefe896c8f10aba372ab5706b9e3558e78db39dd497c940b47bf228bc419", size = 1237261 },
|
{ url = "https://files.pythonhosted.org/packages/7d/7c/1069f25521c8f01a1a182f362e5c8e0337907fae91b368b7da9c3e39b810/tiktoken-0.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b2a21133be05dc116b1d0372af051cd2c6aa1d2188250c9b553f9fa49301b336", size = 1259548 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/1e/86/eea2309dc258fb86c7d9b10db536434fc16420feaa3b6113df18b23db7c2/tiktoken-0.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:326624128590def898775b722ccc327e90b073714227175ea8febbc920ac0a99", size = 884537 },
|
{ url = "https://files.pythonhosted.org/packages/6f/07/c67ad1724b8e14e2b4c8cca04b15da158733ac60136879131db05dda7c30/tiktoken-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:11a20e67fdf58b0e2dea7b8654a288e481bb4fc0289d3ad21291f8d0849915fb", size = 893895 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c1/22/34b2e136a6f4af186b6640cbfd6f93400783c9ef6cd550d9eab80628d9de/tiktoken-0.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:881839cfeae051b3628d9823b2e56b5cc93a9e2efb435f4cf15f17dc45f21586", size = 1039357 },
|
{ url = "https://files.pythonhosted.org/packages/cf/e5/21ff33ecfa2101c1bb0f9b6df750553bd873b7fb532ce2cb276ff40b197f/tiktoken-0.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e88f121c1c22b726649ce67c089b90ddda8b9662545a8aeb03cfef15967ddd03", size = 1065073 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/04/d2/c793cf49c20f5855fd6ce05d080c0537d7418f22c58e71f392d5e8c8dbf7/tiktoken-0.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fe9399bdc3f29d428f16a2f86c3c8ec20be3eac5f53693ce4980371c3245729b", size = 982616 },
|
{ url = "https://files.pythonhosted.org/packages/8e/03/a95e7b4863ee9ceec1c55983e4cc9558bcfd8f4f80e19c4f8a99642f697d/tiktoken-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a6600660f2f72369acb13a57fb3e212434ed38b045fd8cc6cdd74947b4b5d210", size = 1008075 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/b3/a1/79846e5ef911cd5d75c844de3fa496a10c91b4b5f550aad695c5df153d72/tiktoken-0.8.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a58deb7075d5b69237a3ff4bb51a726670419db6ea62bdcd8bd80c78497d7ab", size = 1144011 },
|
{ url = "https://files.pythonhosted.org/packages/40/10/1305bb02a561595088235a513ec73e50b32e74364fef4de519da69bc8010/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95e811743b5dfa74f4b227927ed86cbc57cad4df859cb3b643be797914e41794", size = 1140754 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/26/32/e0e3a859136e95c85a572e4806dc58bf1ddf651108ae8b97d5f3ebe1a244/tiktoken-0.8.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2908c0d043a7d03ebd80347266b0e58440bdef5564f84f4d29fb235b5df3b04", size = 1175432 },
|
{ url = "https://files.pythonhosted.org/packages/1b/40/da42522018ca496432ffd02793c3a72a739ac04c3794a4914570c9bb2925/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99376e1370d59bcf6935c933cb9ba64adc29033b7e73f5f7569f3aad86552b22", size = 1196678 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/c7/89/926b66e9025b97e9fbabeaa59048a736fe3c3e4530a204109571104f921c/tiktoken-0.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:294440d21a2a51e12d4238e68a5972095534fe9878be57d905c476017bff99fc", size = 1236576 },
|
{ url = "https://files.pythonhosted.org/packages/5c/41/1e59dddaae270ba20187ceb8aa52c75b24ffc09f547233991d5fd822838b/tiktoken-0.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:badb947c32739fb6ddde173e14885fb3de4d32ab9d8c591cbd013c22b4c31dd2", size = 1259283 },
|
||||||
{ url = "https://files.pythonhosted.org/packages/45/e2/39d4aa02a52bba73b2cd21ba4533c84425ff8786cc63c511d68c8897376e/tiktoken-0.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:d8f3192733ac4d77977432947d563d7e1b310b96497acd3c196c9bddb36ed9db", size = 883824 },
|
{ url = "https://files.pythonhosted.org/packages/5b/64/b16003419a1d7728d0d8c0d56a4c24325e7b10a21a9dd1fc0f7115c02f0a/tiktoken-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a62d7a25225bafed786a524c1b9f0910a1128f4232615bf3f8257a73aaa3b16", size = 894897 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -121,7 +121,7 @@ export function PreCode(props: { children: any }) {
|
|||||||
// visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message
|
// visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message
|
||||||
// or use the non-minified dev environment for full errors and additional helpful warnings.
|
// or use the non-minified dev environment for full errors and additional helpful warnings.
|
||||||
|
|
||||||
const CodeBlock: any = memo(({ inline, className, children, ...props }: any) => {
|
const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any) => {
|
||||||
const { theme } = useTheme()
|
const { theme } = useTheme()
|
||||||
const [isSVG, setIsSVG] = useState(true)
|
const [isSVG, setIsSVG] = useState(true)
|
||||||
const match = /language-(\w+)/.exec(className || '')
|
const match = /language-(\w+)/.exec(className || '')
|
||||||
@ -258,7 +258,7 @@ const Link = ({ node, children, ...props }: any) => {
|
|||||||
const { onSend } = useChatContext()
|
const { onSend } = useChatContext()
|
||||||
const hidden_text = decodeURIComponent(node.properties.href.toString().split('abbr:')[1])
|
const hidden_text = decodeURIComponent(node.properties.href.toString().split('abbr:')[1])
|
||||||
|
|
||||||
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value}>{node.children[0]?.value}</abbr>
|
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value || ''}>{node.children[0]?.value || ''}</abbr>
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
return <a {...props} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a>
|
return <a {...props} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a>
|
||||||
|
@ -476,15 +476,15 @@ const Flowchart = React.forwardRef((props: {
|
|||||||
'bg-white': currentTheme === Theme.light,
|
'bg-white': currentTheme === Theme.light,
|
||||||
'bg-slate-900': currentTheme === Theme.dark,
|
'bg-slate-900': currentTheme === Theme.dark,
|
||||||
}),
|
}),
|
||||||
mermaidDiv: cn('mermaid cursor-pointer h-auto w-full relative', {
|
mermaidDiv: cn('mermaid relative h-auto w-full cursor-pointer', {
|
||||||
'bg-white': currentTheme === Theme.light,
|
'bg-white': currentTheme === Theme.light,
|
||||||
'bg-slate-900': currentTheme === Theme.dark,
|
'bg-slate-900': currentTheme === Theme.dark,
|
||||||
}),
|
}),
|
||||||
errorMessage: cn('py-4 px-[26px]', {
|
errorMessage: cn('px-[26px] py-4', {
|
||||||
'text-red-500': currentTheme === Theme.light,
|
'text-red-500': currentTheme === Theme.light,
|
||||||
'text-red-400': currentTheme === Theme.dark,
|
'text-red-400': currentTheme === Theme.dark,
|
||||||
}),
|
}),
|
||||||
errorIcon: cn('w-6 h-6', {
|
errorIcon: cn('h-6 w-6', {
|
||||||
'text-red-500': currentTheme === Theme.light,
|
'text-red-500': currentTheme === Theme.light,
|
||||||
'text-red-400': currentTheme === Theme.dark,
|
'text-red-400': currentTheme === Theme.dark,
|
||||||
}),
|
}),
|
||||||
@ -492,7 +492,7 @@ const Flowchart = React.forwardRef((props: {
|
|||||||
'text-gray-700': currentTheme === Theme.light,
|
'text-gray-700': currentTheme === Theme.light,
|
||||||
'text-gray-300': currentTheme === Theme.dark,
|
'text-gray-300': currentTheme === Theme.dark,
|
||||||
}),
|
}),
|
||||||
themeToggle: cn('flex items-center justify-center w-10 h-10 rounded-full transition-all duration-300 shadow-md backdrop-blur-sm', {
|
themeToggle: cn('flex h-10 w-10 items-center justify-center rounded-full shadow-md backdrop-blur-sm transition-all duration-300', {
|
||||||
'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light,
|
'bg-white/80 hover:bg-white hover:shadow-lg text-gray-700 border border-gray-200': currentTheme === Theme.light,
|
||||||
'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark,
|
'bg-slate-800/80 hover:bg-slate-700 hover:shadow-lg text-yellow-300 border border-slate-600': currentTheme === Theme.dark,
|
||||||
}),
|
}),
|
||||||
@ -501,7 +501,7 @@ const Flowchart = React.forwardRef((props: {
|
|||||||
// Style classes for look options
|
// Style classes for look options
|
||||||
const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => {
|
const getLookButtonClass = (lookType: 'classic' | 'handDrawn') => {
|
||||||
return cn(
|
return cn(
|
||||||
'flex items-center justify-center mb-4 w-[calc((100%-8px)/2)] h-8 rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg cursor-pointer system-sm-medium text-text-secondary',
|
'system-sm-medium mb-4 flex h-8 w-[calc((100%-8px)/2)] cursor-pointer items-center justify-center rounded-lg border border-components-option-card-option-border bg-components-option-card-option-bg text-text-secondary',
|
||||||
look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary',
|
look === lookType && 'border-[1.5px] border-components-option-card-option-selected-border bg-components-option-card-option-selected-bg text-text-primary',
|
||||||
currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300',
|
currentTheme === Theme.dark && 'border-slate-600 bg-slate-800 text-slate-300',
|
||||||
look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white',
|
look === lookType && currentTheme === Theme.dark && 'border-blue-500 bg-slate-700 text-white',
|
||||||
@ -512,7 +512,7 @@ const Flowchart = React.forwardRef((props: {
|
|||||||
<div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
|
<div ref={ref as React.RefObject<HTMLDivElement>} className={themeClasses.container}>
|
||||||
<div className={themeClasses.segmented}>
|
<div className={themeClasses.segmented}>
|
||||||
<div className="msh-segmented-group">
|
<div className="msh-segmented-group">
|
||||||
<label className="msh-segmented-item flex items-center space-x-1 m-2 w-[200px]">
|
<label className="msh-segmented-item m-2 flex w-[200px] items-center space-x-1">
|
||||||
<div
|
<div
|
||||||
key='classic'
|
key='classic'
|
||||||
className={getLookButtonClass('classic')}
|
className={getLookButtonClass('classic')}
|
||||||
@ -534,7 +534,7 @@ const Flowchart = React.forwardRef((props: {
|
|||||||
<div ref={containerRef} style={{ position: 'absolute', visibility: 'hidden', height: 0, overflow: 'hidden' }} />
|
<div ref={containerRef} style={{ position: 'absolute', visibility: 'hidden', height: 0, overflow: 'hidden' }} />
|
||||||
|
|
||||||
{isLoading && !svgCode && (
|
{isLoading && !svgCode && (
|
||||||
<div className='py-4 px-[26px]'>
|
<div className='px-[26px] py-4'>
|
||||||
<LoadingAnim type='text'/>
|
<LoadingAnim type='text'/>
|
||||||
{!isCodeComplete && (
|
{!isCodeComplete && (
|
||||||
<div className="mt-2 text-sm text-gray-500">
|
<div className="mt-2 text-sm text-gray-500">
|
||||||
@ -546,7 +546,7 @@ const Flowchart = React.forwardRef((props: {
|
|||||||
|
|
||||||
{svgCode && (
|
{svgCode && (
|
||||||
<div className={themeClasses.mermaidDiv} style={{ objectFit: 'cover' }} onClick={() => setImagePreviewUrl(svgCode)}>
|
<div className={themeClasses.mermaidDiv} style={{ objectFit: 'cover' }} onClick={() => setImagePreviewUrl(svgCode)}>
|
||||||
<div className="absolute left-2 bottom-2 z-[100]">
|
<div className="absolute bottom-2 left-2 z-[100]">
|
||||||
<button
|
<button
|
||||||
onClick={(e) => {
|
onClick={(e) => {
|
||||||
e.stopPropagation()
|
e.stopPropagation()
|
||||||
|
@ -15,6 +15,7 @@ import { useProviderContext } from '@/context/provider-context'
|
|||||||
import GridMask from '@/app/components/base/grid-mask'
|
import GridMask from '@/app/components/base/grid-mask'
|
||||||
import { useAppContext } from '@/context/app-context'
|
import { useAppContext } from '@/context/app-context'
|
||||||
import classNames from '@/utils/classnames'
|
import classNames from '@/utils/classnames'
|
||||||
|
import { useGetPricingPageLanguage } from '@/context/i18n'
|
||||||
|
|
||||||
type Props = {
|
type Props = {
|
||||||
onCancel: () => void
|
onCancel: () => void
|
||||||
@ -33,6 +34,11 @@ const Pricing: FC<Props> = ({
|
|||||||
|
|
||||||
useKeyPress(['esc'], onCancel)
|
useKeyPress(['esc'], onCancel)
|
||||||
|
|
||||||
|
const pricingPageLanguage = useGetPricingPageLanguage()
|
||||||
|
const pricingPageURL = pricingPageLanguage
|
||||||
|
? `https://dify.ai/${pricingPageLanguage}/pricing#plans-and-features`
|
||||||
|
: 'https://dify.ai/pricing#plans-and-features'
|
||||||
|
|
||||||
return createPortal(
|
return createPortal(
|
||||||
<div
|
<div
|
||||||
className='fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] bg-background-overlay-backdrop p-4 backdrop-blur-[6px]'
|
className='fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] bg-background-overlay-backdrop p-4 backdrop-blur-[6px]'
|
||||||
@ -127,7 +133,7 @@ const Pricing: FC<Props> = ({
|
|||||||
</div>
|
</div>
|
||||||
<div className='flex items-center justify-center py-4'>
|
<div className='flex items-center justify-center py-4'>
|
||||||
<div className='flex items-center justify-center gap-x-0.5 rounded-lg px-3 py-2 text-components-button-secondary-accent-text hover:cursor-pointer hover:bg-state-accent-hover'>
|
<div className='flex items-center justify-center gap-x-0.5 rounded-lg px-3 py-2 text-components-button-secondary-accent-text hover:cursor-pointer hover:bg-state-accent-hover'>
|
||||||
<Link href='https://dify.ai/pricing#plans-and-features' className='system-sm-medium'>{t('billing.plansCommon.comparePlanAndFeatures')}</Link>
|
<Link href={pricingPageURL} className='system-sm-medium'>{t('billing.plansCommon.comparePlanAndFeatures')}</Link>
|
||||||
<RiArrowRightUpLine className='size-4' />
|
<RiArrowRightUpLine className='size-4' />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
import { Fragment, useState } from 'react'
|
import { Fragment, useState } from 'react'
|
||||||
import { useRouter } from 'next/navigation'
|
import { useRouter } from 'next/navigation'
|
||||||
import { useContext, useContextSelector } from 'use-context-selector'
|
import { useContextSelector } from 'use-context-selector'
|
||||||
import {
|
import {
|
||||||
RiAccountCircleLine,
|
RiAccountCircleLine,
|
||||||
RiArrowRightUpLine,
|
RiArrowRightUpLine,
|
||||||
@ -23,13 +23,12 @@ import GithubStar from '../github-star'
|
|||||||
import Support from './support'
|
import Support from './support'
|
||||||
import Compliance from './compliance'
|
import Compliance from './compliance'
|
||||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||||
import I18n from '@/context/i18n'
|
import { useGetDocLanguage } from '@/context/i18n'
|
||||||
import Avatar from '@/app/components/base/avatar'
|
import Avatar from '@/app/components/base/avatar'
|
||||||
import { logout } from '@/service/common'
|
import { logout } from '@/service/common'
|
||||||
import AppContext, { useAppContext } from '@/context/app-context'
|
import AppContext, { useAppContext } from '@/context/app-context'
|
||||||
import { useProviderContext } from '@/context/provider-context'
|
import { useProviderContext } from '@/context/provider-context'
|
||||||
import { useModalContext } from '@/context/modal-context'
|
import { useModalContext } from '@/context/modal-context'
|
||||||
import { LanguagesSupported } from '@/i18n/language'
|
|
||||||
import { LicenseStatus } from '@/types/feature'
|
import { LicenseStatus } from '@/types/feature'
|
||||||
import { IS_CLOUD_EDITION } from '@/config'
|
import { IS_CLOUD_EDITION } from '@/config'
|
||||||
import cn from '@/utils/classnames'
|
import cn from '@/utils/classnames'
|
||||||
@ -43,11 +42,11 @@ export default function AppSelector() {
|
|||||||
const [aboutVisible, setAboutVisible] = useState(false)
|
const [aboutVisible, setAboutVisible] = useState(false)
|
||||||
const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures)
|
const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures)
|
||||||
|
|
||||||
const { locale } = useContext(I18n)
|
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
const { userProfile, langeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
|
const { userProfile, langeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
|
||||||
const { isEducationAccount } = useProviderContext()
|
const { isEducationAccount } = useProviderContext()
|
||||||
const { setShowAccountSettingModal } = useModalContext()
|
const { setShowAccountSettingModal } = useModalContext()
|
||||||
|
const docLanguage = useGetDocLanguage()
|
||||||
|
|
||||||
const handleLogout = async () => {
|
const handleLogout = async () => {
|
||||||
await logout({
|
await logout({
|
||||||
@ -132,9 +131,7 @@ export default function AppSelector() {
|
|||||||
className={cn(itemClassName, 'group justify-between',
|
className={cn(itemClassName, 'group justify-between',
|
||||||
'data-[active]:bg-state-base-hover',
|
'data-[active]:bg-state-base-hover',
|
||||||
)}
|
)}
|
||||||
href={
|
href={`https://docs.dify.ai/${docLanguage}/introduction`}
|
||||||
locale !== LanguagesSupported[1] ? 'https://docs.dify.ai/' : `https://docs.dify.ai/v/${locale.toLowerCase()}/`
|
|
||||||
}
|
|
||||||
target='_blank' rel='noopener noreferrer'>
|
target='_blank' rel='noopener noreferrer'>
|
||||||
<RiBookOpenLine className='size-4 shrink-0 text-text-tertiary' />
|
<RiBookOpenLine className='size-4 shrink-0 text-text-tertiary' />
|
||||||
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.helpCenter')}</div>
|
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.helpCenter')}</div>
|
||||||
|
@ -316,7 +316,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
|||||||
nodesConnectable={!nodesReadOnly}
|
nodesConnectable={!nodesReadOnly}
|
||||||
nodesFocusable={!nodesReadOnly}
|
nodesFocusable={!nodesReadOnly}
|
||||||
edgesFocusable={!nodesReadOnly}
|
edgesFocusable={!nodesReadOnly}
|
||||||
panOnScroll
|
panOnScroll={false}
|
||||||
panOnDrag={controlMode === ControlMode.Hand && !workflowReadOnly}
|
panOnDrag={controlMode === ControlMode.Hand && !workflowReadOnly}
|
||||||
zoomOnPinch={!workflowReadOnly}
|
zoomOnPinch={!workflowReadOnly}
|
||||||
zoomOnScroll={!workflowReadOnly}
|
zoomOnScroll={!workflowReadOnly}
|
||||||
|
@ -3,7 +3,7 @@ import {
|
|||||||
useContext,
|
useContext,
|
||||||
} from 'use-context-selector'
|
} from 'use-context-selector'
|
||||||
import type { Locale } from '@/i18n'
|
import type { Locale } from '@/i18n'
|
||||||
import { getLanguage } from '@/i18n/language'
|
import { getDocLanguage, getLanguage, getPricingPageLanguage } from '@/i18n/language'
|
||||||
import { noop } from 'lodash-es'
|
import { noop } from 'lodash-es'
|
||||||
|
|
||||||
type II18NContext = {
|
type II18NContext = {
|
||||||
@ -24,5 +24,15 @@ export const useGetLanguage = () => {
|
|||||||
|
|
||||||
return getLanguage(locale)
|
return getLanguage(locale)
|
||||||
}
|
}
|
||||||
|
export const useGetDocLanguage = () => {
|
||||||
|
const { locale } = useI18N()
|
||||||
|
|
||||||
|
return getDocLanguage(locale)
|
||||||
|
}
|
||||||
|
export const useGetPricingPageLanguage = () => {
|
||||||
|
const { locale } = useI18N()
|
||||||
|
|
||||||
|
return getPricingPageLanguage(locale)
|
||||||
|
}
|
||||||
|
|
||||||
export default I18NContext
|
export default I18NContext
|
||||||
|
@ -39,6 +39,24 @@ export const getLanguage = (locale: string) => {
|
|||||||
return LanguagesSupported[0].replace('-', '_')
|
return LanguagesSupported[0].replace('-', '_')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const DOC_LANGUAGE: Record<string, string> = {
|
||||||
|
'zh-Hans': 'zh-hans',
|
||||||
|
'ja-JP': 'ja-jp',
|
||||||
|
'en-US': 'en',
|
||||||
|
}
|
||||||
|
|
||||||
|
export const getDocLanguage = (locale: string) => {
|
||||||
|
return DOC_LANGUAGE[locale] || 'en'
|
||||||
|
}
|
||||||
|
|
||||||
|
const PRICING_PAGE_LANGUAGE: Record<string, string> = {
|
||||||
|
'ja-JP': 'jp',
|
||||||
|
}
|
||||||
|
|
||||||
|
export const getPricingPageLanguage = (locale: string) => {
|
||||||
|
return PRICING_PAGE_LANGUAGE[locale] || ''
|
||||||
|
}
|
||||||
|
|
||||||
export const NOTICE_I18N = {
|
export const NOTICE_I18N = {
|
||||||
title: {
|
title: {
|
||||||
en_US: 'Important Notice',
|
en_US: 'Important Notice',
|
||||||
|
Loading…
x
Reference in New Issue
Block a user