mirror of
https://git.mirrors.martin98.com/https://github.com/langgenius/dify.git
synced 2025-08-11 11:28:57 +08:00
refactor(core): Remove extra_config from File. (#10203)
This commit is contained in:
parent
78a380bcc4
commit
25ca0278dd
@ -30,6 +30,7 @@ from core.model_runtime.entities import (
|
|||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
from core.model_runtime.entities.model_entities import ModelFeature
|
from core.model_runtime.entities.model_entities import ModelFeature
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -65,7 +66,7 @@ class BaseAgentRunner(AppRunner):
|
|||||||
prompt_messages: Optional[list[PromptMessage]] = None,
|
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||||
db_variables: Optional[ToolConversationVariables] = None,
|
db_variables: Optional[ToolConversationVariables] = None,
|
||||||
model_instance: ModelInstance = None,
|
model_instance: ModelInstance | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
@ -508,24 +509,27 @@ class BaseAgentRunner(AppRunner):
|
|||||||
|
|
||||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||||
if files:
|
if not files:
|
||||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
|
||||||
|
|
||||||
if file_extra_config:
|
|
||||||
file_objs = file_factory.build_from_message_files(
|
|
||||||
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
file_objs = []
|
|
||||||
|
|
||||||
if not file_objs:
|
|
||||||
return UserPromptMessage(content=message.query)
|
|
||||||
else:
|
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
|
||||||
for file_obj in file_objs:
|
|
||||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
|
||||||
|
|
||||||
return UserPromptMessage(content=prompt_message_contents)
|
|
||||||
else:
|
|
||||||
return UserPromptMessage(content=message.query)
|
return UserPromptMessage(content=message.query)
|
||||||
|
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||||
|
if not file_extra_config:
|
||||||
|
return UserPromptMessage(content=message.query)
|
||||||
|
|
||||||
|
image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
|
||||||
|
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
|
||||||
|
file_objs = file_factory.build_from_message_files(
|
||||||
|
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||||
|
)
|
||||||
|
if not file_objs:
|
||||||
|
return UserPromptMessage(content=message.query)
|
||||||
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
|
for file in file_objs:
|
||||||
|
prompt_message_contents.append(
|
||||||
|
file_manager.to_prompt_message_content(
|
||||||
|
file,
|
||||||
|
image_detail_config=image_detail_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return UserPromptMessage(content=prompt_message_contents)
|
||||||
|
@ -10,6 +10,7 @@ from core.model_runtime.entities import (
|
|||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
|
||||||
|
|
||||||
@ -36,8 +37,24 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file_obj in self.files:
|
|
||||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
# get image detail config
|
||||||
|
image_detail_config = (
|
||||||
|
self.application_generate_entity.file_upload_config.image_config.detail
|
||||||
|
if (
|
||||||
|
self.application_generate_entity.file_upload_config
|
||||||
|
and self.application_generate_entity.file_upload_config.image_config
|
||||||
|
)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
for file in self.files:
|
||||||
|
prompt_message_contents.append(
|
||||||
|
file_manager.to_prompt_message_content(
|
||||||
|
file,
|
||||||
|
image_detail_config=image_detail_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
@ -22,6 +22,7 @@ from core.model_runtime.entities import (
|
|||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
@ -397,8 +398,24 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file_obj in self.files:
|
|
||||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
# get image detail config
|
||||||
|
image_detail_config = (
|
||||||
|
self.application_generate_entity.file_upload_config.image_config.detail
|
||||||
|
if (
|
||||||
|
self.application_generate_entity.file_upload_config
|
||||||
|
and self.application_generate_entity.file_upload_config.image_config
|
||||||
|
)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
for file in self.files:
|
||||||
|
prompt_message_contents.append(
|
||||||
|
file_manager.to_prompt_message_content(
|
||||||
|
file,
|
||||||
|
image_detail_config=image_detail_config,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
@ -4,7 +4,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AppAdditionalFeatures(BaseModel):
|
class AppAdditionalFeatures(BaseModel):
|
||||||
file_upload: Optional[FileExtraConfig] = None
|
file_upload: Optional[FileUploadConfig] = None
|
||||||
opening_statement: Optional[str] = None
|
opening_statement: Optional[str] = None
|
||||||
suggested_questions: list[str] = []
|
suggested_questions: list[str] = []
|
||||||
suggested_questions_after_answer: bool = False
|
suggested_questions_after_answer: bool = False
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.file import FileExtraConfig
|
from core.file import FileUploadConfig
|
||||||
|
|
||||||
|
|
||||||
class FileUploadConfigManager:
|
class FileUploadConfigManager:
|
||||||
@ -29,15 +29,14 @@ class FileUploadConfigManager:
|
|||||||
if is_vision:
|
if is_vision:
|
||||||
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
||||||
|
|
||||||
return FileExtraConfig.model_validate(data)
|
return FileUploadConfig.model_validate(data)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||||
"""
|
"""
|
||||||
Validate and set defaults for file upload feature
|
Validate and set defaults for file upload feature
|
||||||
|
|
||||||
:param config: app model config args
|
:param config: app model config args
|
||||||
:param is_vision: if True, the feature is vision feature
|
|
||||||
"""
|
"""
|
||||||
if not config.get("file_upload"):
|
if not config.get("file_upload"):
|
||||||
config["file_upload"] = {}
|
config["file_upload"] = {}
|
||||||
|
@ -52,9 +52,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||||||
related_config_keys = []
|
related_config_keys = []
|
||||||
|
|
||||||
# file upload validation
|
# file upload validation
|
||||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||||
config=config, is_vision=False
|
|
||||||
)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
# opening_statement
|
# opening_statement
|
||||||
|
@ -26,7 +26,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole
|
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
@ -98,13 +97,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = file_factory.build_from_mappings(
|
file_objs = file_factory.build_from_mappings(
|
||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
user_id=user.id,
|
|
||||||
role=role,
|
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -127,10 +123,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
|
file_upload_config=file_extra_config,
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs
|
inputs=conversation.inputs
|
||||||
if conversation
|
if conversation
|
||||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
|
@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models import Account, App, EndUser
|
from models import Account, App, EndUser
|
||||||
from models.enums import CreatedByRole
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -103,8 +102,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||||
|
|
||||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args.get("files") or []
|
files = args.get("files") or []
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
@ -112,8 +109,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
file_objs = file_factory.build_from_mappings(
|
file_objs = file_factory.build_from_mappings(
|
||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
user_id=user.id,
|
|
||||||
role=role,
|
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -135,10 +130,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
|
file_upload_config=file_extra_config,
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs
|
inputs=conversation.inputs
|
||||||
if conversation
|
if conversation
|
||||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
|
@ -2,12 +2,11 @@ from collections.abc import Mapping
|
|||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
from core.file import File, FileExtraConfig
|
from core.file import File, FileUploadConfig
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||||
from models.enums import CreatedByRole
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAppGenerator:
|
class BaseAppGenerator:
|
||||||
@ -16,8 +15,6 @@ class BaseAppGenerator:
|
|||||||
*,
|
*,
|
||||||
user_inputs: Optional[Mapping[str, Any]],
|
user_inputs: Optional[Mapping[str, Any]],
|
||||||
app_config: "AppConfig",
|
app_config: "AppConfig",
|
||||||
user_id: str,
|
|
||||||
role: "CreatedByRole",
|
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
@ -34,9 +31,7 @@ class BaseAppGenerator:
|
|||||||
k: file_factory.build_from_mapping(
|
k: file_factory.build_from_mapping(
|
||||||
mapping=v,
|
mapping=v,
|
||||||
tenant_id=app_config.tenant_id,
|
tenant_id=app_config.tenant_id,
|
||||||
user_id=user_id,
|
config=FileUploadConfig(
|
||||||
role=role,
|
|
||||||
config=FileExtraConfig(
|
|
||||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
@ -50,9 +45,7 @@ class BaseAppGenerator:
|
|||||||
k: file_factory.build_from_mappings(
|
k: file_factory.build_from_mappings(
|
||||||
mappings=v,
|
mappings=v,
|
||||||
tenant_id=app_config.tenant_id,
|
tenant_id=app_config.tenant_id,
|
||||||
user_id=user_id,
|
config=FileUploadConfig(
|
||||||
role=role,
|
|
||||||
config=FileExtraConfig(
|
|
||||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
|
@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.enums import CreatedByRole
|
|
||||||
from models.model import App, EndUser
|
from models.model import App, EndUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -101,8 +100,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||||
|
|
||||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
@ -110,8 +107,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
file_objs = file_factory.build_from_mappings(
|
file_objs = file_factory.build_from_mappings(
|
||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
user_id=user.id,
|
|
||||||
role=role,
|
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -133,10 +128,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
|
file_upload_config=file_extra_config,
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs
|
inputs=conversation.inputs
|
||||||
if conversation
|
if conversation
|
||||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
|
@ -22,7 +22,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models import Account, App, EndUser, Message
|
from models import Account, App, EndUser, Message
|
||||||
from models.enums import CreatedByRole
|
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
|
|
||||||
@ -88,8 +87,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||||
)
|
)
|
||||||
|
|
||||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
@ -97,8 +94,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
file_objs = file_factory.build_from_mappings(
|
file_objs = file_factory.build_from_mappings(
|
||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
user_id=user.id,
|
|
||||||
role=role,
|
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -110,7 +105,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
|
||||||
trace_manager = TraceQueueManager(app_model.id)
|
trace_manager = TraceQueueManager(app_model.id)
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
@ -118,7 +112,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
file_upload_config=file_extra_config,
|
||||||
|
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
@ -259,14 +254,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||||||
override_model_config_dict["model"] = model_dict
|
override_model_config_dict["model"] = model_dict
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = file_factory.build_from_mappings(
|
file_objs = file_factory.build_from_mappings(
|
||||||
mappings=message.message_files,
|
mappings=message.message_files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
user_id=user.id,
|
|
||||||
role=role,
|
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -46,9 +46,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||||||
related_config_keys = []
|
related_config_keys = []
|
||||||
|
|
||||||
# file upload validation
|
# file upload validation
|
||||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||||
config=config, is_vision=False
|
|
||||||
)
|
|
||||||
related_config_keys.extend(current_related_config_keys)
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
# text_to_speech
|
# text_to_speech
|
||||||
|
@ -25,7 +25,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from models import Account, App, EndUser, Workflow
|
from models import Account, App, EndUser, Workflow
|
||||||
from models.enums import CreatedByRole
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -70,15 +69,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
):
|
):
|
||||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
|
|
||||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
system_files = file_factory.build_from_mappings(
|
system_files = file_factory.build_from_mappings(
|
||||||
mappings=files,
|
mappings=files,
|
||||||
tenant_id=app_model.tenant_id,
|
tenant_id=app_model.tenant_id,
|
||||||
user_id=user.id,
|
|
||||||
role=role,
|
|
||||||
config=file_extra_config,
|
config=file_extra_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -100,7 +95,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
application_generate_entity = WorkflowAppGenerateEntity(
|
application_generate_entity = WorkflowAppGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
file_upload_config=file_extra_config,
|
||||||
|
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||||
files=system_files,
|
files=system_files,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
|||||||
from constants import UUID_NIL
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
from core.file.models import File
|
from core.file import File, FileUploadConfig
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
|
||||||
@ -80,6 +80,7 @@ class AppGenerateEntity(BaseModel):
|
|||||||
|
|
||||||
# app config
|
# app config
|
||||||
app_config: AppConfig
|
app_config: AppConfig
|
||||||
|
file_upload_config: Optional[FileUploadConfig] = None
|
||||||
|
|
||||||
inputs: Mapping[str, Any]
|
inputs: Mapping[str, Any]
|
||||||
files: Sequence[File]
|
files: Sequence[File]
|
||||||
|
@ -2,13 +2,13 @@ from .constants import FILE_MODEL_IDENTITY
|
|||||||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||||
from .models import (
|
from .models import (
|
||||||
File,
|
File,
|
||||||
FileExtraConfig,
|
FileUploadConfig,
|
||||||
ImageConfig,
|
ImageConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FileType",
|
"FileType",
|
||||||
"FileExtraConfig",
|
"FileUploadConfig",
|
||||||
"FileTransferMethod",
|
"FileTransferMethod",
|
||||||
"FileBelongsTo",
|
"FileBelongsTo",
|
||||||
"File",
|
"File",
|
||||||
|
@ -33,25 +33,28 @@ def get_attr(*, file: File, attr: FileAttribute):
|
|||||||
raise ValueError(f"Invalid file attribute: {attr}")
|
raise ValueError(f"Invalid file attribute: {attr}")
|
||||||
|
|
||||||
|
|
||||||
def to_prompt_message_content(f: File, /):
|
def to_prompt_message_content(
|
||||||
|
f: File,
|
||||||
|
/,
|
||||||
|
*,
|
||||||
|
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Convert a File object to an ImagePromptMessageContent object.
|
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
|
||||||
|
|
||||||
This function takes a File object and converts it to an ImagePromptMessageContent
|
This function takes a File object and converts it to an appropriate PromptMessageContent
|
||||||
object, which can be used as a prompt for image-based AI models.
|
object, which can be used as a prompt for image or audio-based AI models.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
f (File): The File object to convert.
|
||||||
|
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
|
||||||
|
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ImagePromptMessageContent: An object containing the image data and detail level.
|
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the file is not an image or if the file data is missing.
|
ValueError: If the file type is not supported or if required data is missing.
|
||||||
|
|
||||||
Note:
|
|
||||||
The detail level of the image prompt is determined by the file's extra_config.
|
|
||||||
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
|
||||||
"""
|
"""
|
||||||
match f.type:
|
match f.type:
|
||||||
case FileType.IMAGE:
|
case FileType.IMAGE:
|
||||||
@ -60,12 +63,7 @@ def to_prompt_message_content(f: File, /):
|
|||||||
else:
|
else:
|
||||||
data = _to_base64_data_string(f)
|
data = _to_base64_data_string(f)
|
||||||
|
|
||||||
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
|
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||||
detail = f._extra_config.image_config.detail
|
|
||||||
else:
|
|
||||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
|
||||||
|
|
||||||
return ImagePromptMessageContent(data=data, detail=detail)
|
|
||||||
case FileType.AUDIO:
|
case FileType.AUDIO:
|
||||||
encoded_string = _file_to_encoded_string(f)
|
encoded_string = _file_to_encoded_string(f)
|
||||||
if f.extension is None:
|
if f.extension is None:
|
||||||
@ -78,7 +76,7 @@ def to_prompt_message_content(f: File, /):
|
|||||||
data = _to_base64_data_string(f)
|
data = _to_base64_data_string(f)
|
||||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"file type {f.type} is not supported")
|
raise ValueError("file type f.type is not supported")
|
||||||
|
|
||||||
|
|
||||||
def download(f: File, /):
|
def download(f: File, /):
|
||||||
|
@ -21,7 +21,7 @@ class ImageConfig(BaseModel):
|
|||||||
detail: ImagePromptMessageContent.DETAIL | None = None
|
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||||
|
|
||||||
|
|
||||||
class FileExtraConfig(BaseModel):
|
class FileUploadConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
File Upload Entity.
|
File Upload Entity.
|
||||||
"""
|
"""
|
||||||
@ -46,7 +46,6 @@ class File(BaseModel):
|
|||||||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||||
mime_type: Optional[str] = None
|
mime_type: Optional[str] = None
|
||||||
size: int = -1
|
size: int = -1
|
||||||
_extra_config: FileExtraConfig | None = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||||
data = self.model_dump(mode="json")
|
data = self.model_dump(mode="json")
|
||||||
@ -107,34 +106,4 @@ class File(BaseModel):
|
|||||||
case FileTransferMethod.TOOL_FILE:
|
case FileTransferMethod.TOOL_FILE:
|
||||||
if not self.related_id:
|
if not self.related_id:
|
||||||
raise ValueError("Missing file related_id")
|
raise ValueError("Missing file related_id")
|
||||||
|
|
||||||
# Validate the extra config.
|
|
||||||
if not self._extra_config:
|
|
||||||
return self
|
|
||||||
|
|
||||||
if self._extra_config.allowed_file_types:
|
|
||||||
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
|
||||||
raise ValueError(f"Invalid file type: {self.type}")
|
|
||||||
|
|
||||||
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
|
||||||
raise ValueError(f"Invalid file extension: {self.extension}")
|
|
||||||
|
|
||||||
if (
|
|
||||||
self._extra_config.allowed_upload_methods
|
|
||||||
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
|
||||||
):
|
|
||||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
|
||||||
|
|
||||||
match self.type:
|
|
||||||
case FileType.IMAGE:
|
|
||||||
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
|
||||||
if not self._extra_config.image_config:
|
|
||||||
return self
|
|
||||||
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
|
||||||
if (
|
|
||||||
self._extra_config.image_config.transfer_methods
|
|
||||||
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
|
||||||
):
|
|
||||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
@ -81,15 +81,18 @@ class TokenBufferMemory:
|
|||||||
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if workflow_run:
|
if workflow_run and workflow_run.workflow:
|
||||||
file_extra_config = FileUploadConfigManager.convert(
|
file_extra_config = FileUploadConfigManager.convert(
|
||||||
workflow_run.workflow.features_dict, is_vision=False
|
workflow_run.workflow.features_dict, is_vision=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||||
if file_extra_config and app_record:
|
if file_extra_config and app_record:
|
||||||
file_objs = file_factory.build_from_message_files(
|
file_objs = file_factory.build_from_message_files(
|
||||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||||
)
|
)
|
||||||
|
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||||
|
detail = file_extra_config.image_config.detail
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
@ -98,12 +101,16 @@ class TokenBufferMemory:
|
|||||||
else:
|
else:
|
||||||
prompt_message_contents: list[PromptMessageContent] = []
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
for file_obj in file_objs:
|
for file in file_objs:
|
||||||
if file_obj.type in {FileType.IMAGE, FileType.AUDIO}:
|
if file.type in {FileType.IMAGE, FileType.AUDIO}:
|
||||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
prompt_message = file_manager.to_prompt_message_content(
|
||||||
|
file,
|
||||||
|
image_detail_config=detail,
|
||||||
|
)
|
||||||
prompt_message_contents.append(prompt_message)
|
prompt_message_contents.append(prompt_message)
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ from core.model_runtime.entities import (
|
|||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.prompt_transform import PromptTransform
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
@ -26,8 +27,13 @@ class AdvancedPromptTransform(PromptTransform):
|
|||||||
Advanced Prompt Transform for Workflow LLM Node.
|
Advanced Prompt Transform for Workflow LLM Node.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, with_variable_tmpl: bool = False) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
with_variable_tmpl: bool = False,
|
||||||
|
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||||
|
) -> None:
|
||||||
self.with_variable_tmpl = with_variable_tmpl
|
self.with_variable_tmpl = with_variable_tmpl
|
||||||
|
self.image_detail_config = image_detail_config
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
|
@ -1,19 +1,23 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.file import File
|
from core.file import FileTransferMethod, FileType
|
||||||
from core.file.enums import FileTransferMethod, FileType
|
|
||||||
from core.tools.errors import ToolProviderCredentialValidationError
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
||||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||||
|
from factories import file_factory
|
||||||
|
|
||||||
|
|
||||||
class VectorizerProvider(BuiltinToolProviderController):
|
class VectorizerProvider(BuiltinToolProviderController):
|
||||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
test_img = File(
|
mapping = {
|
||||||
|
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||||
|
"type": FileType.IMAGE,
|
||||||
|
"id": "test_id",
|
||||||
|
"url": "https://cloud.dify.ai/logo/logo-site.png",
|
||||||
|
}
|
||||||
|
test_img = file_factory.build_from_mapping(
|
||||||
|
mapping=mapping,
|
||||||
tenant_id="__test_123",
|
tenant_id="__test_123",
|
||||||
remote_url="https://cloud.dify.ai/logo/logo-site.png",
|
|
||||||
type=FileType.IMAGE,
|
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
VectorizerTool().fork_tool_runtime(
|
VectorizerTool().fork_tool_runtime(
|
||||||
|
@ -13,6 +13,7 @@ from core.workflow.nodes.base import BaseNode
|
|||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.http_request.executor import Executor
|
from core.workflow.nodes.http_request.executor import Executor
|
||||||
from core.workflow.utils import variable_template_parser
|
from core.workflow.utils import variable_template_parser
|
||||||
|
from factories import file_factory
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
from .entities import (
|
from .entities import (
|
||||||
@ -161,16 +162,15 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||||||
mimetype=content_type,
|
mimetype=content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
files.append(
|
mapping = {
|
||||||
File(
|
"tool_file_id": tool_file.id,
|
||||||
tenant_id=self.tenant_id,
|
"type": FileType.IMAGE.value,
|
||||||
type=FileType.IMAGE,
|
"transfer_method": FileTransferMethod.TOOL_FILE.value,
|
||||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
}
|
||||||
related_id=tool_file.id,
|
file = file_factory.build_from_mapping(
|
||||||
filename=filename,
|
mapping=mapping,
|
||||||
extension=extension,
|
tenant_id=self.tenant_id,
|
||||||
mime_type=content_type,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
files.append(file)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
@ -17,6 +17,7 @@ from core.workflow.nodes.base import BaseNode
|
|||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models import ToolFile
|
from models import ToolFile
|
||||||
from models.workflow import WorkflowNodeExecutionStatus
|
from models.workflow import WorkflowNodeExecutionStatus
|
||||||
|
|
||||||
@ -189,19 +190,17 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
if tool_file is None:
|
if tool_file is None:
|
||||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||||
|
|
||||||
result.append(
|
mapping = {
|
||||||
File(
|
"tool_file_id": tool_file_id,
|
||||||
tenant_id=self.tenant_id,
|
"type": FileType.IMAGE,
|
||||||
type=FileType.IMAGE,
|
"transfer_method": transfer_method,
|
||||||
transfer_method=transfer_method,
|
"url": url,
|
||||||
remote_url=url,
|
}
|
||||||
related_id=tool_file.id,
|
file = file_factory.build_from_mapping(
|
||||||
filename=tool_file.name,
|
mapping=mapping,
|
||||||
extension=ext,
|
tenant_id=self.tenant_id,
|
||||||
mime_type=tool_file.mimetype,
|
|
||||||
size=tool_file.size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
result.append(file)
|
||||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||||
# get tool file id
|
# get tool file id
|
||||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||||
@ -209,19 +208,17 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||||
tool_file = session.scalar(stmt)
|
tool_file = session.scalar(stmt)
|
||||||
if tool_file is None:
|
if tool_file is None:
|
||||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||||
result.append(
|
mapping = {
|
||||||
File(
|
"tool_file_id": tool_file_id,
|
||||||
tenant_id=self.tenant_id,
|
"type": FileType.IMAGE,
|
||||||
type=FileType.IMAGE,
|
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
}
|
||||||
related_id=tool_file.id,
|
file = file_factory.build_from_mapping(
|
||||||
filename=tool_file.name,
|
mapping=mapping,
|
||||||
extension=path.splitext(response.save_as)[1],
|
tenant_id=self.tenant_id,
|
||||||
mime_type=tool_file.mimetype,
|
|
||||||
size=tool_file.size,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
result.append(file)
|
||||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||||
url = str(response.message)
|
url = str(response.message)
|
||||||
transfer_method = FileTransferMethod.TOOL_FILE
|
transfer_method = FileTransferMethod.TOOL_FILE
|
||||||
@ -235,16 +232,15 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||||||
extension = "." + url.split("/")[-1].split(".")[1]
|
extension = "." + url.split("/")[-1].split(".")[1]
|
||||||
else:
|
else:
|
||||||
extension = ".bin"
|
extension = ".bin"
|
||||||
file = File(
|
mapping = {
|
||||||
|
"tool_file_id": tool_file_id,
|
||||||
|
"type": FileType.IMAGE,
|
||||||
|
"transfer_method": transfer_method,
|
||||||
|
"url": url,
|
||||||
|
}
|
||||||
|
file = file_factory.build_from_mapping(
|
||||||
|
mapping=mapping,
|
||||||
tenant_id=self.tenant_id,
|
tenant_id=self.tenant_id,
|
||||||
type=FileType(response.save_as),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
remote_url=url,
|
|
||||||
filename=tool_file.name,
|
|
||||||
related_id=tool_file.id,
|
|
||||||
extension=extension,
|
|
||||||
mime_type=tool_file.mimetype,
|
|
||||||
size=tool_file.size,
|
|
||||||
)
|
)
|
||||||
result.append(file)
|
result.append(file)
|
||||||
|
|
||||||
|
@ -5,10 +5,10 @@ from collections.abc import Generator, Mapping, Sequence
|
|||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.app_config.entities import FileExtraConfig
|
from core.app.app_config.entities import FileUploadConfig
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file.models import File, FileTransferMethod, FileType, ImageConfig
|
from core.file.models import File, FileTransferMethod, ImageConfig
|
||||||
from core.workflow.callbacks import WorkflowCallback
|
from core.workflow.callbacks import WorkflowCallback
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||||
@ -22,6 +22,7 @@ from core.workflow.nodes.base import BaseNode, BaseNodeData
|
|||||||
from core.workflow.nodes.event import NodeEvent
|
from core.workflow.nodes.event import NodeEvent
|
||||||
from core.workflow.nodes.llm import LLMNodeData
|
from core.workflow.nodes.llm import LLMNodeData
|
||||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||||
|
from factories import file_factory
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
@ -271,19 +272,17 @@ class WorkflowEntry:
|
|||||||
for item in input_value:
|
for item in input_value:
|
||||||
if isinstance(item, dict) and "type" in item and item["type"] == "image":
|
if isinstance(item, dict) and "type" in item and item["type"] == "image":
|
||||||
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
|
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
|
||||||
file = File(
|
mapping = {
|
||||||
|
"id": item.get("id"),
|
||||||
|
"transfer_method": transfer_method,
|
||||||
|
"upload_file_id": item.get("upload_file_id"),
|
||||||
|
"url": item.get("url"),
|
||||||
|
}
|
||||||
|
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
|
||||||
|
file = file_factory.build_from_mapping(
|
||||||
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
type=FileType.IMAGE,
|
config=config,
|
||||||
transfer_method=transfer_method,
|
|
||||||
remote_url=item.get("url")
|
|
||||||
if transfer_method == FileTransferMethod.REMOTE_URL
|
|
||||||
else None,
|
|
||||||
related_id=item.get("upload_file_id")
|
|
||||||
if transfer_method == FileTransferMethod.LOCAL_FILE
|
|
||||||
else None,
|
|
||||||
_extra_config=FileExtraConfig(
|
|
||||||
image_config=ImageConfig(detail=detail) if detail else None
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
new_value.append(file)
|
new_value.append(file)
|
||||||
|
|
||||||
|
@ -1,23 +1,21 @@
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Callable, Mapping, Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
|
||||||
from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
|
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import MessageFile, ToolFile, UploadFile
|
from models import MessageFile, ToolFile, UploadFile
|
||||||
from models.enums import CreatedByRole
|
|
||||||
|
|
||||||
|
|
||||||
def build_from_message_files(
|
def build_from_message_files(
|
||||||
*,
|
*,
|
||||||
message_files: Sequence["MessageFile"],
|
message_files: Sequence["MessageFile"],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
config: FileExtraConfig,
|
config: FileUploadConfig,
|
||||||
) -> Sequence[File]:
|
) -> Sequence[File]:
|
||||||
results = [
|
results = [
|
||||||
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
|
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
|
||||||
@ -31,7 +29,7 @@ def build_from_message_file(
|
|||||||
*,
|
*,
|
||||||
message_file: "MessageFile",
|
message_file: "MessageFile",
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
config: FileExtraConfig,
|
config: FileUploadConfig,
|
||||||
):
|
):
|
||||||
mapping = {
|
mapping = {
|
||||||
"transfer_method": message_file.transfer_method,
|
"transfer_method": message_file.transfer_method,
|
||||||
@ -43,8 +41,6 @@ def build_from_message_file(
|
|||||||
return build_from_mapping(
|
return build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=message_file.created_by,
|
|
||||||
role=CreatedByRole(message_file.created_by_role),
|
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -53,38 +49,30 @@ def build_from_mapping(
|
|||||||
*,
|
*,
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
config: FileUploadConfig | None = None,
|
||||||
role: "CreatedByRole",
|
) -> File:
|
||||||
config: FileExtraConfig,
|
config = config or FileUploadConfig()
|
||||||
):
|
|
||||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
||||||
match transfer_method:
|
|
||||||
case FileTransferMethod.REMOTE_URL:
|
build_functions: dict[FileTransferMethod, Callable] = {
|
||||||
file = _build_from_remote_url(
|
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||||
mapping=mapping,
|
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
||||||
tenant_id=tenant_id,
|
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
||||||
config=config,
|
}
|
||||||
transfer_method=transfer_method,
|
|
||||||
)
|
build_func = build_functions.get(transfer_method)
|
||||||
case FileTransferMethod.LOCAL_FILE:
|
if not build_func:
|
||||||
file = _build_from_local_file(
|
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
||||||
mapping=mapping,
|
|
||||||
tenant_id=tenant_id,
|
file = build_func(
|
||||||
user_id=user_id,
|
mapping=mapping,
|
||||||
role=role,
|
tenant_id=tenant_id,
|
||||||
config=config,
|
transfer_method=transfer_method,
|
||||||
transfer_method=transfer_method,
|
)
|
||||||
)
|
|
||||||
case FileTransferMethod.TOOL_FILE:
|
if not _is_file_valid_with_config(file=file, config=config):
|
||||||
file = _build_from_tool_file(
|
raise ValueError(f"File validation failed for file: {file.filename}")
|
||||||
mapping=mapping,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
user_id=user_id,
|
|
||||||
config=config,
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
)
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
|
||||||
|
|
||||||
return file
|
return file
|
||||||
|
|
||||||
@ -92,10 +80,8 @@ def build_from_mapping(
|
|||||||
def build_from_mappings(
|
def build_from_mappings(
|
||||||
*,
|
*,
|
||||||
mappings: Sequence[Mapping[str, Any]],
|
mappings: Sequence[Mapping[str, Any]],
|
||||||
config: FileExtraConfig | None,
|
config: FileUploadConfig | None,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
|
||||||
role: "CreatedByRole",
|
|
||||||
) -> Sequence[File]:
|
) -> Sequence[File]:
|
||||||
if not config:
|
if not config:
|
||||||
return []
|
return []
|
||||||
@ -104,8 +90,6 @@ def build_from_mappings(
|
|||||||
build_from_mapping(
|
build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id,
|
|
||||||
role=role,
|
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
for mapping in mappings
|
for mapping in mappings
|
||||||
@ -128,31 +112,20 @@ def _build_from_local_file(
|
|||||||
*,
|
*,
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
|
||||||
role: "CreatedByRole",
|
|
||||||
config: FileExtraConfig,
|
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
):
|
) -> File:
|
||||||
# check if the upload file exists.
|
|
||||||
file_type = FileType.value_of(mapping.get("type"))
|
file_type = FileType.value_of(mapping.get("type"))
|
||||||
stmt = select(UploadFile).where(
|
stmt = select(UploadFile).where(
|
||||||
UploadFile.id == mapping.get("upload_file_id"),
|
UploadFile.id == mapping.get("upload_file_id"),
|
||||||
UploadFile.tenant_id == tenant_id,
|
UploadFile.tenant_id == tenant_id,
|
||||||
UploadFile.created_by == user_id,
|
|
||||||
UploadFile.created_by_role == role,
|
|
||||||
)
|
)
|
||||||
if file_type == FileType.IMAGE:
|
|
||||||
stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
|
|
||||||
elif file_type == FileType.VIDEO:
|
|
||||||
stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
|
|
||||||
elif file_type == FileType.AUDIO:
|
|
||||||
stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
|
|
||||||
elif file_type == FileType.DOCUMENT:
|
|
||||||
stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
|
|
||||||
row = db.session.scalar(stmt)
|
row = db.session.scalar(stmt)
|
||||||
|
|
||||||
if row is None:
|
if row is None:
|
||||||
raise ValueError("Invalid upload file")
|
raise ValueError("Invalid upload file")
|
||||||
file = File(
|
|
||||||
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=row.name,
|
filename=row.name,
|
||||||
extension="." + row.extension,
|
extension="." + row.extension,
|
||||||
@ -162,23 +135,37 @@ def _build_from_local_file(
|
|||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
remote_url=row.source_url,
|
remote_url=row.source_url,
|
||||||
related_id=mapping.get("upload_file_id"),
|
related_id=mapping.get("upload_file_id"),
|
||||||
_extra_config=config,
|
|
||||||
size=row.size,
|
size=row.size,
|
||||||
)
|
)
|
||||||
return file
|
|
||||||
|
|
||||||
|
|
||||||
def _build_from_remote_url(
|
def _build_from_remote_url(
|
||||||
*,
|
*,
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
config: FileExtraConfig,
|
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
):
|
) -> File:
|
||||||
url = mapping.get("url")
|
url = mapping.get("url")
|
||||||
if not url:
|
if not url:
|
||||||
raise ValueError("Invalid file url")
|
raise ValueError("Invalid file url")
|
||||||
|
|
||||||
|
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||||
|
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
|
||||||
|
|
||||||
|
return File(
|
||||||
|
id=mapping.get("id"),
|
||||||
|
filename=filename,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
type=FileType.value_of(mapping.get("type")),
|
||||||
|
transfer_method=transfer_method,
|
||||||
|
remote_url=url,
|
||||||
|
mime_type=mime_type,
|
||||||
|
extension=extension,
|
||||||
|
size=file_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_remote_file_info(url: str):
|
||||||
mime_type = mimetypes.guess_type(url)[0] or ""
|
mime_type = mimetypes.guess_type(url)[0] or ""
|
||||||
file_size = -1
|
file_size = -1
|
||||||
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
|
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
|
||||||
@ -186,56 +173,34 @@ def _build_from_remote_url(
|
|||||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||||
if resp.status_code == httpx.codes.OK:
|
if resp.status_code == httpx.codes.OK:
|
||||||
if content_disposition := resp.headers.get("Content-Disposition"):
|
if content_disposition := resp.headers.get("Content-Disposition"):
|
||||||
filename = content_disposition.split("filename=")[-1].strip('"')
|
filename = str(content_disposition.split("filename=")[-1].strip('"'))
|
||||||
file_size = int(resp.headers.get("Content-Length", file_size))
|
file_size = int(resp.headers.get("Content-Length", file_size))
|
||||||
mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
|
mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
|
||||||
|
|
||||||
# Determine file extension
|
return mime_type, filename, file_size
|
||||||
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
|
|
||||||
|
|
||||||
if not mime_type:
|
|
||||||
mime_type, _ = mimetypes.guess_type(url)
|
|
||||||
file = File(
|
|
||||||
id=mapping.get("id"),
|
|
||||||
filename=filename,
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
type=FileType.value_of(mapping.get("type")),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
remote_url=url,
|
|
||||||
_extra_config=config,
|
|
||||||
mime_type=mime_type,
|
|
||||||
extension=extension,
|
|
||||||
size=file_size,
|
|
||||||
)
|
|
||||||
return file
|
|
||||||
|
|
||||||
|
|
||||||
def _build_from_tool_file(
|
def _build_from_tool_file(
|
||||||
*,
|
*,
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
user_id: str,
|
|
||||||
config: FileExtraConfig,
|
|
||||||
transfer_method: FileTransferMethod,
|
transfer_method: FileTransferMethod,
|
||||||
):
|
) -> File:
|
||||||
tool_file = (
|
tool_file = (
|
||||||
db.session.query(ToolFile)
|
db.session.query(ToolFile)
|
||||||
.filter(
|
.filter(
|
||||||
ToolFile.id == mapping.get("tool_file_id"),
|
ToolFile.id == mapping.get("tool_file_id"),
|
||||||
ToolFile.tenant_id == tenant_id,
|
ToolFile.tenant_id == tenant_id,
|
||||||
ToolFile.user_id == user_id,
|
|
||||||
)
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if tool_file is None:
|
if tool_file is None:
|
||||||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||||
|
|
||||||
path = tool_file.file_key
|
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||||
if "." in path:
|
|
||||||
extension = "." + path.split("/")[-1].split(".")[-1]
|
return File(
|
||||||
else:
|
|
||||||
extension = ".bin"
|
|
||||||
file = File(
|
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
filename=tool_file.name,
|
filename=tool_file.name,
|
||||||
@ -246,6 +211,21 @@ def _build_from_tool_file(
|
|||||||
extension=extension,
|
extension=extension,
|
||||||
mime_type=tool_file.mimetype,
|
mime_type=tool_file.mimetype,
|
||||||
size=tool_file.size,
|
size=tool_file.size,
|
||||||
_extra_config=config,
|
|
||||||
)
|
)
|
||||||
return file
|
|
||||||
|
|
||||||
|
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
|
||||||
|
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if config.allowed_extensions and file.extension not in config.allowed_extensions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if file.type == FileType.IMAGE and config.image_config:
|
||||||
|
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
@ -13,7 +13,7 @@ from sqlalchemy import Float, func, text
|
|||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||||
from core.file import helpers as file_helpers
|
from core.file import helpers as file_helpers
|
||||||
from core.file.tool_file_parser import ToolFileParser
|
from core.file.tool_file_parser import ToolFileParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -949,9 +949,6 @@ class Message(db.Model):
|
|||||||
"type": message_file.type,
|
"type": message_file.type,
|
||||||
},
|
},
|
||||||
tenant_id=current_app.tenant_id,
|
tenant_id=current_app.tenant_id,
|
||||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
|
||||||
role=CreatedByRole(message_file.created_by_role),
|
|
||||||
config=FileExtraConfig(),
|
|
||||||
)
|
)
|
||||||
elif message_file.transfer_method == "remote_url":
|
elif message_file.transfer_method == "remote_url":
|
||||||
if message_file.url is None:
|
if message_file.url is None:
|
||||||
@ -964,9 +961,6 @@ class Message(db.Model):
|
|||||||
"url": message_file.url,
|
"url": message_file.url,
|
||||||
},
|
},
|
||||||
tenant_id=current_app.tenant_id,
|
tenant_id=current_app.tenant_id,
|
||||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
|
||||||
role=CreatedByRole(message_file.created_by_role),
|
|
||||||
config=FileExtraConfig(),
|
|
||||||
)
|
)
|
||||||
elif message_file.transfer_method == "tool_file":
|
elif message_file.transfer_method == "tool_file":
|
||||||
if message_file.upload_file_id is None:
|
if message_file.upload_file_id is None:
|
||||||
@ -981,9 +975,6 @@ class Message(db.Model):
|
|||||||
file = file_factory.build_from_mapping(
|
file = file_factory.build_from_mapping(
|
||||||
mapping=mapping,
|
mapping=mapping,
|
||||||
tenant_id=current_app.tenant_id,
|
tenant_id=current_app.tenant_id,
|
||||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
|
||||||
role=CreatedByRole(message_file.created_by_role),
|
|
||||||
config=FileExtraConfig(),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -13,7 +13,7 @@ from core.app.app_config.entities import (
|
|||||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||||
from core.file.models import FileExtraConfig
|
from core.file.models import FileUploadConfig
|
||||||
from core.helper import encrypter
|
from core.helper import encrypter
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode
|
from core.model_runtime.entities.llm_entities import LLMMode
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
@ -381,7 +381,7 @@ class WorkflowConverter:
|
|||||||
graph: dict,
|
graph: dict,
|
||||||
model_config: ModelConfigEntity,
|
model_config: ModelConfigEntity,
|
||||||
prompt_template: PromptTemplateEntity,
|
prompt_template: PromptTemplateEntity,
|
||||||
file_upload: Optional[FileExtraConfig] = None,
|
file_upload: Optional[FileUploadConfig] = None,
|
||||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
|
@ -430,37 +430,3 @@ def test_multi_colons_parse(setup_http_mock):
|
|||||||
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
|
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
|
||||||
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
|
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
|
||||||
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||||
|
|
||||||
|
|
||||||
def test_image_file(monkeypatch):
|
|
||||||
from types import SimpleNamespace
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"core.tools.tool_file_manager.ToolFileManager.create_file_by_raw",
|
|
||||||
lambda *args, **kwargs: SimpleNamespace(id="1"),
|
|
||||||
)
|
|
||||||
|
|
||||||
node = init_http_node(
|
|
||||||
config={
|
|
||||||
"id": "1",
|
|
||||||
"data": {
|
|
||||||
"title": "http",
|
|
||||||
"desc": "",
|
|
||||||
"method": "get",
|
|
||||||
"url": "https://cloud.dify.ai/logo/logo-site.png",
|
|
||||||
"authorization": {
|
|
||||||
"type": "no-auth",
|
|
||||||
"config": None,
|
|
||||||
},
|
|
||||||
"params": "",
|
|
||||||
"headers": "",
|
|
||||||
"body": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
result = node._run()
|
|
||||||
assert result.process_data is not None
|
|
||||||
assert result.outputs is not None
|
|
||||||
resp = result.outputs
|
|
||||||
assert len(resp.get("files", [])) == 1
|
|
||||||
|
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from core.app.app_config.entities import ModelConfigEntity
|
from core.app.app_config.entities import ModelConfigEntity
|
||||||
from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig
|
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
@ -134,7 +134,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
|||||||
type=FileType.IMAGE,
|
type=FileType.IMAGE,
|
||||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||||
remote_url="https://example.com/image1.jpg",
|
remote_url="https://example.com/image1.jpg",
|
||||||
_extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)),
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user