refactor(core): Remove extra_config from File. (#10203)

This commit is contained in:
-LAN- 2024-11-08 18:13:24 +08:00 committed by GitHub
parent 78a380bcc4
commit 25ca0278dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 263 additions and 344 deletions

View File

@ -30,6 +30,7 @@ from core.model_runtime.entities import (
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -65,7 +66,7 @@ class BaseAgentRunner(AppRunner):
prompt_messages: Optional[list[PromptMessage]] = None, prompt_messages: Optional[list[PromptMessage]] = None,
variables_pool: Optional[ToolRuntimeVariablePool] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None,
db_variables: Optional[ToolConversationVariables] = None, db_variables: Optional[ToolConversationVariables] = None,
model_instance: ModelInstance = None, model_instance: ModelInstance | None = None,
) -> None: ) -> None:
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
@ -508,24 +509,27 @@ class BaseAgentRunner(AppRunner):
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files: if not files:
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if file_extra_config:
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
else:
file_objs = []
if not file_objs:
return UserPromptMessage(content=message.query)
else:
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
return UserPromptMessage(content=prompt_message_contents)
else:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
if not file_extra_config:
return UserPromptMessage(content=message.query)
image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
)
if not file_objs:
return UserPromptMessage(content=message.query)
prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
return UserPromptMessage(content=prompt_message_contents)

View File

@ -10,6 +10,7 @@ from core.model_runtime.entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -36,8 +37,24 @@ class CotChatAgentRunner(CotAgentRunner):
if self.files: if self.files:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) # get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:

View File

@ -22,6 +22,7 @@ from core.model_runtime.entities import (
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
@ -397,8 +398,24 @@ class FunctionCallAgentRunner(BaseAgentRunner):
if self.files: if self.files:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_message_contents.append(TextPromptMessageContent(data=query))
for file_obj in self.files:
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) # get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:

View File

@ -4,7 +4,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
from core.file import FileExtraConfig, FileTransferMethod, FileType from core.file import FileTransferMethod, FileType, FileUploadConfig
from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.message_entities import PromptMessageRole
from models.model import AppMode from models.model import AppMode
@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel):
class AppAdditionalFeatures(BaseModel): class AppAdditionalFeatures(BaseModel):
file_upload: Optional[FileExtraConfig] = None file_upload: Optional[FileUploadConfig] = None
opening_statement: Optional[str] = None opening_statement: Optional[str] = None
suggested_questions: list[str] = [] suggested_questions: list[str] = []
suggested_questions_after_answer: bool = False suggested_questions_after_answer: bool = False

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
from core.file import FileExtraConfig from core.file import FileUploadConfig
class FileUploadConfigManager: class FileUploadConfigManager:
@ -29,15 +29,14 @@ class FileUploadConfigManager:
if is_vision: if is_vision:
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
return FileExtraConfig.model_validate(data) return FileUploadConfig.model_validate(data)
@classmethod @classmethod
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
""" """
Validate and set defaults for file upload feature Validate and set defaults for file upload feature
:param config: app model config args :param config: app model config args
:param is_vision: if True, the feature is vision feature
""" """
if not config.get("file_upload"): if not config.get("file_upload"):
config["file_upload"] = {} config["file_upload"] = {}

View File

@ -52,9 +52,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
related_config_keys = [] related_config_keys = []
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
config=config, is_vision=False
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# opening_statement # opening_statement

View File

