diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 507455c176..860ec5de0c 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -30,6 +30,7 @@ from core.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder @@ -65,7 +66,7 @@ class BaseAgentRunner(AppRunner): prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None, + model_instance: ModelInstance | None = None, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -508,24 +509,27 @@ class BaseAgentRunner(AppRunner): def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() - if files: - file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) - - if file_extra_config: - file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config - ) - else: - file_objs = [] - - if not file_objs: - return UserPromptMessage(content=message.query) - else: - prompt_message_contents: list[PromptMessageContent] = [] - prompt_message_contents.append(TextPromptMessageContent(data=message.query)) - for file_obj in file_objs: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) - - return UserPromptMessage(content=prompt_message_contents) - else: + if not files: return UserPromptMessage(content=message.query) + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + if not file_extra_config: + return UserPromptMessage(content=message.query) + + image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=self.tenant_id, config=file_extra_config + ) + if not file_objs: + return UserPromptMessage(content=message.query) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + for file in file_objs: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + return UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 6261a9b12c..d8d047fe91 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -10,6 +10,7 @@ from core.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.utils.encoders import jsonable_encoder @@ -36,8 +37,24 @@ class CotChatAgentRunner(CotAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - for file_obj in self.files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9083b4e85f..cd546dee12 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -22,6 +22,7 @@ from core.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine @@ -397,8 +398,24 @@ class FunctionCallAgentRunner(BaseAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - for file_obj in self.files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 6c6e342a07..9b72452d7a 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,7 +4,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field, field_validator -from core.file import FileExtraConfig, FileTransferMethod, FileType +from core.file import FileTransferMethod, FileType, FileUploadConfig from core.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode @@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel): class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileExtraConfig] = None + file_upload: Optional[FileUploadConfig] = None opening_statement: Optional[str] = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index d0f75d0b75..a79ddf3ddf 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any -from core.file import FileExtraConfig +from core.file import FileUploadConfig class FileUploadConfigManager: @@ -29,15 +29,14 @@ class FileUploadConfigManager: if is_vision: data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") - return FileExtraConfig.model_validate(data) + return FileUploadConfig.model_validate(data) @classmethod - def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for file upload feature :param config: app model config args - :param is_vision: if True, the feature is vision feature """ if not config.get("file_upload"): config["file_upload"] = {} diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index b52f235849..cb606953cd 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -52,9 +52,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # opening_statement diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 3010f8a03f..0b88345061 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -26,7 +26,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models.account import Account -from models.enums import CreatedByRole from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow @@ -98,13 +97,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER if file_extra_config: file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -127,10 +123,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 73d433d94d..d1564a260e 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser -from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -103,8 +102,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -112,8 +109,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -135,10 +130,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index d8e38476c7..6e6da95401 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,12 +2,11 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional from core.app.app_config.entities import VariableEntityType -from core.file import File, FileExtraConfig +from core.file import File, FileUploadConfig from factories import file_factory if TYPE_CHECKING: from core.app.app_config.entities import AppConfig, VariableEntity - from models.enums import CreatedByRole class BaseAppGenerator: @@ -16,8 +15,6 @@ class BaseAppGenerator: *, user_inputs: Optional[Mapping[str, Any]], app_config: "AppConfig", - user_id: str, - role: "CreatedByRole", ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values @@ -34,9 +31,7 @@ class BaseAppGenerator: k: file_factory.build_from_mapping( mapping=v, tenant_id=app_config.tenant_id, - user_id=user_id, - role=role, - config=FileExtraConfig( + config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, @@ -50,9 +45,7 @@ class BaseAppGenerator: k: file_factory.build_from_mappings( mappings=v, tenant_id=app_config.tenant_id, - user_id=user_id, - role=role, - config=FileExtraConfig( + config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index d0ba90cc5e..e683dfef3f 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models.account import Account -from models.enums import CreatedByRole from models.model import App, EndUser logger = logging.getLogger(__name__) @@ -101,8 +100,6 @@ class ChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -110,8 +107,6 @@ class ChatAppGenerator(MessageBasedAppGenerator): file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -133,10 +128,11 @@ class ChatAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 3bb05d05d8..22ee8b0967 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -22,7 +22,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Message -from models.enums import CreatedByRole from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -88,8 +87,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): tenant_id=app_model.tenant_id, config=args.get("model_config") ) - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -97,8 +94,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -110,7 +105,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(app_model.id) # init application generate entity @@ -118,7 +112,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, user_id=user.id, @@ -259,14 +254,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): override_model_config_dict["model"] = model_dict # parse files - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) if file_extra_config: file_objs = file_factory.build_from_mappings( mappings=message.message_files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 8b98e74b85..b0aa21c731 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -46,9 +46,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # text_to_speech diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6e9c6804f9..a0080ece20 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -25,7 +25,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Workflow -from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -70,15 +69,11 @@ class WorkflowAppGenerator(BaseAppGenerator): ): files: Sequence[Mapping[str, Any]] = args.get("files") or [] - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) system_files = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) @@ -100,7 +95,8 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity = WorkflowAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), files=system_files, user_id=user.id, stream=stream, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index f2eba29323..31c3a996e1 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.file.models import File +from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity from core.ops.ops_trace_manager import TraceQueueManager @@ -80,6 +80,7 @@ class AppGenerateEntity(BaseModel): # app config app_config: AppConfig + file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] files: Sequence[File] diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index bdaf8793fa..fe9e52258a 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -2,13 +2,13 @@ from .constants import FILE_MODEL_IDENTITY from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType from .models import ( File, - FileExtraConfig, + FileUploadConfig, ImageConfig, ) __all__ = [ "FileType", - "FileExtraConfig", + "FileUploadConfig", "FileTransferMethod", "FileBelongsTo", "File", diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index ff9220d35f..eb260a8f84 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -33,25 +33,28 @@ def get_attr(*, file: File, attr: FileAttribute): raise ValueError(f"Invalid file attribute: {attr}") -def to_prompt_message_content(f: File, /): +def to_prompt_message_content( + f: File, + /, + *, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, +): """ - Convert a File object to an ImagePromptMessageContent object. + Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. - This function takes a File object and converts it to an ImagePromptMessageContent - object, which can be used as a prompt for image-based AI models. + This function takes a File object and converts it to an appropriate PromptMessageContent + object, which can be used as a prompt for image or audio-based AI models. Args: - file (File): The File object to convert. Must be of type FileType.IMAGE. + f (File): The File object to convert. + detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. + If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. Returns: - ImagePromptMessageContent: An object containing the image data and detail level. + Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level Raises: - ValueError: If the file is not an image or if the file data is missing. - - Note: - The detail level of the image prompt is determined by the file's extra_config. - If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW. + ValueError: If the file type is not supported or if required data is missing. """ match f.type: case FileType.IMAGE: @@ -60,12 +63,7 @@ def to_prompt_message_content(f: File, /): else: data = _to_base64_data_string(f) - if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail: - detail = f._extra_config.image_config.detail - else: - detail = ImagePromptMessageContent.DETAIL.LOW - - return ImagePromptMessageContent(data=data, detail=detail) + return ImagePromptMessageContent(data=data, detail=image_detail_config) case FileType.AUDIO: encoded_string = _file_to_encoded_string(f) if f.extension is None: @@ -78,7 +76,7 @@ def to_prompt_message_content(f: File, /): data = _to_base64_data_string(f) return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) case _: - raise ValueError(f"file type {f.type} is not supported") + raise ValueError("file type f.type is not supported") def download(f: File, /): diff --git a/api/core/file/models.py b/api/core/file/models.py index 866ff3155b..0142893787 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -21,7 +21,7 @@ class ImageConfig(BaseModel): detail: ImagePromptMessageContent.DETAIL | None = None -class FileExtraConfig(BaseModel): +class FileUploadConfig(BaseModel): """ File Upload Entity. """ @@ -46,7 +46,6 @@ class File(BaseModel): extension: Optional[str] = Field(default=None, description="File extension, should contains dot") mime_type: Optional[str] = None size: int = -1 - _extra_config: FileExtraConfig | None = None def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") @@ -107,34 +106,4 @@ class File(BaseModel): case FileTransferMethod.TOOL_FILE: if not self.related_id: raise ValueError("Missing file related_id") - - # Validate the extra config. - if not self._extra_config: - return self - - if self._extra_config.allowed_file_types: - if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM: - raise ValueError(f"Invalid file type: {self.type}") - - if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions: - raise ValueError(f"Invalid file extension: {self.extension}") - - if ( - self._extra_config.allowed_upload_methods - and self.transfer_method not in self._extra_config.allowed_upload_methods - ): - raise ValueError(f"Invalid transfer method: {self.transfer_method}") - - match self.type: - case FileType.IMAGE: - # NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - if not self._extra_config.image_config: - return self - # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field - if ( - self._extra_config.image_config.transfer_methods - and self.transfer_method not in self._extra_config.image_config.transfer_methods - ): - raise ValueError(f"Invalid transfer method: {self.transfer_method}") - return self diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index d92c36a2df..688fb4776a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -81,15 +81,18 @@ class TokenBufferMemory: db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() ) - if workflow_run: + if workflow_run and workflow_run.workflow: file_extra_config = FileUploadConfigManager.convert( workflow_run.workflow.features_dict, is_vision=False ) + detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: file_objs = file_factory.build_from_message_files( message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config ) + if file_extra_config.image_config and file_extra_config.image_config.detail: + detail = file_extra_config.image_config.detail else: file_objs = [] @@ -98,12 +101,16 @@ class TokenBufferMemory: else: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) - for file_obj in file_objs: - if file_obj.type in {FileType.IMAGE, FileType.AUDIO}: - prompt_message = file_manager.to_prompt_message_content(file_obj) + for file in file_objs: + if file.type in {FileType.IMAGE, FileType.AUDIO}: + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: prompt_messages.append(UserPromptMessage(content=message.query)) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index bbd9531b19..0f3f824966 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,6 +15,7 @@ from core.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -26,8 +27,13 @@ class AdvancedPromptTransform(PromptTransform): Advanced Prompt Transform for Workflow LLM Node. """ - def __init__(self, with_variable_tmpl: bool = False) -> None: + def __init__( + self, + with_variable_tmpl: bool = False, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + ) -> None: self.with_variable_tmpl = with_variable_tmpl + self.image_detail_config = image_detail_config def get_prompt( self, diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 8140348723..211ec78f4d 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,19 +1,23 @@ from typing import Any -from core.file import File -from core.file.enums import FileTransferMethod, FileType +from core.file import FileTransferMethod, FileType from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from factories import file_factory class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - test_img = File( + mapping = { + "transfer_method": FileTransferMethod.TOOL_FILE, + "type": FileType.IMAGE, + "id": "test_id", + "url": "https://cloud.dify.ai/logo/logo-site.png", + } + test_img = file_factory.build_from_mapping( + mapping=mapping, tenant_id="__test_123", - remote_url="https://cloud.dify.ai/logo/logo-site.png", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, ) try: VectorizerTool().fork_tool_runtime( diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 61c661e587..5b399bed63 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -13,6 +13,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser +from factories import file_factory from models.workflow import WorkflowNodeExecutionStatus from .entities import ( @@ -161,16 +162,15 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): mimetype=content_type, ) - files.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=filename, - extension=extension, - mime_type=content_type, - ) + mapping = { + "tool_file_id": tool_file.id, + "type": FileType.IMAGE.value, + "transfer_method": FileTransferMethod.TOOL_FILE.value, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + files.append(file) return files diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 42e870c46c..6870b7467d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -17,6 +17,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus @@ -189,19 +190,17 @@ class ToolNode(BaseNode[ToolNodeData]): if tool_file is None: raise ToolFileError(f"Tool file {tool_file_id} does not exist") - result.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - remote_url=url, - related_id=tool_file.id, - filename=tool_file.name, - extension=ext, - mime_type=tool_file.mimetype, - size=tool_file.size, - ) + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id tool_file_id = str(response.message).split("/")[-1].split(".")[0] @@ -209,19 +208,17 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") - result.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=tool_file.name, - extension=path.splitext(response.save_as)[1], - mime_type=tool_file.mimetype, - size=tool_file.size, - ) + raise ValueError(f"tool file {tool_file_id} not exists") + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.LINK: url = str(response.message) transfer_method = FileTransferMethod.TOOL_FILE @@ -235,16 +232,15 @@ class ToolNode(BaseNode[ToolNodeData]): extension = "." + url.split("/")[-1].split(".")[1] else: extension = ".bin" - file = File( + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=self.tenant_id, - type=FileType(response.save_as), - transfer_method=transfer_method, - remote_url=url, - filename=tool_file.name, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, ) result.append(file) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index eb812bad21..84b251223f 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,10 +5,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.app_config.entities import FileExtraConfig +from core.app.app_config.entities import FileUploadConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.models import File, FileTransferMethod, FileType, ImageConfig +from core.file.models import File, FileTransferMethod, ImageConfig from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError @@ -22,6 +22,7 @@ from core.workflow.nodes.base import BaseNode, BaseNodeData from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.llm import LLMNodeData from core.workflow.nodes.node_mapping import node_type_classes_mapping +from factories import file_factory from models.enums import UserFrom from models.workflow import ( Workflow, @@ -271,19 +272,17 @@ class WorkflowEntry: for item in input_value: if isinstance(item, dict) and "type" in item and item["type"] == "image": transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - file = File( + mapping = { + "id": item.get("id"), + "transfer_method": transfer_method, + "upload_file_id": item.get("upload_file_id"), + "url": item.get("url"), + } + config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None) + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - remote_url=item.get("url") - if transfer_method == FileTransferMethod.REMOTE_URL - else None, - related_id=item.get("upload_file_id") - if transfer_method == FileTransferMethod.LOCAL_FILE - else None, - _extra_config=FileExtraConfig( - image_config=ImageConfig(detail=detail) if detail else None - ), + config=config, ) new_value.append(file) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 1066dc8862..738b2b3478 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,23 +1,21 @@ import mimetypes -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from typing import Any import httpx from sqlalchemy import select -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType +from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile -from models.enums import CreatedByRole def build_from_message_files( *, message_files: Sequence["MessageFile"], tenant_id: str, - config: FileExtraConfig, + config: FileUploadConfig, ) -> Sequence[File]: results = [ build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) @@ -31,7 +29,7 @@ def build_from_message_file( *, message_file: "MessageFile", tenant_id: str, - config: FileExtraConfig, + config: FileUploadConfig, ): mapping = { "transfer_method": message_file.transfer_method, @@ -43,8 +41,6 @@ def build_from_message_file( return build_from_mapping( mapping=mapping, tenant_id=tenant_id, - user_id=message_file.created_by, - role=CreatedByRole(message_file.created_by_role), config=config, ) @@ -53,38 +49,30 @@ def build_from_mapping( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - role: "CreatedByRole", - config: FileExtraConfig, -): + config: FileUploadConfig | None = None, +) -> File: + config = config or FileUploadConfig() + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) - match transfer_method: - case FileTransferMethod.REMOTE_URL: - file = _build_from_remote_url( - mapping=mapping, - tenant_id=tenant_id, - config=config, - transfer_method=transfer_method, - ) - case FileTransferMethod.LOCAL_FILE: - file = _build_from_local_file( - mapping=mapping, - tenant_id=tenant_id, - user_id=user_id, - role=role, - config=config, - transfer_method=transfer_method, - ) - case FileTransferMethod.TOOL_FILE: - file = _build_from_tool_file( - mapping=mapping, - tenant_id=tenant_id, - user_id=user_id, - config=config, - transfer_method=transfer_method, - ) - case _: - raise ValueError(f"Invalid file transfer method: {transfer_method}") + + build_functions: dict[FileTransferMethod, Callable] = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + } + + build_func = build_functions.get(transfer_method) + if not build_func: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + ) + + if not _is_file_valid_with_config(file=file, config=config): + raise ValueError(f"File validation failed for file: {file.filename}") return file @@ -92,10 +80,8 @@ def build_from_mapping( def build_from_mappings( *, mappings: Sequence[Mapping[str, Any]], - config: FileExtraConfig | None, + config: FileUploadConfig | None, tenant_id: str, - user_id: str, - role: "CreatedByRole", ) -> Sequence[File]: if not config: return [] @@ -104,8 +90,6 @@ def build_from_mappings( build_from_mapping( mapping=mapping, tenant_id=tenant_id, - user_id=user_id, - role=role, config=config, ) for mapping in mappings @@ -128,31 +112,20 @@ def _build_from_local_file( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - role: "CreatedByRole", - config: FileExtraConfig, transfer_method: FileTransferMethod, -): - # check if the upload file exists. +) -> File: file_type = FileType.value_of(mapping.get("type")) stmt = select(UploadFile).where( UploadFile.id == mapping.get("upload_file_id"), UploadFile.tenant_id == tenant_id, - UploadFile.created_by == user_id, - UploadFile.created_by_role == role, ) - if file_type == FileType.IMAGE: - stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS)) - elif file_type == FileType.VIDEO: - stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS)) - elif file_type == FileType.AUDIO: - stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS)) - elif file_type == FileType.DOCUMENT: - stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS)) + row = db.session.scalar(stmt) + if row is None: raise ValueError("Invalid upload file") - file = File( + + return File( id=mapping.get("id"), filename=row.name, extension="." + row.extension, @@ -162,23 +135,37 @@ def _build_from_local_file( transfer_method=transfer_method, remote_url=row.source_url, related_id=mapping.get("upload_file_id"), - _extra_config=config, size=row.size, ) - return file def _build_from_remote_url( *, mapping: Mapping[str, Any], tenant_id: str, - config: FileExtraConfig, transfer_method: FileTransferMethod, -): +) -> File: url = mapping.get("url") if not url: raise ValueError("Invalid file url") + mime_type, filename, file_size = _get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + + return File( + id=mapping.get("id"), + filename=filename, + tenant_id=tenant_id, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _get_remote_file_info(url: str): mime_type = mimetypes.guess_type(url)[0] or "" file_size = -1 filename = url.split("/")[-1].split("?")[0] or "unknown_file" @@ -186,56 +173,34 @@ def _build_from_remote_url( resp = ssrf_proxy.head(url, follow_redirects=True) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): - filename = content_disposition.split("filename=")[-1].strip('"') + filename = str(content_disposition.split("filename=")[-1].strip('"')) file_size = int(resp.headers.get("Content-Length", file_size)) mime_type = mime_type or str(resp.headers.get("Content-Type", "")) - # Determine file extension - extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" - - if not mime_type: - mime_type, _ = mimetypes.guess_type(url) - file = File( - id=mapping.get("id"), - filename=filename, - tenant_id=tenant_id, - type=FileType.value_of(mapping.get("type")), - transfer_method=transfer_method, - remote_url=url, - _extra_config=config, - mime_type=mime_type, - extension=extension, - size=file_size, - ) - return file + return mime_type, filename, file_size def _build_from_tool_file( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - config: FileExtraConfig, transfer_method: FileTransferMethod, -): +) -> File: tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == mapping.get("tool_file_id"), ToolFile.tenant_id == tenant_id, - ToolFile.user_id == user_id, ) .first() ) + if tool_file is None: raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") - path = tool_file.file_key - if "." in path: - extension = "." + path.split("/")[-1].split(".")[-1] - else: - extension = ".bin" - file = File( + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + + return File( id=mapping.get("id"), tenant_id=tenant_id, filename=tool_file.name, @@ -246,6 +211,21 @@ def _build_from_tool_file( extension=extension, mime_type=tool_file.mimetype, size=tool_file.size, - _extra_config=config, ) - return file + + +def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: + if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: + return False + + if config.allowed_extensions and file.extension not in config.allowed_extensions: + return False + + if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods: + return False + + if file.type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: + return False + + return True diff --git a/api/models/model.py b/api/models/model.py index d049cd373d..e909d53e3e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -13,7 +13,7 @@ from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config -from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser from extensions.ext_database import db @@ -949,9 +949,6 @@ class Message(db.Model): "type": message_file.type, }, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) elif message_file.transfer_method == "remote_url": if message_file.url is None: @@ -964,9 +961,6 @@ class Message(db.Model): "url": message_file.url, }, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) elif message_file.transfer_method == "tool_file": if message_file.upload_file_id is None: @@ -981,9 +975,6 @@ class Message(db.Model): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) else: raise ValueError( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 75c11afa94..90b5cc4836 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,7 +13,7 @@ from core.app.app_config.entities import ( from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.file.models import FileExtraConfig +from core.file.models import FileUploadConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -381,7 +381,7 @@ class WorkflowConverter: graph: dict, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, - file_upload: Optional[FileExtraConfig] = None, + file_upload: Optional[FileUploadConfig] = None, external_data_variable_node_mapping: dict[str, str] | None = None, ) -> dict: """ diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 0da6622658..9eea63f722 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -430,37 +430,3 @@ def test_multi_colons_parse(setup_http_mock): assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "") # assert "http://example3.com" == resp.get("headers", {}).get("referer") - - -def test_image_file(monkeypatch): - from types import SimpleNamespace - - monkeypatch.setattr( - "core.tools.tool_file_manager.ToolFileManager.create_file_by_raw", - lambda *args, **kwargs: SimpleNamespace(id="1"), - ) - - node = init_http_node( - config={ - "id": "1", - "data": { - "title": "http", - "desc": "", - "method": "get", - "url": "https://cloud.dify.ai/logo/logo-site.png", - "authorization": { - "type": "no-auth", - "config": None, - }, - "params": "", - "headers": "", - "body": None, - }, - } - ) - - result = node._run() - assert result.process_data is not None - assert result.outputs is not None - resp = result.outputs - assert len(resp.get("files", [])) == 1 diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index ece2173090..7d19cff3e8 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.app.app_config.entities import ModelConfigEntity -from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig +from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -134,7 +134,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", - _extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)), ) ]