fix: Update prompt message content types to use Literal and add union type for content (#17136)

Co-authored-by: 朱庆超 <zhuqingchao@xiaomi.com>
Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
ZalterCitty 2025-04-22 16:17:55 +08:00 committed by GitHub
parent 404f8a790c
commit a1158cc946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 73 additions and 39 deletions

View File

@ -21,14 +21,13 @@ from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
LLMUsage, LLMUsage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageTool, PromptMessageTool,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.prompt.utils.extract_thread_messages import extract_thread_messages
@ -501,7 +500,7 @@ class BaseAgentRunner(AppRunner):
) )
if not file_objs: if not file_objs:
return UserPromptMessage(content=message.query) return UserPromptMessage(content=message.query)
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query)) prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs: for file in file_objs:
prompt_message_contents.append( prompt_message_contents.append(

View File

@ -5,12 +5,11 @@ from core.file import file_manager
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
@ -40,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query Organize user query
""" """
if self.files: if self.files:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config # get image detail config

View File

@ -15,14 +15,13 @@ from core.model_runtime.entities import (
LLMResultChunkDelta, LLMResultChunkDelta,
LLMUsage, LLMUsage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageContentType, PromptMessageContentType,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
@ -395,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query Organize user query
""" """
if self.files: if self.files:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config # get image detail config

View File

@ -7,9 +7,9 @@ from core.model_runtime.entities import (
AudioPromptMessageContent, AudioPromptMessageContent,
DocumentPromptMessageContent, DocumentPromptMessageContent,
ImagePromptMessageContent, ImagePromptMessageContent,
MultiModalPromptMessageContent,
VideoPromptMessageContent, VideoPromptMessageContent,
) )
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from extensions.ext_storage import storage from extensions.ext_storage import storage
from . import helpers from . import helpers
@ -43,7 +43,7 @@ def to_prompt_message_content(
/, /,
*, *,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> MultiModalPromptMessageContent: ) -> PromptMessageContentUnionTypes:
if f.extension is None: if f.extension is None:
raise ValueError("Missing file extension") raise ValueError("Missing file extension")
if f.mime_type is None: if f.mime_type is None:
@ -58,7 +58,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE: if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = { prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent, FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent, FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent, FileType.VIDEO: VideoPromptMessageContent,

View File

@ -8,11 +8,11 @@ from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageRole, PromptMessageRole,
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
@ -100,7 +100,7 @@ class TokenBufferMemory:
if not file_objs: if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query)) prompt_messages.append(UserPromptMessage(content=message.query))
else: else:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query)) prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs: for file in file_objs:
prompt_message = file_manager.to_prompt_message_content( prompt_message = file_manager.to_prompt_message_content(

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional, Union from typing import Annotated, Any, Literal, Optional, Union
from pydantic import BaseModel, Field, field_serializer, field_validator from pydantic import BaseModel, Field, field_serializer, field_validator
@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum):
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
""" pass
Model class for prompt message content.
"""
type: PromptMessageContentType
class TextPromptMessageContent(PromptMessageContent): class TextPromptMessageContent(PromptMessageContent):
@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent):
Model class for text prompt message content. Model class for text prompt message content.
""" """
type: PromptMessageContentType = PromptMessageContentType.TEXT type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
data: str data: str
@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
Model class for multi-modal prompt message content. Model class for multi-modal prompt message content.
""" """
type: PromptMessageContentType
format: str = Field(default=..., description="the format of multi-modal file") format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file") base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file") url: str = Field(default="", description="the url of multi-modal file")
@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
class VideoPromptMessageContent(MultiModalPromptMessageContent): class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
class AudioPromptMessageContent(MultiModalPromptMessageContent): class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
class ImagePromptMessageContent(MultiModalPromptMessageContent): class ImagePromptMessageContent(MultiModalPromptMessageContent):
@ -110,12 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
LOW = "low" LOW = "low"
HIGH = "high" HIGH = "high"
type: PromptMessageContentType = PromptMessageContentType.IMAGE type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent): class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
PromptMessageContentUnionTypes = Annotated[
Union[
TextPromptMessageContent,
ImagePromptMessageContent,
DocumentPromptMessageContent,
AudioPromptMessageContent,
VideoPromptMessageContent,
],
Field(discriminator="type"),
]
class PromptMessage(BaseModel): class PromptMessage(BaseModel):
@ -124,7 +131,7 @@ class PromptMessage(BaseModel):
""" """
role: PromptMessageRole role: PromptMessageRole
content: Optional[str | Sequence[PromptMessageContent]] = None content: Optional[str | list[PromptMessageContentUnionTypes]] = None
name: Optional[str] = None name: Optional[str] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:

View File

@ -9,13 +9,12 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities import ( from core.model_runtime.entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessage, PromptMessage,
PromptMessageContent,
PromptMessageRole, PromptMessageRole,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@ -125,7 +124,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt = Jinja2Formatter.format(prompt, prompt_inputs) prompt = Jinja2Formatter.format(prompt, prompt_inputs)
if files: if files:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files: for file in files:
prompt_message_contents.append( prompt_message_contents.append(
@ -201,7 +200,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files and query is not None: if files and query is not None:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query)) prompt_message_contents.append(TextPromptMessageContent(data=query))
for file in files: for file in files:
prompt_message_contents.append( prompt_message_contents.append(

View File

@ -11,7 +11,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent, ImagePromptMessageContent,
PromptMessage, PromptMessage,
PromptMessageContent, PromptMessageContentUnionTypes,
SystemPromptMessage, SystemPromptMessage,
TextPromptMessageContent, TextPromptMessageContent,
UserPromptMessage, UserPromptMessage,
@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform):
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None, image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> UserPromptMessage: ) -> UserPromptMessage:
if files: if files:
prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt)) prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files: for file in files:
prompt_message_contents.append( prompt_message_contents.append(

View File

@ -24,7 +24,7 @@ from core.model_runtime.entities import (
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
PromptMessageContent, PromptMessageContentUnionTypes,
PromptMessageRole, PromptMessageRole,
SystemPromptMessage, SystemPromptMessage,
UserPromptMessage, UserPromptMessage,
@ -594,8 +594,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool: VariablePool, variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector], jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
# FIXME: fix the type error cause prompt_messages is type quick a few times prompt_messages: list[PromptMessage] = []
prompt_messages: list[Any] = []
if isinstance(prompt_template, list): if isinstance(prompt_template, list):
# For chat model # For chat model
@ -657,12 +656,14 @@ class LLMNode(BaseNode[LLMNodeData]):
# For issue #11247 - Check if prompt content is a string or a list # For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content) prompt_content_type = type(prompt_content)
if prompt_content_type == str: if prompt_content_type == str:
prompt_content = str(prompt_content)
if "#histories#" in prompt_content: if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text) prompt_content = prompt_content.replace("#histories#", memory_text)
else: else:
prompt_content = memory_text + "\n" + prompt_content prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content prompt_messages[0].content = prompt_content
elif prompt_content_type == list: elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content: for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT: if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data: if "#histories#" in content_item.data:
@ -675,9 +676,10 @@ class LLMNode(BaseNode[LLMNodeData]):
# Add current query to the prompt message # Add current query to the prompt message
if sys_query: if sys_query:
if prompt_content_type == str: if prompt_content_type == str:
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query) prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content prompt_messages[0].content = prompt_content
elif prompt_content_type == list: elif prompt_content_type == list:
prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content: for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT: if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data content_item.data = sys_query + "\n" + content_item.data
@ -707,7 +709,7 @@ class LLMNode(BaseNode[LLMNodeData]):
filtered_prompt_messages = [] filtered_prompt_messages = []
for prompt_message in prompt_messages: for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list): if isinstance(prompt_message.content, list):
prompt_message_content = [] prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content: for content_item in prompt_message.content:
# Skip content if features are not defined # Skip content if features are not defined
if not model_config.model_schema.features: if not model_config.model_schema.features:
@ -1132,7 +1134,9 @@ class LLMNode(BaseNode[LLMNodeData]):
) )
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
):
match role: match role:
case PromptMessageRole.USER: case PromptMessageRole.USER:
return UserPromptMessage(content=contents) return UserPromptMessage(content=contents)

View File

@ -0,0 +1,27 @@
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
def test_build_prompt_message_with_prompt_message_contents():
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
assert isinstance(prompt.content, list)
assert isinstance(prompt.content[0], TextPromptMessageContent)
assert prompt.content[0].data == "Hello, World!"
def test_dump_prompt_message():
example_url = "https://example.com/image.jpg"
prompt = UserPromptMessage(
content=[
ImagePromptMessageContent(
url=example_url,
format="jpeg",
mime_type="image/jpeg",
)
]
)
data = prompt.model_dump()
assert data["content"][0].get("url") == example_url