@ -26,7 +26,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models.account import Account from models.account import Account
from models.enums import CreatedByRole
from models.model import App, Conversation, EndUser, Message from models.model import App, Conversation, EndUser, Message
from models.workflow import Workflow from models.workflow import Workflow
@ -98,13 +97,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
if file_extra_config: if file_extra_config:
file_objs = file_factory.build_from_mappings( file_objs = file_factory.build_from_mappings(
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config, config=file_extra_config,
) )
else: else:
@ -127,10 +123,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
application_generate_entity = AdvancedChatAppGenerateEntity( application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, EndUser from models import Account, App, EndUser
from models.enums import CreatedByRole
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -103,8 +102,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True} override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files # parse files
files = args.get("files") or [] files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
@ -112,8 +109,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
file_objs = file_factory.build_from_mappings( file_objs = file_factory.build_from_mappings(
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config, config=file_extra_config,
) )
else: else:
@ -135,10 +130,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -2,12 +2,11 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from core.app.app_config.entities import VariableEntityType from core.app.app_config.entities import VariableEntityType
from core.file import File, FileExtraConfig from core.file import File, FileUploadConfig
from factories import file_factory from factories import file_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from core.app.app_config.entities import AppConfig, VariableEntity from core.app.app_config.entities import AppConfig, VariableEntity
from models.enums import CreatedByRole
class BaseAppGenerator: class BaseAppGenerator:
@ -16,8 +15,6 @@ class BaseAppGenerator:
*, *,
user_inputs: Optional[Mapping[str, Any]], user_inputs: Optional[Mapping[str, Any]],
app_config: "AppConfig", app_config: "AppConfig",
user_id: str,
role: "CreatedByRole",
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
user_inputs = user_inputs or {} user_inputs = user_inputs or {}
# Filter input variables from form configuration, handle required fields, default values, and option values # Filter input variables from form configuration, handle required fields, default values, and option values
@ -34,9 +31,7 @@ class BaseAppGenerator:
k: file_factory.build_from_mapping( k: file_factory.build_from_mapping(
mapping=v, mapping=v,
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
user_id=user_id, config=FileUploadConfig(
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
@ -50,9 +45,7 @@ class BaseAppGenerator:
k: file_factory.build_from_mappings( k: file_factory.build_from_mappings(
mappings=v, mappings=v,
tenant_id=app_config.tenant_id, tenant_id=app_config.tenant_id,
user_id=user_id, config=FileUploadConfig(
role=role,
config=FileExtraConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_types=entity_dictionary[k].allowed_file_types,
allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_extensions=entity_dictionary[k].allowed_file_extensions,
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,

View File

@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models.account import Account from models.account import Account
from models.enums import CreatedByRole
from models.model import App, EndUser from models.model import App, EndUser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -101,8 +100,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode # always enable retriever resource in debugger mode
override_model_config_dict["retriever_resource"] = {"enabled": True} override_model_config_dict["retriever_resource"] = {"enabled": True}
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
@ -110,8 +107,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
file_objs = file_factory.build_from_mappings( file_objs = file_factory.build_from_mappings(
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config, config=file_extra_config,
) )
else: else:
@ -133,10 +128,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
file_upload_config=file_extra_config,
conversation_id=conversation.id if conversation else None, conversation_id=conversation.id if conversation else None,
inputs=conversation.inputs inputs=conversation.inputs
if conversation if conversation
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,

View File

@ -22,7 +22,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, EndUser, Message from models import Account, App, EndUser, Message
from models.enums import CreatedByRole
from services.errors.app import MoreLikeThisDisabledError from services.errors.app import MoreLikeThisDisabledError
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
@ -88,8 +87,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
tenant_id=app_model.tenant_id, config=args.get("model_config") tenant_id=app_model.tenant_id, config=args.get("model_config")
) )
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files # parse files
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
@ -97,8 +94,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
file_objs = file_factory.build_from_mappings( file_objs = file_factory.build_from_mappings(
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config, config=file_extra_config,
) )
else: else:
@ -110,7 +105,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
) )
# get tracing instance # get tracing instance
user_id = user.id if isinstance(user, Account) else user.session_id
trace_manager = TraceQueueManager(app_model.id) trace_manager = TraceQueueManager(app_model.id)
# init application generate entity # init application generate entity
@ -118,7 +112,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
model_conf=ModelConfigConverter.convert(app_config), model_conf=ModelConfigConverter.convert(app_config),
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
query=query, query=query,
files=file_objs, files=file_objs,
user_id=user.id, user_id=user.id,
@ -259,14 +254,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["model"] = model_dict override_model_config_dict["model"] = model_dict
# parse files # parse files
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
if file_extra_config: if file_extra_config:
file_objs = file_factory.build_from_mappings( file_objs = file_factory.build_from_mappings(
mappings=message.message_files, mappings=message.message_files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config, config=file_extra_config,
) )
else: else:

View File

@ -46,9 +46,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
related_config_keys = [] related_config_keys = []
# file upload validation # file upload validation
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
config=config, is_vision=False
)
related_config_keys.extend(current_related_config_keys) related_config_keys.extend(current_related_config_keys)
# text_to_speech # text_to_speech

View File

@ -25,7 +25,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, EndUser, Workflow from models import Account, App, EndUser, Workflow
from models.enums import CreatedByRole
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -70,15 +69,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
): ):
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
# parse files # parse files
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings( system_files = file_factory.build_from_mappings(
mappings=files, mappings=files,
tenant_id=app_model.tenant_id, tenant_id=app_model.tenant_id,
user_id=user.id,
role=role,
config=file_extra_config, config=file_extra_config,
) )
@ -100,7 +95,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity = WorkflowAppGenerateEntity( application_generate_entity = WorkflowAppGenerateEntity(
task_id=str(uuid.uuid4()), task_id=str(uuid.uuid4()),
app_config=app_config, app_config=app_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), file_upload_config=file_extra_config,
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
files=system_files, files=system_files,
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
from constants import UUID_NIL from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file.models import File from core.file import File, FileUploadConfig
from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.model_entities import AIModelEntity
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
@ -80,6 +80,7 @@ class AppGenerateEntity(BaseModel):
# app config # app config
app_config: AppConfig app_config: AppConfig
file_upload_config: Optional[FileUploadConfig] = None
inputs: Mapping[str, Any] inputs: Mapping[str, Any]
files: Sequence[File] files: Sequence[File]

View File

@ -2,13 +2,13 @@ from .constants import FILE_MODEL_IDENTITY
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
from .models import ( from .models import (
File, File,
FileExtraConfig, FileUploadConfig,
ImageConfig, ImageConfig,
) )
__all__ = [ __all__ = [
"FileType", "FileType",
"FileExtraConfig", "FileUploadConfig",
"FileTransferMethod", "FileTransferMethod",
"FileBelongsTo", "FileBelongsTo",
"File", "File",

View File

@ -33,25 +33,28 @@ def get_attr(*, file: File, attr: FileAttribute):
raise ValueError(f"Invalid file attribute: {attr}") raise ValueError(f"Invalid file attribute: {attr}")
def to_prompt_message_content(f: File, /): def to_prompt_message_content(
f: File,
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
):
""" """
Convert a File object to an ImagePromptMessageContent object. Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
This function takes a File object and converts it to an ImagePromptMessageContent This function takes a File object and converts it to an appropriate PromptMessageContent
object, which can be used as a prompt for image-based AI models. object, which can be used as a prompt for image or audio-based AI models.
Args: Args:
file (File): The File object to convert. Must be of type FileType.IMAGE. f (File): The File object to convert.
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
Returns: Returns:
ImagePromptMessageContent: An object containing the image data and detail level. Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
Raises: Raises:
ValueError: If the file is not an image or if the file data is missing. ValueError: If the file type is not supported or if required data is missing.
Note:
The detail level of the image prompt is determined by the file's extra_config.
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
""" """
match f.type: match f.type:
case FileType.IMAGE: case FileType.IMAGE:
@ -60,12 +63,7 @@ def to_prompt_message_content(f: File, /):
else: else:
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail: return ImagePromptMessageContent(data=data, detail=image_detail_config)
detail = f._extra_config.image_config.detail
else:
detail = ImagePromptMessageContent.DETAIL.LOW
return ImagePromptMessageContent(data=data, detail=detail)
case FileType.AUDIO: case FileType.AUDIO:
encoded_string = _file_to_encoded_string(f) encoded_string = _file_to_encoded_string(f)
if f.extension is None: if f.extension is None:
@ -78,7 +76,7 @@ def to_prompt_message_content(f: File, /):
data = _to_base64_data_string(f) data = _to_base64_data_string(f)
return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case _: case _:
raise ValueError(f"file type {f.type} is not supported") raise ValueError("file type f.type is not supported")
def download(f: File, /): def download(f: File, /):

View File

@ -21,7 +21,7 @@ class ImageConfig(BaseModel):
detail: ImagePromptMessageContent.DETAIL | None = None detail: ImagePromptMessageContent.DETAIL | None = None
class FileExtraConfig(BaseModel): class FileUploadConfig(BaseModel):
""" """
File Upload Entity. File Upload Entity.
""" """
@ -46,7 +46,6 @@ class File(BaseModel):
extension: Optional[str] = Field(default=None, description="File extension, should contains dot") extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
mime_type: Optional[str] = None mime_type: Optional[str] = None
size: int = -1 size: int = -1
_extra_config: FileExtraConfig | None = None
def to_dict(self) -> Mapping[str, str | int | None]: def to_dict(self) -> Mapping[str, str | int | None]:
data = self.model_dump(mode="json") data = self.model_dump(mode="json")
@ -107,34 +106,4 @@ class File(BaseModel):
case FileTransferMethod.TOOL_FILE: case FileTransferMethod.TOOL_FILE:
if not self.related_id: if not self.related_id:
raise ValueError("Missing file related_id") raise ValueError("Missing file related_id")
# Validate the extra config.
if not self._extra_config:
return self
if self._extra_config.allowed_file_types:
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
raise ValueError(f"Invalid file type: {self.type}")
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
raise ValueError(f"Invalid file extension: {self.extension}")
if (
self._extra_config.allowed_upload_methods
and self.transfer_method not in self._extra_config.allowed_upload_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
match self.type:
case FileType.IMAGE:
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
if not self._extra_config.image_config:
return self
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
if (
self._extra_config.image_config.transfer_methods
and self.transfer_method not in self._extra_config.image_config.transfer_methods
):
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
return self return self

View File

@ -81,15 +81,18 @@ class TokenBufferMemory:
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
) )
if workflow_run: if workflow_run and workflow_run.workflow:
file_extra_config = FileUploadConfigManager.convert( file_extra_config = FileUploadConfigManager.convert(
workflow_run.workflow.features_dict, is_vision=False workflow_run.workflow.features_dict, is_vision=False
) )
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record: if file_extra_config and app_record:
file_objs = file_factory.build_from_message_files( file_objs = file_factory.build_from_message_files(
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
) )
if file_extra_config.image_config and file_extra_config.image_config.detail:
detail = file_extra_config.image_config.detail
else: else:
file_objs = [] file_objs = []
@ -98,12 +101,16 @@ class TokenBufferMemory:
else: else:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContent] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query)) prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file_obj in file_objs: for file in file_objs:
if file_obj.type in {FileType.IMAGE, FileType.AUDIO}: if file.type in {FileType.IMAGE, FileType.AUDIO}:
prompt_message = file_manager.to_prompt_message_content(file_obj) prompt_message = file_manager.to_prompt_message_content(
file,
image_detail_config=detail,
)
prompt_message_contents.append(prompt_message) prompt_message_contents.append(prompt_message)
prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else: else:
prompt_messages.append(UserPromptMessage(content=message.query)) prompt_messages.append(UserPromptMessage(content=message.query))

View File

@ -15,6 +15,7 @@ from core.model_runtime.entities import (
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@ -26,8 +27,13 @@ class AdvancedPromptTransform(PromptTransform):
Advanced Prompt Transform for Workflow LLM Node. Advanced Prompt Transform for Workflow LLM Node.
""" """
def __init__(self, with_variable_tmpl: bool = False) -> None: def __init__(
self,
with_variable_tmpl: bool = False,
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
) -> None:
self.with_variable_tmpl = with_variable_tmpl self.with_variable_tmpl = with_variable_tmpl
self.image_detail_config = image_detail_config
def get_prompt( def get_prompt(
self, self,

View File

@ -1,19 +1,23 @@
from typing import Any from typing import Any
from core.file import File from core.file import FileTransferMethod, FileType
from core.file.enums import FileTransferMethod, FileType
from core.tools.errors import ToolProviderCredentialValidationError from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from factories import file_factory
class VectorizerProvider(BuiltinToolProviderController): class VectorizerProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None: def _validate_credentials(self, credentials: dict[str, Any]) -> None:
test_img = File( mapping = {
"transfer_method": FileTransferMethod.TOOL_FILE,
"type": FileType.IMAGE,
"id": "test_id",
"url": "https://cloud.dify.ai/logo/logo-site.png",
}
test_img = file_factory.build_from_mapping(
mapping=mapping,
tenant_id="__test_123", tenant_id="__test_123",
remote_url="https://cloud.dify.ai/logo/logo-site.png",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
) )
try: try:
VectorizerTool().fork_tool_runtime( VectorizerTool().fork_tool_runtime(

View File

@ -13,6 +13,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.http_request.executor import Executor from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser from core.workflow.utils import variable_template_parser
from factories import file_factory
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import ( from .entities import (
@ -161,16 +162,15 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mimetype=content_type, mimetype=content_type,
) )
files.append( mapping = {
File( "tool_file_id": tool_file.id,
tenant_id=self.tenant_id, "type": FileType.IMAGE.value,
type=FileType.IMAGE, "transfer_method": FileTransferMethod.TOOL_FILE.value,
transfer_method=FileTransferMethod.TOOL_FILE, }
related_id=tool_file.id, file = file_factory.build_from_mapping(
filename=filename, mapping=mapping,
extension=extension, tenant_id=self.tenant_id,
mime_type=content_type,
)
) )
files.append(file)
return files return files

View File

@ -17,6 +17,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory
from models import ToolFile from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -189,19 +190,17 @@ class ToolNode(BaseNode[ToolNodeData]):
if tool_file is None: if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist") raise ToolFileError(f"Tool file {tool_file_id} does not exist")
result.append( mapping = {
File( "tool_file_id": tool_file_id,
tenant_id=self.tenant_id, "type": FileType.IMAGE,
type=FileType.IMAGE, "transfer_method": transfer_method,
transfer_method=transfer_method, "url": url,
remote_url=url, }
related_id=tool_file.id, file = file_factory.build_from_mapping(
filename=tool_file.name, mapping=mapping,
extension=ext, tenant_id=self.tenant_id,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
) )
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.BLOB: elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id # get tool file id
tool_file_id = str(response.message).split("/")[-1].split(".")[0] tool_file_id = str(response.message).split("/")[-1].split(".")[0]
@ -209,19 +208,17 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id) stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt) tool_file = session.scalar(stmt)
if tool_file is None: if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist") raise ValueError(f"tool file {tool_file_id} not exists")
result.append( mapping = {
File( "tool_file_id": tool_file_id,
tenant_id=self.tenant_id, "type": FileType.IMAGE,
type=FileType.IMAGE, "transfer_method": FileTransferMethod.TOOL_FILE,
transfer_method=FileTransferMethod.TOOL_FILE, }
related_id=tool_file.id, file = file_factory.build_from_mapping(
filename=tool_file.name, mapping=mapping,
extension=path.splitext(response.save_as)[1], tenant_id=self.tenant_id,
mime_type=tool_file.mimetype,
size=tool_file.size,
)
) )
result.append(file)
elif response.type == ToolInvokeMessage.MessageType.LINK: elif response.type == ToolInvokeMessage.MessageType.LINK:
url = str(response.message) url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE transfer_method = FileTransferMethod.TOOL_FILE
@ -235,16 +232,15 @@ class ToolNode(BaseNode[ToolNodeData]):
extension = "." + url.split("/")[-1].split(".")[1] extension = "." + url.split("/")[-1].split(".")[1]
else: else:
extension = ".bin" extension = ".bin"
file = File( mapping = {
"tool_file_id": tool_file_id,
"type": FileType.IMAGE,
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType(response.save_as),
transfer_method=transfer_method,
remote_url=url,
filename=tool_file.name,
related_id=tool_file.id,
extension=extension,
mime_type=tool_file.mimetype,
size=tool_file.size,
) )
result.append(file) result.append(file)

View File

@ -5,10 +5,10 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from configs import dify_config from configs import dify_config
from core.app.app_config.entities import FileExtraConfig from core.app.app_config.entities import FileUploadConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File, FileTransferMethod, FileType, ImageConfig from core.file.models import File, FileTransferMethod, ImageConfig
from core.workflow.callbacks import WorkflowCallback from core.workflow.callbacks import WorkflowCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
@ -22,6 +22,7 @@ from core.workflow.nodes.base import BaseNode, BaseNodeData
from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.llm import LLMNodeData from core.workflow.nodes.llm import LLMNodeData
from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.nodes.node_mapping import node_type_classes_mapping
from factories import file_factory
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
@ -271,19 +272,17 @@ class WorkflowEntry:
for item in input_value: for item in input_value:
if isinstance(item, dict) and "type" in item and item["type"] == "image": if isinstance(item, dict) and "type" in item and item["type"] == "image":
transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
file = File( mapping = {
"id": item.get("id"),
"transfer_method": transfer_method,
"upload_file_id": item.get("upload_file_id"),
"url": item.get("url"),
}
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id, tenant_id=tenant_id,
type=FileType.IMAGE, config=config,
transfer_method=transfer_method,
remote_url=item.get("url")
if transfer_method == FileTransferMethod.REMOTE_URL
else None,
related_id=item.get("upload_file_id")
if transfer_method == FileTransferMethod.LOCAL_FILE
else None,
_extra_config=FileExtraConfig(
image_config=ImageConfig(detail=detail) if detail else None
),
) )
new_value.append(file) new_value.append(file)

View File

@ -1,23 +1,21 @@
import mimetypes import mimetypes
from collections.abc import Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any from typing import Any
import httpx import httpx
from sqlalchemy import select from sqlalchemy import select
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
from models import MessageFile, ToolFile, UploadFile from models import MessageFile, ToolFile, UploadFile
from models.enums import CreatedByRole
def build_from_message_files( def build_from_message_files(
*, *,
message_files: Sequence["MessageFile"], message_files: Sequence["MessageFile"],
tenant_id: str, tenant_id: str,
config: FileExtraConfig, config: FileUploadConfig,
) -> Sequence[File]: ) -> Sequence[File]:
results = [ results = [
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
@ -31,7 +29,7 @@ def build_from_message_file(
*, *,
message_file: "MessageFile", message_file: "MessageFile",
tenant_id: str, tenant_id: str,
config: FileExtraConfig, config: FileUploadConfig,
): ):
mapping = { mapping = {
"transfer_method": message_file.transfer_method, "transfer_method": message_file.transfer_method,
@ -43,8 +41,6 @@ def build_from_message_file(
return build_from_mapping( return build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=message_file.created_by,
role=CreatedByRole(message_file.created_by_role),
config=config, config=config,
) )
@ -53,38 +49,30 @@ def build_from_mapping(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
user_id: str, config: FileUploadConfig | None = None,
role: "CreatedByRole", ) -> File:
config: FileExtraConfig, config = config or FileUploadConfig()
):
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
match transfer_method:
case FileTransferMethod.REMOTE_URL: build_functions: dict[FileTransferMethod, Callable] = {
file = _build_from_remote_url( FileTransferMethod.LOCAL_FILE: _build_from_local_file,
mapping=mapping, FileTransferMethod.REMOTE_URL: _build_from_remote_url,
tenant_id=tenant_id, FileTransferMethod.TOOL_FILE: _build_from_tool_file,
config=config, }
transfer_method=transfer_method,
) build_func = build_functions.get(transfer_method)
case FileTransferMethod.LOCAL_FILE: if not build_func:
file = _build_from_local_file( raise ValueError(f"Invalid file transfer method: {transfer_method}")
mapping=mapping,
tenant_id=tenant_id, file = build_func(
user_id=user_id, mapping=mapping,
role=role, tenant_id=tenant_id,
config=config, transfer_method=transfer_method,
transfer_method=transfer_method, )
)
case FileTransferMethod.TOOL_FILE: if not _is_file_valid_with_config(file=file, config=config):
file = _build_from_tool_file( raise ValueError(f"File validation failed for file: {file.filename}")
mapping=mapping,
tenant_id=tenant_id,
user_id=user_id,
config=config,
transfer_method=transfer_method,
)
case _:
raise ValueError(f"Invalid file transfer method: {transfer_method}")
return file return file
@ -92,10 +80,8 @@ def build_from_mapping(
def build_from_mappings( def build_from_mappings(
*, *,
mappings: Sequence[Mapping[str, Any]], mappings: Sequence[Mapping[str, Any]],
config: FileExtraConfig | None, config: FileUploadConfig | None,
tenant_id: str, tenant_id: str,
user_id: str,
role: "CreatedByRole",
) -> Sequence[File]: ) -> Sequence[File]:
if not config: if not config:
return [] return []
@ -104,8 +90,6 @@ def build_from_mappings(
build_from_mapping( build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=tenant_id, tenant_id=tenant_id,
user_id=user_id,
role=role,
config=config, config=config,
) )
for mapping in mappings for mapping in mappings
@ -128,31 +112,20 @@ def _build_from_local_file(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
user_id: str,
role: "CreatedByRole",
config: FileExtraConfig,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
): ) -> File:
# check if the upload file exists.
file_type = FileType.value_of(mapping.get("type")) file_type = FileType.value_of(mapping.get("type"))
stmt = select(UploadFile).where( stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"), UploadFile.id == mapping.get("upload_file_id"),
UploadFile.tenant_id == tenant_id, UploadFile.tenant_id == tenant_id,
UploadFile.created_by == user_id,
UploadFile.created_by_role == role,
) )
if file_type == FileType.IMAGE:
stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
elif file_type == FileType.VIDEO:
stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
elif file_type == FileType.AUDIO:
stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
elif file_type == FileType.DOCUMENT:
stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
row = db.session.scalar(stmt) row = db.session.scalar(stmt)
if row is None: if row is None:
raise ValueError("Invalid upload file") raise ValueError("Invalid upload file")
file = File(
return File(
id=mapping.get("id"), id=mapping.get("id"),
filename=row.name, filename=row.name,
extension="." + row.extension, extension="." + row.extension,
@ -162,23 +135,37 @@ def _build_from_local_file(
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=row.source_url, remote_url=row.source_url,
related_id=mapping.get("upload_file_id"), related_id=mapping.get("upload_file_id"),
_extra_config=config,
size=row.size, size=row.size,
) )
return file
def _build_from_remote_url( def _build_from_remote_url(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
config: FileExtraConfig,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
): ) -> File:
url = mapping.get("url") url = mapping.get("url")
if not url: if not url:
raise ValueError("Invalid file url") raise ValueError("Invalid file url")
mime_type, filename, file_size = _get_remote_file_info(url)
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
return File(
id=mapping.get("id"),
filename=filename,
tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")),
transfer_method=transfer_method,
remote_url=url,
mime_type=mime_type,
extension=extension,
size=file_size,
)
def _get_remote_file_info(url: str):
mime_type = mimetypes.guess_type(url)[0] or "" mime_type = mimetypes.guess_type(url)[0] or ""
file_size = -1 file_size = -1
filename = url.split("/")[-1].split("?")[0] or "unknown_file" filename = url.split("/")[-1].split("?")[0] or "unknown_file"
@ -186,56 +173,34 @@ def _build_from_remote_url(
resp = ssrf_proxy.head(url, follow_redirects=True) resp = ssrf_proxy.head(url, follow_redirects=True)
if resp.status_code == httpx.codes.OK: if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"): if content_disposition := resp.headers.get("Content-Disposition"):
filename = content_disposition.split("filename=")[-1].strip('"') filename = str(content_disposition.split("filename=")[-1].strip('"'))
file_size = int(resp.headers.get("Content-Length", file_size)) file_size = int(resp.headers.get("Content-Length", file_size))
mime_type = mime_type or str(resp.headers.get("Content-Type", "")) mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
# Determine file extension return mime_type, filename, file_size
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
if not mime_type:
mime_type, _ = mimetypes.guess_type(url)
file = File(
id=mapping.get("id"),
filename=filename,
tenant_id=tenant_id,
type=FileType.value_of(mapping.get("type")),
transfer_method=transfer_method,
remote_url=url,
_extra_config=config,
mime_type=mime_type,
extension=extension,
size=file_size,
)
return file
def _build_from_tool_file( def _build_from_tool_file(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
tenant_id: str, tenant_id: str,
user_id: str,
config: FileExtraConfig,
transfer_method: FileTransferMethod, transfer_method: FileTransferMethod,
): ) -> File:
tool_file = ( tool_file = (
db.session.query(ToolFile) db.session.query(ToolFile)
.filter( .filter(
ToolFile.id == mapping.get("tool_file_id"), ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id, ToolFile.tenant_id == tenant_id,
ToolFile.user_id == user_id,
) )
.first() .first()
) )
if tool_file is None: if tool_file is None:
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
path = tool_file.file_key extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
if "." in path:
extension = "." + path.split("/")[-1].split(".")[-1] return File(
else:
extension = ".bin"
file = File(
id=mapping.get("id"), id=mapping.get("id"),
tenant_id=tenant_id, tenant_id=tenant_id,
filename=tool_file.name, filename=tool_file.name,
@ -246,6 +211,21 @@ def _build_from_tool_file(
extension=extension, extension=extension,
mime_type=tool_file.mimetype, mime_type=tool_file.mimetype,
size=tool_file.size, size=tool_file.size,
_extra_config=config,
) )
return file
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
return False
if config.allowed_extensions and file.extension not in config.allowed_extensions:
return False
if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
return False
if file.type == FileType.IMAGE and config.image_config:
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
return False
return True

View File

@ -13,7 +13,7 @@ from sqlalchemy import Float, func, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from configs import dify_config from configs import dify_config
from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.file.tool_file_parser import ToolFileParser from core.file.tool_file_parser import ToolFileParser
from extensions.ext_database import db from extensions.ext_database import db
@ -949,9 +949,6 @@ class Message(db.Model):
"type": message_file.type, "type": message_file.type,
}, },
tenant_id=current_app.tenant_id, tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
) )
elif message_file.transfer_method == "remote_url": elif message_file.transfer_method == "remote_url":
if message_file.url is None: if message_file.url is None:
@ -964,9 +961,6 @@ class Message(db.Model):
"url": message_file.url, "url": message_file.url,
}, },
tenant_id=current_app.tenant_id, tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
) )
elif message_file.transfer_method == "tool_file": elif message_file.transfer_method == "tool_file":
if message_file.upload_file_id is None: if message_file.upload_file_id is None:
@ -981,9 +975,6 @@ class Message(db.Model):
file = file_factory.build_from_mapping( file = file_factory.build_from_mapping(
mapping=mapping, mapping=mapping,
tenant_id=current_app.tenant_id, tenant_id=current_app.tenant_id,
user_id=self.from_account_id or self.from_end_user_id or "",
role=CreatedByRole(message_file.created_by_role),
config=FileExtraConfig(),
) )
else: else:
raise ValueError( raise ValueError(

View File

@ -13,7 +13,7 @@ from core.app.app_config.entities import (
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.file.models import FileExtraConfig from core.file.models import FileUploadConfig
from core.helper import encrypter from core.helper import encrypter
from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -381,7 +381,7 @@ class WorkflowConverter:
graph: dict, graph: dict,
model_config: ModelConfigEntity, model_config: ModelConfigEntity,
prompt_template: PromptTemplateEntity, prompt_template: PromptTemplateEntity,
file_upload: Optional[FileExtraConfig] = None, file_upload: Optional[FileUploadConfig] = None,
external_data_variable_node_mapping: dict[str, str] | None = None, external_data_variable_node_mapping: dict[str, str] | None = None,
) -> dict: ) -> dict:
""" """

View File

@ -430,37 +430,3 @@ def test_multi_colons_parse(setup_http_mock):
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "") assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
# assert "http://example3.com" == resp.get("headers", {}).get("referer") # assert "http://example3.com" == resp.get("headers", {}).get("referer")
def test_image_file(monkeypatch):
from types import SimpleNamespace
monkeypatch.setattr(
"core.tools.tool_file_manager.ToolFileManager.create_file_by_raw",
lambda *args, **kwargs: SimpleNamespace(id="1"),
)
node = init_http_node(
config={
"id": "1",
"data": {
"title": "http",
"desc": "",
"method": "get",
"url": "https://cloud.dify.ai/logo/logo-site.png",
"authorization": {
"type": "no-auth",
"config": None,
},
"params": "",
"headers": "",
"body": None,
},
}
)
result = node._run()
assert result.process_data is not None
assert result.outputs is not None
resp = result.outputs
assert len(resp.get("files", [])) == 1

View File

@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from core.app.app_config.entities import ModelConfigEntity from core.app.app_config.entities import ModelConfigEntity
from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
@ -134,7 +134,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
type=FileType.IMAGE, type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL, transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image1.jpg", remote_url="https://example.com/image1.jpg",
_extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)),
) )
] ]